Caddy module to require at-proto authentication and restrict routes to DIDs
1package caddyatprotoauth
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/url"
8 "strings"
9 "time"
10
11 "github.com/caddyserver/caddy/v2"
12 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
13 "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
14 "github.com/caddyserver/caddy/v2/modules/caddyhttp"
15 "go.uber.org/zap"
16 "tangled.org/vvill.dev/caddy-atproto-auth/internal/oauth"
17 "tangled.org/vvill.dev/caddy-atproto-auth/internal/resolver"
18 "tangled.org/vvill.dev/caddy-atproto-auth/internal/session"
19 "tangled.org/vvill.dev/caddy-atproto-auth/internal/ui"
20)
21
22func init() {
23 caddy.RegisterModule(Gate{})
24 httpcaddyfile.RegisterHandlerDirective("atproto_gate", parseCaddyfileGate)
25}
26
27// Gate acts as a middleware that guards endpoints
28// and validates the session cookie.
29type Gate struct {
30 Allow []string `json:"allow,omitempty"`
31 ClientID string `json:"client_id,omitempty"` // ClientID for session refreshing (e.g. https://example.com/client-metadata.json)
32 PortalURL string `json:"portal_url,omitempty"` // URL of the auth portal (e.g. http://localhost:8080 or /)
33 UI ui.Config `json:"ui,omitempty"` // Custom UI configuration
34
35 // Dependencies
36 app *App
37 sessions *session.Manager
38 oauth *oauth.Manager
39 renderer *ui.Renderer
40 logger *zap.Logger
41 resolvedDIDs []string
42}
43
44// CaddyModule returns the Caddy module information.
45func (Gate) CaddyModule() caddy.ModuleInfo {
46 return caddy.ModuleInfo{
47 ID: "http.handlers.atproto_gate",
48 New: func() caddy.Module { return new(Gate) },
49 }
50}
51
52// Provision sets up the module.
53func (g *Gate) Provision(ctx caddy.Context) error {
54 g.logger = ctx.Logger()
55
56 // 1. Get Global App
57 app, err := ctx.App("atproto")
58 if err != nil {
59 return fmt.Errorf("getting atproto app: %w", err)
60 }
61 g.app = app.(*App)
62
63 // 2. Initialize Session Manager (using global secret)
64 g.sessions = g.app.SessionManager
65
66 // 4. Initialize UI Renderer
67 renderer, err := ui.NewRenderer(g.UI)
68 if err != nil {
69 return fmt.Errorf("failed to init ui renderer: %w", err)
70 }
71 g.renderer = renderer
72
73 // 5. Initialize OAuth Manager (if client_id set for refresh)
74 if g.ClientID != "" {
75 // We don't strictly need callbackURL for refresh, but we pass empty string.
76 // If Manager needs it, we might need to add it to config.
77 mgr, err := oauth.NewManager(g.app.Store, g.ClientID, "")
78 if err != nil {
79 return fmt.Errorf("failed to init oauth manager for refresh: %w", err)
80 }
81 g.oauth = mgr
82 }
83
84 // Default PortalURL if empty?
85 // If empty, we can't really redirect anywhere meaningful unless we assume /login.
86 if g.PortalURL == "" {
87 g.PortalURL = "/"
88 }
89
90 // 6. Pre-resolve allowed handles to DIDs
91 // We need a resolver for this
92 resolverInstance := resolver.New()
93
94 g.resolvedDIDs = make([]string, 0, len(g.Allow))
95 ctxResolver := context.Background() // Use background context for boot-time resolution
96 for _, allow := range g.Allow {
97 if allow == "*" {
98 g.resolvedDIDs = append(g.resolvedDIDs, "*")
99 continue
100 }
101
102 // If it's already a DID, append it directly
103 if strings.HasPrefix(allow, "did:") {
104 g.resolvedDIDs = append(g.resolvedDIDs, allow)
105 continue
106 }
107
108 // Treat as handle and resolve
109 did, err := resolverInstance.ResolveIdentifier(ctxResolver, allow)
110 if err != nil {
111 g.logger.Warn("failed to resolve handle during provision", zap.String("handle", allow), zap.Error(err))
112 } else {
113 g.resolvedDIDs = append(g.resolvedDIDs, did)
114 }
115 }
116
117 return nil
118}
119
120// Validate checks that the configuration is valid.
121func (g *Gate) Validate() error {
122 return nil
123}
124
125// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
126func (g *Gate) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
127 for d.Next() {
128 for nesting := d.Nesting(); d.NextBlock(nesting); {
129 switch d.Val() {
130 case "allow":
131 g.Allow = append(g.Allow, d.RemainingArgs()...)
132 case "client_id":
133 if !d.NextArg() {
134 return d.ArgErr()
135 }
136 g.ClientID = d.Val()
137 case "portal_url":
138 if !d.NextArg() {
139 return d.ArgErr()
140 }
141 g.PortalURL = d.Val()
142 case "ui":
143 for nesting := d.Nesting(); d.NextBlock(nesting); {
144 switch d.Val() {
145 case "login_template":
146 if !d.NextArg() {
147 return d.ArgErr()
148 }
149 g.UI.LoginTemplatePath = d.Val()
150 case "forbidden_template":
151 if !d.NextArg() {
152 return d.ArgErr()
153 }
154 g.UI.ForbiddenTemplatePath = d.Val()
155 default:
156 return d.Errf("unrecognized subdirective '%s'", d.Val())
157 }
158 }
159 default:
160 return d.Errf("unrecognized subdirective '%s'", d.Val())
161 }
162 }
163 }
164 return nil
165}
166
167// parseCaddyfileGate parses the atproto_gate directive from a Caddyfile.
168func parseCaddyfileGate(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
169 var g Gate
170 err := g.UnmarshalCaddyfile(h.Dispenser)
171 return &g, err
172}
173
174// ServeHTTP implements caddyhttp.MiddlewareHandler.
175func (g *Gate) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
176 if r.URL.Path == "/logout" && g.PortalURL != "" {
177 scheme := "https"
178 if r.TLS == nil && r.Header.Get("X-Forwarded-Proto") != "https" {
179 scheme = "http"
180 }
181 host := r.Host
182 currentURL := fmt.Sprintf("%s://%s", scheme, host)
183
184 // Ensure PortalURL doesn't end with /
185 portalURL := g.PortalURL
186 if portalURL == "/" {
187 portalURL = ""
188 } else if len(portalURL) > 0 && portalURL[len(portalURL)-1] == '/' {
189 portalURL = portalURL[:len(portalURL)-1]
190 }
191
192 // Also perform local credential invalidation if possible (composite mode)
193 sess, err := g.sessions.VerifyCookie(r)
194 if err == nil || err == session.ErrExpired {
195 if g.oauth != nil {
196 if err := g.oauth.Logout(r.Context(), sess.DID, sess.SessionID); err != nil {
197 g.logger.Error("failed to revoke session during local logout", zap.Error(err))
198 }
199 }
200 }
201
202 // Clear local session cookie
203 http.SetCookie(w, g.sessions.ClearCookie(strings.Split(host, ":")[0]))
204
205 portalLogout := fmt.Sprintf("%s/logout?redirect_to=%s", portalURL, url.QueryEscape(currentURL))
206 http.Redirect(w, r, portalLogout, http.StatusFound)
207 return nil
208 }
209
210 // 1. Verify stateless cookie here
211 sess, err := g.sessions.VerifyCookie(r)
212 if err == session.ErrExpired {
213 // Attempt transparent refresh if we are in a mode that supports it.
214 // We need an OAuth manager to refresh.
215 // If ClientID is set, g.oauth is set.
216
217 if g.oauth != nil && sess != nil {
218 clientSession, err := g.oauth.ResumeSession(r.Context(), sess.DID, sess.SessionID)
219 if err == nil {
220 // Refresh tokens
221 if _, err := clientSession.RefreshTokens(r.Context()); err == nil {
222 // Success! Update cookie.
223 // We need to extend expiration.
224 // Handle lookup might be needed if not in session?
225 // Sess has Handle.
226 cookie, err := g.sessions.CreateCookie(
227 clientSession.Data.AccountDID,
228 sess.Handle, // Keep handle from old cookie
229 clientSession.Data.SessionID,
230 24*7*time.Hour,
231 strings.Split(r.Host, ":")[0],
232 )
233 if err == nil {
234 http.SetCookie(w, cookie)
235 r.AddCookie(cookie)
236 // Proceed as authorized
237 r.Header.Set("X-Atproto-Did", sess.DID)
238 r.Header.Set("X-Atproto-Handle", sess.Handle)
239 return next.ServeHTTP(w, r)
240 }
241 }
242 }
243 // If refresh failed, fall through to re-login logic
244 }
245 } else if err == nil {
246 // Session valid!
247 // Check authorization against allowlist
248 allowed := false
249 for _, allow := range g.resolvedDIDs {
250 if allow == "*" || allow == sess.DID {
251 allowed = true
252 break
253 }
254 }
255
256 if allowed {
257 // Inject headers
258 r.Header.Set("X-Atproto-Did", sess.DID)
259 r.Header.Set("X-Atproto-Handle", sess.Handle)
260 return next.ServeHTTP(w, r)
261 }
262
263 // Authenticated but not authorized
264 w.Header().Set("Content-Type", "text/html; charset=utf-8")
265 w.WriteHeader(http.StatusForbidden)
266 if err := g.renderer.RenderForbidden(w, ui.ForbiddenData{
267 AppName: "Gate", // We don't have Domain/AppName anymore, maybe use Host?
268 DID: sess.DID,
269 Handle: sess.Handle,
270 }); err != nil {
271 g.logger.Error("failed to render forbidden page", zap.Error(err))
272 }
273 return nil
274 }
275
276 // 2. If invalid/missing, initiate redirect to Portal
277 if g.PortalURL != "" {
278 // Construct redirect URL: ${PortalURL}/login?redirect_to=${CurrentURL}
279 scheme := "https"
280 if r.TLS == nil && r.Header.Get("X-Forwarded-Proto") != "https" {
281 scheme = "http"
282 }
283 host := r.Host
284 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI())
285
286 // Ensure PortalURL doesn't end with / if we append /login
287 portalURL := g.PortalURL
288 if portalURL == "/" {
289 portalURL = ""
290 } else if len(portalURL) > 0 && portalURL[len(portalURL)-1] == '/' {
291 portalURL = portalURL[:len(portalURL)-1]
292 }
293
294 portalLogin := fmt.Sprintf("%s/login?redirect_to=%s", portalURL, url.QueryEscape(currentURL))
295 http.Redirect(w, r, portalLogin, http.StatusFound)
296 return nil
297 }
298
299 // Fallback: 401
300 return caddyhttp.Error(http.StatusUnauthorized, fmt.Errorf("unauthorized"))
301}
302
303// Interface guards
304var (
305 _ caddy.Provisioner = (*Gate)(nil)
306 _ caddy.Validator = (*Gate)(nil)
307 _ caddyhttp.MiddlewareHandler = (*Gate)(nil)
308 _ caddyfile.Unmarshaler = (*Gate)(nil)
309)