Caddy module to require at-proto authentication and restrict routes to DIDs
1package caddyatprotoauth
2
3import (
4 "encoding/json"
5 "fmt"
6 "net"
7 "net/http"
8 "net/url"
9 "strings"
10 "sync"
11
12 "github.com/caddyserver/caddy/v2"
13 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
14 "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
15 "github.com/caddyserver/caddy/v2/modules/caddyhttp"
16 "go.uber.org/zap"
17 "tangled.org/vvill.dev/caddy-atproto-auth/internal/oauth"
18 "tangled.org/vvill.dev/caddy-atproto-auth/internal/session"
19 "tangled.org/vvill.dev/caddy-atproto-auth/internal/ui"
20)
21
22func init() {
23 caddy.RegisterModule(&Portal{})
24 httpcaddyfile.RegisterHandlerDirective("atproto_portal", parseCaddyfilePortal)
25}
26
27// Portal is the centralized authentication portal for Path B (Auth Hub).
28type Portal struct {
29 Name string `json:"name,omitempty"`
30 LoginTemplatePath string `json:"login_template,omitempty"`
31 ForbiddenTemplatePath string `json:"forbidden_template,omitempty"`
32 CookieName string `json:"cookie_name,omitempty"`
33 CookieDomain string `json:"cookie_domain,omitempty"`
34 AllowedRedirectDomains []string `json:"allowed_redirect_domains,omitempty"`
35 LogoutRedirectURL string `json:"logout_redirect_url,omitempty"`
36
37 // Paths configuration
38 PathPrefix string `json:"path_prefix,omitempty"`
39
40 // Dependencies
41 app *App
42 sessions *session.Manager
43 oauthManagers map[string]*oauth.Manager
44 oauthMu sync.RWMutex
45 renderer *ui.Renderer
46 logger *zap.Logger
47}
48
49// CaddyModule returns the Caddy module information.
50func (*Portal) CaddyModule() caddy.ModuleInfo {
51 return caddy.ModuleInfo{
52 ID: "http.handlers.atproto_portal",
53 New: func() caddy.Module { return new(Portal) },
54 }
55}
56
57// Provision sets up the module.
58func (p *Portal) Provision(ctx caddy.Context) error {
59 p.logger = ctx.Logger()
60
61 // Get Global App
62 app, err := ctx.App("atproto")
63 if err != nil {
64 return fmt.Errorf("getting atproto app: %w", err)
65 }
66 p.app = app.(*App)
67
68 // Initialize Session Manager from global app secret
69 p.sessions = session.NewManager(p.app.CookieSecret, p.CookieName, p.CookieDomain)
70
71 // Initialize UI Renderer
72 renderer, err := ui.NewRenderer(ui.Config{
73 LoginTemplatePath: p.LoginTemplatePath,
74 ForbiddenTemplatePath: p.ForbiddenTemplatePath,
75 })
76 if err != nil {
77 return fmt.Errorf("failed to init ui renderer: %w", err)
78 }
79 p.renderer = renderer
80
81 p.oauthManagers = make(map[string]*oauth.Manager)
82
83 // Normalize PathPrefix (ensure it starts with / and doesn't end with /)
84 if p.PathPrefix != "" {
85 if !strings.HasPrefix(p.PathPrefix, "/") {
86 p.PathPrefix = "/" + p.PathPrefix
87 }
88 p.PathPrefix = strings.TrimSuffix(p.PathPrefix, "/")
89 }
90
91 return nil
92}
93
94// getOAuthManager gets or initializes the OAuth manager for a specific host.
95func (p *Portal) getOAuthManager(r *http.Request) (*oauth.Manager, error) {
96 host := getRequestHost(r)
97
98 p.oauthMu.RLock()
99 mgr, exists := p.oauthManagers[host]
100 p.oauthMu.RUnlock()
101
102 if exists {
103 return mgr, nil
104 }
105
106 p.oauthMu.Lock()
107 defer p.oauthMu.Unlock()
108
109 // Double-check after acquiring write lock
110 if mgr, exists := p.oauthManagers[host]; exists {
111 return mgr, nil
112 }
113
114 if len(p.oauthManagers) >= p.app.OAuthManagerCacheSize {
115 // Prevent DoS from unbounded map growth
116 p.logger.Warn("oauth managers cache full, clearing to prevent OOM")
117 p.oauthManagers = make(map[string]*oauth.Manager)
118 }
119
120 scheme := getRequestScheme(r)
121
122 clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", scheme, host)
123 callbackURL := fmt.Sprintf("%s://%s%s/callback", scheme, host, p.PathPrefix)
124
125 mgr, err := oauth.NewManager(p.app.Store, clientID, callbackURL, p.app.AllowPrivateCIDRs)
126 if err != nil {
127 return nil, err
128 }
129
130 p.oauthManagers[host] = mgr
131 return mgr, nil
132}
133
134// Validate checks that the configuration is valid.
135func (p *Portal) Validate() error {
136 if p.Name == "" {
137 p.Name = "Authentication Portal"
138 }
139 return nil
140}
141
142// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
143func (p *Portal) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
144 for d.Next() {
145 for nesting := d.Nesting(); d.NextBlock(nesting); {
146 switch d.Val() {
147 case "name":
148 if !d.NextArg() {
149 return d.ArgErr()
150 }
151 p.Name = d.Val()
152 case "login_template":
153 if !d.NextArg() {
154 return d.ArgErr()
155 }
156 p.LoginTemplatePath = d.Val()
157 case "forbidden_template":
158 if !d.NextArg() {
159 return d.ArgErr()
160 }
161 p.ForbiddenTemplatePath = d.Val()
162 case "cookie_name":
163 if !d.NextArg() {
164 return d.ArgErr()
165 }
166 p.CookieName = d.Val()
167 case "cookie_domain":
168 if !d.NextArg() {
169 return d.ArgErr()
170 }
171 p.CookieDomain = d.Val()
172 case "allowed_redirect_domains":
173 p.AllowedRedirectDomains = append(p.AllowedRedirectDomains, d.RemainingArgs()...)
174 case "logout_redirect_url":
175 if !d.NextArg() {
176 return d.ArgErr()
177 }
178 p.LogoutRedirectURL = d.Val()
179 case "path_prefix":
180 if !d.NextArg() {
181 return d.ArgErr()
182 }
183 p.PathPrefix = d.Val()
184 default:
185 return d.Errf("unrecognized subdirective '%s'", d.Val())
186 }
187 }
188 }
189 return nil
190}
191
192func parseCaddyfilePortal(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
193 var p Portal
194 err := p.UnmarshalCaddyfile(h.Dispenser)
195 return &p, err
196}
197
198func (p *Portal) setSecurityHeaders(w http.ResponseWriter) {
199 w.Header().Set("Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'")
200 w.Header().Set("X-Content-Type-Options", "nosniff")
201 w.Header().Set("X-Frame-Options", "DENY")
202}
203
204// ServeHTTP implements caddyhttp.MiddlewareHandler.
205func (p *Portal) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
206 oauthMgr, err := p.getOAuthManager(r)
207 if err != nil {
208 p.logger.Error("failed to get oauth manager", zap.Error(err))
209 http.Error(w, "Internal Server Error", http.StatusInternalServerError)
210 return nil
211 }
212
213 // 1. Metadata Endpoint
214 if r.URL.Path == "/.well-known/oauth-client-metadata.json" {
215 meta, err := oauthMgr.GetClientMetadata()
216 if err != nil {
217 p.logger.Error("failed to get client metadata", zap.Error(err))
218 http.Error(w, "Internal Server Error", http.StatusInternalServerError)
219 return nil
220 }
221 w.Header().Set("Content-Type", "application/json")
222 json.NewEncoder(w).Encode(meta)
223 return nil
224 }
225
226 // 2. Callback Endpoint
227 callbackPath := p.PathPrefix + "/callback"
228 if r.URL.Path == callbackPath {
229 // Process callback
230 ctx := r.Context()
231 query := r.URL.Query()
232
233 sessionData, handle, err := oauthMgr.ProcessCallback(ctx, query)
234 if err != nil {
235 p.logger.Error("oauth callback failed", zap.Error(err))
236 p.setSecurityHeaders(w)
237 w.Header().Set("Content-Type", "text/html; charset=utf-8")
238 w.WriteHeader(http.StatusBadRequest)
239 _ = p.renderer.RenderLogin(w, ui.LoginData{
240 AppName: p.Name,
241 Error: "Authentication failed. Please try again.",
242 LoginURL: p.PathPrefix + "/login",
243 })
244 return nil
245 }
246
247 reqDomain := getRequestHost(r)
248
249 // Create Session Cookie
250 cookie, err := p.sessions.CreateCookie(
251 sessionData.AccountDID,
252 handle,
253 sessionData.SessionID,
254 p.app.SessionDuration,
255 )
256 if err != nil {
257 p.logger.Error("failed to create session cookie", zap.Error(err))
258 http.Error(w, "Internal Error", http.StatusInternalServerError)
259 return nil
260 }
261
262 http.SetCookie(w, cookie)
263
264 appNameForLog := p.Name
265 if appNameForLog == "" {
266 appNameForLog = reqDomain
267 }
268 p.logger.Info(fmt.Sprintf("@%s (did: %s) has logged in for %s", handle, sessionData.AccountDID.String(), appNameForLog))
269
270 // Check for redirect_to cookie
271 redirectTo := "/"
272 state := r.URL.Query().Get("state")
273 cookieName := fmt.Sprintf("atproto_redirect_to_%s", state)
274 if redirectCookie, err := r.Cookie(cookieName); err == nil && redirectCookie.Value != "" {
275 redirectTo = redirectCookie.Value
276 // Clear cookie
277 http.SetCookie(w, &http.Cookie{
278 Name: cookieName,
279 Value: "",
280 Path: "/",
281 MaxAge: -1,
282 HttpOnly: true,
283 Secure: true,
284 })
285
286 // Basic open redirect mitigation: ensure it's a relative path or matches CookieDomain/Domain
287 if strings.HasPrefix(redirectTo, "http://") || strings.HasPrefix(redirectTo, "https://") {
288 parsed, err := url.Parse(redirectTo)
289 // Allow redirect if host is our exact domain, or if cookie domain is a parent of the host
290 isAllowedDomain := false
291 if err == nil {
292 h := parsed.Host
293 // Strip port if present
294 hostNoPort, _, hostErr := net.SplitHostPort(h)
295 if hostErr != nil {
296 hostNoPort = h
297 }
298
299 if hostNoPort == reqDomain {
300 isAllowedDomain = true
301 } else {
302 isAllowedDomain = checkAllowedDomain(hostNoPort, p.AllowedRedirectDomains)
303 }
304 }
305
306 if !isAllowedDomain {
307 p.logger.Warn("blocked cross-domain redirect", zap.String("url", redirectTo))
308 redirectTo = "/" // Fallback to home if invalid or not matching domain
309 }
310 }
311 }
312
313 // Redirect to home or saved location
314 http.Redirect(w, r, redirectTo, http.StatusFound)
315 return nil
316 }
317
318 // 3. Login Start (Form Action)
319 loginPath := p.PathPrefix + "/login"
320 logoutPath := p.PathPrefix + "/logout"
321 forbiddenPath := p.PathPrefix + "/forbidden"
322
323 if r.URL.Path == loginPath && r.Method == "POST" {
324 handle := r.FormValue("handle")
325 // Strip leading @ if present
326 if len(handle) > 0 && handle[0] == '@' {
327 handle = handle[1:]
328 }
329
330 if handle == "" {
331 http.Error(w, "Handle required", http.StatusBadRequest)
332 return nil
333 }
334
335 // Start Auth Flow
336 redirectURI, err := oauthMgr.StartAuthFlow(r.Context(), handle)
337 if err != nil {
338 p.logger.Error("failed to start auth flow", zap.Error(err), zap.String("handle", handle))
339 // Render error on login page
340 p.setSecurityHeaders(w)
341 w.Header().Set("Content-Type", "text/html; charset=utf-8")
342 w.WriteHeader(http.StatusBadRequest)
343 if renderErr := p.renderer.RenderLogin(w, ui.LoginData{
344 AppName: p.Name,
345 Redirect: r.FormValue("redirect_to"),
346 Error: "Authentication failed. Please try again.",
347 LoginURL: loginPath,
348 }); renderErr != nil {
349 p.logger.Error("failed to render login error", zap.Error(renderErr))
350 }
351 return nil
352 }
353
354 if redirectTo := r.FormValue("redirect_to"); redirectTo != "" {
355 u, _ := url.Parse(redirectURI)
356 state := u.Query().Get("state")
357 http.SetCookie(w, &http.Cookie{
358 Name: fmt.Sprintf("atproto_redirect_to_%s", state),
359 Value: redirectTo,
360 Path: "/",
361 MaxAge: 300,
362 HttpOnly: true,
363 Secure: true,
364 SameSite: http.SameSiteLaxMode,
365 })
366 }
367
368 http.Redirect(w, r, redirectURI, http.StatusFound)
369 return nil
370 }
371
372 // 4. Default: Login Page
373 if r.URL.Path == loginPath {
374 if sess, err := p.sessions.VerifyCookie(r); err == nil && sess != nil {
375 // Already logged in
376 redirectTo := r.URL.Query().Get("redirect_to")
377 if redirectTo == "" {
378 redirectTo = "/"
379 }
380 http.Redirect(w, r, redirectTo, http.StatusFound)
381 return nil
382 }
383
384 p.setSecurityHeaders(w)
385 w.Header().Set("Content-Type", "text/html; charset=utf-8")
386 if err := p.renderer.RenderLogin(w, ui.LoginData{
387 AppName: p.Name,
388 Redirect: r.URL.Query().Get("redirect_to"),
389 LoginURL: loginPath,
390 }); err != nil {
391 p.logger.Error("failed to render login page", zap.Error(err))
392 return caddyhttp.Error(http.StatusInternalServerError, err)
393 }
394 return nil
395 }
396
397 // 5. Logout
398 if r.URL.Path == logoutPath {
399 // Invalidate credential if session exists
400 sess, err := p.sessions.VerifyCookie(r)
401
402 reqDomain := getRequestHost(r)
403
404 if err == nil || err == session.ErrExpired {
405 appNameForLog := p.Name
406 if appNameForLog == "" {
407 appNameForLog = reqDomain
408 }
409 p.logger.Info(fmt.Sprintf("@%s (did: %s) has logged out for %s", sess.Handle, sess.DID, appNameForLog))
410
411 if err := oauthMgr.Logout(r.Context(), sess.DID, sess.SessionID); err != nil {
412 p.logger.Error("failed to revoke session during logout", zap.Error(err))
413 }
414 }
415
416 http.SetCookie(w, p.sessions.ClearCookie())
417
418 // Handle redirect_to for logout
419 redirectTo := r.URL.Query().Get("redirect_to")
420 if redirectTo == "" {
421 if p.LogoutRedirectURL != "" {
422 redirectTo = p.LogoutRedirectURL
423 } else {
424 redirectTo = loginPath
425 }
426 } else {
427 // Basic open redirect mitigation: ensure it's a relative path or matches CookieDomain/Domain
428 if strings.HasPrefix(redirectTo, "http://") || strings.HasPrefix(redirectTo, "https://") {
429 parsed, err := url.Parse(redirectTo)
430 isAllowedDomain := false
431 if err == nil {
432 h := parsed.Host
433 hostNoPort, _, hostErr := net.SplitHostPort(h)
434 if hostErr != nil {
435 hostNoPort = h
436 }
437
438 if hostNoPort == reqDomain {
439 isAllowedDomain = true
440 } else {
441 isAllowedDomain = checkAllowedDomain(hostNoPort, p.AllowedRedirectDomains)
442 }
443 }
444
445 if !isAllowedDomain {
446 p.logger.Warn("blocked cross-domain redirect on logout", zap.String("url", redirectTo))
447 if p.LogoutRedirectURL != "" {
448 redirectTo = p.LogoutRedirectURL
449 } else {
450 redirectTo = loginPath // Fallback to login page
451 }
452 }
453 }
454 }
455
456 http.Redirect(w, r, redirectTo, http.StatusFound)
457 return nil
458 }
459
460 // 6. Forbidden Page
461 if r.URL.Path == forbiddenPath {
462 p.setSecurityHeaders(w)
463 w.Header().Set("Content-Type", "text/html; charset=utf-8")
464 w.WriteHeader(http.StatusForbidden)
465
466 var did, handle string
467 sess, err := p.sessions.VerifyCookie(r)
468 if err == nil || err == session.ErrExpired {
469 did = sess.DID
470 handle = sess.Handle
471 }
472
473 if err := p.renderer.RenderForbidden(w, ui.ForbiddenData{
474 AppName: p.Name,
475 DID: did,
476 Handle: handle,
477 LogoutURL: logoutPath,
478 }); err != nil {
479 p.logger.Error("failed to render forbidden page", zap.Error(err))
480 }
481 return nil
482 }
483
484 return next.ServeHTTP(w, r)
485}
486
487// Interface guards
488var (
489 _ caddy.Provisioner = (*Portal)(nil)
490 _ caddy.Validator = (*Portal)(nil)
491 _ caddyhttp.MiddlewareHandler = (*Portal)(nil)
492 _ caddyfile.Unmarshaler = (*Portal)(nil)
493)