Caddy module to require at-proto authentication and restrict routes to DIDs
2

Configure Feed

Select the types of activity you want to include in your feed.

AI: Resolve various bugs to do with redirects and syntax

+428 -230
+151 -123
gate.go
··· 6 6 "net/http" 7 7 "net/url" 8 8 "strings" 9 - "time" 9 + "sync" 10 10 11 11 "github.com/caddyserver/caddy/v2" 12 12 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" ··· 16 16 "tangled.org/vvill.dev/caddy-atproto-auth/internal/oauth" 17 17 "tangled.org/vvill.dev/caddy-atproto-auth/internal/resolver" 18 18 "tangled.org/vvill.dev/caddy-atproto-auth/internal/session" 19 - "tangled.org/vvill.dev/caddy-atproto-auth/internal/ui" 20 19 ) 21 20 22 21 func init() { 23 - caddy.RegisterModule(Gate{}) 22 + caddy.RegisterModule(&Gate{}) 24 23 httpcaddyfile.RegisterHandlerDirective("atproto_gate", parseCaddyfileGate) 25 24 } 26 25 27 26 // Gate acts as a middleware that guards endpoints 28 27 // and validates the session cookie. 29 28 type Gate struct { 30 - Allow []string `json:"allow,omitempty"` 31 - ClientID string `json:"client_id,omitempty"` // ClientID for session refreshing (e.g. https://example.com/client-metadata.json) 32 - PortalURL string `json:"portal_url,omitempty"` // URL of the auth portal (e.g. http://localhost:8080 or /) 33 - UI ui.Config `json:"ui,omitempty"` // Custom UI configuration 29 + Allow []string `json:"allow,omitempty"` 30 + PortalURL string `json:"portal_url,omitempty"` // URL of the auth portal (e.g. https://auth.example.com or /auth) 31 + CookieName string `json:"cookie_name,omitempty"` 32 + CookieDomain string `json:"cookie_domain,omitempty"` 33 + ResolveHandlesOnRequest bool `json:"resolve_handles_on_request,omitempty"` 34 34 35 35 // Dependencies 36 - app *App 37 - sessions *session.Manager 38 - oauth *oauth.Manager 39 - renderer *ui.Renderer 40 - logger *zap.Logger 41 - resolvedDIDs []string 36 + app *App 37 + sessions *session.Manager 38 + oauthManagers map[string]*oauth.Manager 39 + oauthMu sync.RWMutex 40 + logger *zap.Logger 41 + resolvedDIDs []string 42 + handleCache sync.Map 43 + resolver *resolver.Resolver 42 44 } 43 45 44 46 // CaddyModule returns the Caddy module information. 45 - func (Gate) CaddyModule() caddy.ModuleInfo { 47 + func (*Gate) CaddyModule() caddy.ModuleInfo { 46 48 return caddy.ModuleInfo{ 47 49 ID: "http.handlers.atproto_gate", 48 50 New: func() caddy.Module { return new(Gate) }, ··· 61 63 g.app = app.(*App) 62 64 63 65 // 2. Initialize Session Manager (using global secret) 64 - g.sessions = g.app.SessionManager 66 + g.sessions = session.NewManager(g.app.CookieSecret, g.CookieName, g.CookieDomain) 67 + 68 + g.oauthManagers = make(map[string]*oauth.Manager) 65 69 66 - // 4. Initialize UI Renderer 67 - renderer, err := ui.NewRenderer(g.UI) 68 - if err != nil { 69 - return fmt.Errorf("failed to init ui renderer: %w", err) 70 + // Normalize PortalURL (ensure it doesn't end with /) 71 + if g.PortalURL == "" { 72 + g.PortalURL = "/" 73 + } else if len(g.PortalURL) > 0 && g.PortalURL[len(g.PortalURL)-1] == '/' && g.PortalURL != "/" { 74 + g.PortalURL = g.PortalURL[:len(g.PortalURL)-1] 70 75 } 71 - g.renderer = renderer 72 76 73 - // 5. Initialize OAuth Manager (if client_id set for refresh) 74 - if g.ClientID != "" { 75 - // We don't strictly need callbackURL for refresh, but we pass empty string. 76 - // If Manager needs it, we might need to add it to config. 77 - mgr, err := oauth.NewManager(g.app.Store, g.ClientID, "") 78 - if err != nil { 79 - return fmt.Errorf("failed to init oauth manager for refresh: %w", err) 77 + // 5. Initialize OAuth Manager for transparent refresh 78 + // We derive the ClientID from the PortalURL if it's absolute, 79 + // or from the Host header at request time if it's relative. 80 + // For now, if it's absolute, we can init the OAuth manager immediately. 81 + if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") { 82 + parsedURL, err := url.Parse(g.PortalURL) 83 + if err == nil { 84 + clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", parsedURL.Scheme, parsedURL.Host) 85 + mgr, err := oauth.NewManager(g.app.Store, clientID, "") 86 + if err != nil { 87 + return fmt.Errorf("failed to init oauth manager for refresh: %w", err) 88 + } 89 + g.oauthManagers[parsedURL.Host] = mgr 80 90 } 81 - g.oauth = mgr 82 - } 83 - 84 - // Default PortalURL if empty? 85 - // If empty, we can't really redirect anywhere meaningful unless we assume /login. 86 - if g.PortalURL == "" { 87 - g.PortalURL = "/" 88 91 } 89 92 90 93 // 6. Pre-resolve allowed handles to DIDs 91 - // We need a resolver for this 92 - resolverInstance := resolver.New() 94 + g.resolver = resolver.New() 93 95 94 96 g.resolvedDIDs = make([]string, 0, len(g.Allow)) 95 97 ctxResolver := context.Background() // Use background context for boot-time resolution ··· 106 108 } 107 109 108 110 // Treat as handle and resolve 109 - did, err := resolverInstance.ResolveIdentifier(ctxResolver, allow) 111 + did, err := g.resolver.ResolveIdentifier(ctxResolver, allow) 110 112 if err != nil { 111 113 g.logger.Warn("failed to resolve handle during provision", zap.String("handle", allow), zap.Error(err)) 112 114 } else { 113 115 g.resolvedDIDs = append(g.resolvedDIDs, did) 116 + g.handleCache.Store(allow, did) 114 117 } 115 118 } 116 119 ··· 129 132 switch d.Val() { 130 133 case "allow": 131 134 g.Allow = append(g.Allow, d.RemainingArgs()...) 132 - case "client_id": 135 + case "cookie_name": 133 136 if !d.NextArg() { 134 137 return d.ArgErr() 135 138 } 136 - g.ClientID = d.Val() 139 + g.CookieName = d.Val() 140 + case "cookie_domain": 141 + if !d.NextArg() { 142 + return d.ArgErr() 143 + } 144 + g.CookieDomain = d.Val() 137 145 case "portal_url": 138 146 if !d.NextArg() { 139 147 return d.ArgErr() 140 148 } 141 149 g.PortalURL = d.Val() 142 - case "ui": 143 - for nesting := d.Nesting(); d.NextBlock(nesting); { 144 - switch d.Val() { 145 - case "login_template": 146 - if !d.NextArg() { 147 - return d.ArgErr() 148 - } 149 - g.UI.LoginTemplatePath = d.Val() 150 - case "forbidden_template": 151 - if !d.NextArg() { 152 - return d.ArgErr() 153 - } 154 - g.UI.ForbiddenTemplatePath = d.Val() 155 - default: 156 - return d.Errf("unrecognized subdirective '%s'", d.Val()) 157 - } 158 - } 150 + case "resolve_handles_on_request": 151 + g.ResolveHandlesOnRequest = true 159 152 default: 160 153 return d.Errf("unrecognized subdirective '%s'", d.Val()) 161 154 } ··· 171 164 return &g, err 172 165 } 173 166 174 - // ServeHTTP implements caddyhttp.MiddlewareHandler. 175 - func (g *Gate) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { 176 - if r.URL.Path == "/logout" && g.PortalURL != "" { 177 - scheme := "https" 178 - if r.TLS == nil && r.Header.Get("X-Forwarded-Proto") != "https" { 179 - scheme = "http" 180 - } 181 - host := r.Host 182 - currentURL := fmt.Sprintf("%s://%s", scheme, host) 167 + // getOAuthManager gets or initializes the OAuth manager for a specific host. 168 + func (g *Gate) getOAuthManager(r *http.Request) (*oauth.Manager, error) { 169 + host := getRequestHost(r) 183 170 184 - // Ensure PortalURL doesn't end with / 185 - portalURL := g.PortalURL 186 - if portalURL == "/" { 187 - portalURL = "" 188 - } else if len(portalURL) > 0 && portalURL[len(portalURL)-1] == '/' { 189 - portalURL = portalURL[:len(portalURL)-1] 171 + // If PortalURL is absolute, we already cached it under parsedURL.Host 172 + if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") { 173 + parsedURL, err := url.Parse(g.PortalURL) 174 + if err == nil { 175 + host = parsedURL.Host 190 176 } 177 + } 191 178 192 - // Also perform local credential invalidation if possible (composite mode) 193 - sess, err := g.sessions.VerifyCookie(r) 194 - if err == nil || err == session.ErrExpired { 195 - if g.oauth != nil { 196 - if err := g.oauth.Logout(r.Context(), sess.DID, sess.SessionID); err != nil { 197 - g.logger.Error("failed to revoke session during local logout", zap.Error(err)) 198 - } 199 - } 200 - } 179 + g.oauthMu.RLock() 180 + mgr, exists := g.oauthManagers[host] 181 + g.oauthMu.RUnlock() 182 + 183 + if exists { 184 + return mgr, nil 185 + } 186 + 187 + g.oauthMu.Lock() 188 + defer g.oauthMu.Unlock() 189 + 190 + if mgr, exists := g.oauthManagers[host]; exists { 191 + return mgr, nil 192 + } 193 + 194 + if len(g.oauthManagers) >= g.app.OAuthManagerCacheSize { 195 + // Prevent DoS from unbounded map growth 196 + g.logger.Warn("oauth managers cache full, clearing to prevent OOM") 197 + g.oauthManagers = make(map[string]*oauth.Manager) 198 + } 201 199 202 - // Clear local session cookie 203 - http.SetCookie(w, g.sessions.ClearCookie(strings.Split(host, ":")[0])) 200 + scheme := getRequestScheme(r) 204 201 205 - portalLogout := fmt.Sprintf("%s/logout?redirect_to=%s", portalURL, url.QueryEscape(currentURL)) 206 - http.Redirect(w, r, portalLogout, http.StatusFound) 207 - return nil 202 + clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", scheme, host) 203 + mgr, err := oauth.NewManager(g.app.Store, clientID, "") 204 + if err != nil { 205 + return nil, err 208 206 } 209 207 208 + g.oauthManagers[host] = mgr 209 + return mgr, nil 210 + } 211 + 212 + // ServeHTTP implements caddyhttp.MiddlewareHandler. 213 + func (g *Gate) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { 210 214 // 1. Verify stateless cookie here 211 215 sess, err := g.sessions.VerifyCookie(r) 212 216 if err == session.ErrExpired { 213 217 // Attempt transparent refresh if we are in a mode that supports it. 214 218 // We need an OAuth manager to refresh. 215 - // If ClientID is set, g.oauth is set. 219 + oauthMgr, _ := g.getOAuthManager(r) 216 220 217 - if g.oauth != nil && sess != nil { 218 - clientSession, err := g.oauth.ResumeSession(r.Context(), sess.DID, sess.SessionID) 219 - if err == nil { 221 + if oauthMgr != nil && sess != nil { 222 + clientSession, errRefresh := oauthMgr.ResumeSession(r.Context(), sess.DID, sess.SessionID) 223 + if errRefresh == nil { 220 224 // Refresh tokens 221 - if _, err := clientSession.RefreshTokens(r.Context()); err == nil { 225 + if _, errRefresh := clientSession.RefreshTokens(r.Context()); errRefresh == nil { 222 226 // Success! Update cookie. 223 - // We need to extend expiration. 224 - // Handle lookup might be needed if not in session? 225 - // Sess has Handle. 226 - cookie, err := g.sessions.CreateCookie( 227 + 228 + // Resolve fresh handle, fallback to old if unavailable 229 + handle := sess.Handle 230 + ident, errDir := oauthMgr.App.Dir.LookupDID(r.Context(), clientSession.Data.AccountDID) 231 + if errDir == nil && ident != nil { 232 + handle = ident.Handle.String() 233 + } 234 + 235 + cookie, errCookie := g.sessions.CreateCookie( 227 236 clientSession.Data.AccountDID, 228 - sess.Handle, // Keep handle from old cookie 237 + handle, 229 238 clientSession.Data.SessionID, 230 - 24*7*time.Hour, 231 - strings.Split(r.Host, ":")[0], 239 + g.app.SessionDuration, 232 240 ) 233 - if err == nil { 241 + if errCookie == nil { 234 242 http.SetCookie(w, cookie) 235 243 r.AddCookie(cookie) 236 - // Proceed as authorized 237 - r.Header.Set("X-Atproto-Did", sess.DID) 238 - r.Header.Set("X-Atproto-Handle", sess.Handle) 239 - return next.ServeHTTP(w, r) 244 + 245 + // Update local session for authorization checks below 246 + sess.DID = clientSession.Data.AccountDID.String() 247 + sess.Handle = handle 248 + err = nil // clear expiration error to proceed 240 249 } 241 250 } 242 251 } 243 - // If refresh failed, fall through to re-login logic 252 + // If refresh failed, err remains ErrExpired, falling through to redirect 244 253 } 245 - } else if err == nil { 254 + } 255 + 256 + if err == nil { 246 257 // Session valid! 247 258 // Check authorization against allowlist 248 259 allowed := false ··· 253 264 } 254 265 } 255 266 267 + if !allowed && g.ResolveHandlesOnRequest { 268 + // Try dynamic resolution for handles that weren't in the pre-resolved list 269 + // or might have changed. 270 + for _, allow := range g.Allow { 271 + if allow != "*" && !strings.HasPrefix(allow, "did:") { 272 + if cachedDID, ok := g.handleCache.Load(allow); ok && cachedDID == sess.DID { 273 + allowed = true 274 + break 275 + } 276 + // Resolve and cache 277 + did, resErr := g.resolver.ResolveIdentifier(r.Context(), allow) 278 + if resErr == nil { 279 + g.handleCache.Store(allow, did) 280 + if did == sess.DID { 281 + allowed = true 282 + break 283 + } 284 + } 285 + } 286 + } 287 + } 288 + 256 289 if allowed { 257 290 // Inject headers 258 291 r.Header.Set("X-Atproto-Did", sess.DID) ··· 261 294 } 262 295 263 296 // Authenticated but not authorized 264 - w.Header().Set("Content-Type", "text/html; charset=utf-8") 265 - w.WriteHeader(http.StatusForbidden) 266 - if err := g.renderer.RenderForbidden(w, ui.ForbiddenData{ 267 - AppName: "Gate", // We don't have Domain/AppName anymore, maybe use Host? 268 - DID: sess.DID, 269 - Handle: sess.Handle, 270 - }); err != nil { 271 - g.logger.Error("failed to render forbidden page", zap.Error(err)) 297 + if g.PortalURL != "" { 298 + scheme := getRequestScheme(r) 299 + host := r.Host 300 + currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI()) 301 + 302 + portalURL := g.PortalURL 303 + portalForbidden := fmt.Sprintf("%s/forbidden?redirect_to=%s", portalURL, url.QueryEscape(currentURL)) 304 + http.Redirect(w, r, portalForbidden, http.StatusFound) 305 + return nil 272 306 } 307 + 308 + w.Header().Set("Content-Type", "text/plain; charset=utf-8") 309 + w.WriteHeader(http.StatusForbidden) 310 + w.Write([]byte("Forbidden")) 273 311 return nil 274 312 } 275 313 276 314 // 2. If invalid/missing, initiate redirect to Portal 277 315 if g.PortalURL != "" { 278 316 // Construct redirect URL: ${PortalURL}/login?redirect_to=${CurrentURL} 279 - scheme := "https" 280 - if r.TLS == nil && r.Header.Get("X-Forwarded-Proto") != "https" { 281 - scheme = "http" 282 - } 317 + scheme := getRequestScheme(r) 283 318 host := r.Host 284 319 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI()) 285 320 286 - // Ensure PortalURL doesn't end with / if we append /login 287 321 portalURL := g.PortalURL 288 - if portalURL == "/" { 289 - portalURL = "" 290 - } else if len(portalURL) > 0 && portalURL[len(portalURL)-1] == '/' { 291 - portalURL = portalURL[:len(portalURL)-1] 292 - } 293 - 294 322 portalLogin := fmt.Sprintf("%s/login?redirect_to=%s", portalURL, url.QueryEscape(currentURL)) 295 323 http.Redirect(w, r, portalLogin, http.StatusFound) 296 324 return nil
+31 -16
global.go
··· 2 2 3 3 import ( 4 4 "fmt" 5 + "strconv" 6 + "time" 5 7 6 8 "github.com/caddyserver/caddy/v2" 7 9 "github.com/caddyserver/caddy/v2/caddyconfig" ··· 9 11 "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" 10 12 11 13 "tangled.org/vvill.dev/caddy-atproto-auth/internal/db" 12 - "tangled.org/vvill.dev/caddy-atproto-auth/internal/session" 13 14 ) 14 15 15 16 func init() { 16 - caddy.RegisterModule(App{}) 17 + caddy.RegisterModule(&App{}) 17 18 httpcaddyfile.RegisterGlobalOption("atproto", parseGlobalAtproto) 18 19 } 19 20 20 21 // App configures the global atproto integration. 21 22 type App struct { 22 - StoragePath string `json:"storage_path,omitempty"` 23 - CookieSecret string `json:"cookie_secret,omitempty"` 24 - CookieName string `json:"cookie_name,omitempty"` 25 - 26 - CookieDomain string `json:"cookie_domain,omitempty"` 23 + StoragePath string `json:"storage_path,omitempty"` 24 + CookieSecret string `json:"cookie_secret,omitempty"` 25 + SessionDurationStr string `json:"session_duration,omitempty"` 26 + OAuthManagerCacheSize int `json:"oauth_manager_cache_size,omitempty"` 27 27 28 28 // Internal state 29 - Store *db.Store `json:"-"` 30 - SessionManager *session.Manager `json:"-"` 29 + Store *db.Store `json:"-"` 30 + SessionDuration time.Duration `json:"-"` 31 31 } 32 32 33 33 // CaddyModule returns the Caddy module information. 34 - func (App) CaddyModule() caddy.ModuleInfo { 34 + func (*App) CaddyModule() caddy.ModuleInfo { 35 35 return caddy.ModuleInfo{ 36 36 ID: "atproto", 37 37 New: func() caddy.Module { return new(App) }, ··· 62 62 a.CookieSecret = secret 63 63 } 64 64 65 - // Initialize Session Manager globally 66 - a.SessionManager = session.NewManager(a.CookieSecret, a.CookieName, a.CookieDomain) 65 + // Parse session duration 66 + a.SessionDuration = 24 * 7 * time.Hour 67 + if a.SessionDurationStr != "" { 68 + d, err := caddy.ParseDuration(a.SessionDurationStr) 69 + if err != nil { 70 + return fmt.Errorf("invalid session_duration: %w", err) 71 + } 72 + a.SessionDuration = time.Duration(d) 73 + } 74 + 75 + if a.OAuthManagerCacheSize <= 0 { 76 + a.OAuthManagerCacheSize = 100 // Default max oauth managers 77 + } 67 78 68 79 return nil 69 80 } ··· 109 120 return nil, d.ArgErr() 110 121 } 111 122 app.CookieSecret = d.Val() 112 - case "cookie_name": 123 + case "session_duration": 113 124 if !d.NextArg() { 114 125 return nil, d.ArgErr() 115 126 } 116 - app.CookieName = d.Val() 117 - case "cookie_domain": 127 + app.SessionDurationStr = d.Val() 128 + case "oauth_manager_cache_size": 118 129 if !d.NextArg() { 119 130 return nil, d.ArgErr() 120 131 } 121 - app.CookieDomain = d.Val() 132 + val, err := strconv.Atoi(d.Val()) 133 + if err != nil { 134 + return nil, d.Errf("invalid oauth_manager_cache_size: %v", err) 135 + } 136 + app.OAuthManagerCacheSize = val 122 137 default: 123 138 return nil, d.Errf("unrecognized subdirective '%s'", d.Val()) 124 139 }
+6 -4
internal/db/db.go
··· 4 4 "context" 5 5 "crypto/rand" 6 6 "database/sql" 7 + "encoding/hex" 7 8 "encoding/json" 8 9 "fmt" 9 10 "sync/atomic" ··· 112 113 err := s.db.QueryRowContext(ctx, "SELECT key_data FROM system_keys WHERE id = 'cookie_secret'").Scan(&secret) 113 114 if err == sql.ErrNoRows { 114 115 // Generate new random 32 byte secret 115 - secret = make([]byte, 32) 116 - if _, err := rand.Read(secret); err != nil { 116 + rawSecret := make([]byte, 32) 117 + if _, err := rand.Read(rawSecret); err != nil { 117 118 return "", fmt.Errorf("failed to generate cookie secret: %w", err) 118 119 } 120 + secretStr := hex.EncodeToString(rawSecret) 119 121 120 - _, err = s.db.ExecContext(ctx, "INSERT INTO system_keys (id, key_data) VALUES ('cookie_secret', ?)", secret) 122 + _, err = s.db.ExecContext(ctx, "INSERT INTO system_keys (id, key_data) VALUES ('cookie_secret', ?)", []byte(secretStr)) 121 123 if err != nil { 122 124 return "", fmt.Errorf("failed to save cookie secret: %w", err) 123 125 } 124 - return string(secret), nil 126 + return secretStr, nil 125 127 } else if err != nil { 126 128 return "", fmt.Errorf("failed to load cookie secret: %w", err) 127 129 }
+5 -11
internal/session/session.go
··· 52 52 } 53 53 54 54 // CreateCookie generates a signed http.Cookie for the session. 55 - func (m *Manager) CreateCookie(did syntax.DID, handle string, sessionID string, duration time.Duration, reqDomain string) (*http.Cookie, error) { 55 + func (m *Manager) CreateCookie(did syntax.DID, handle string, sessionID string, duration time.Duration) (*http.Cookie, error) { 56 56 exp := time.Now().Add(duration).Unix() 57 57 sess := Session{ 58 58 DID: did.String(), ··· 71 71 value := fmt.Sprintf("%s.%s", encoded, signature) 72 72 73 73 cookieDomain := m.CookieDomain 74 - if cookieDomain == "" { 75 - cookieDomain = reqDomain 76 - } 74 + // If cookieDomain is empty, we leave Domain empty for a host-only cookie. 75 + // We no longer fallback to reqDomain. 77 76 78 77 cookie := &http.Cookie{ 79 78 Name: m.CookieName, ··· 130 129 var ErrExpired = errors.New("session expired") 131 130 132 131 // ClearCookie returns a cookie that clears the session. 133 - func (m *Manager) ClearCookie(reqDomain string) *http.Cookie { 134 - cookieDomain := m.CookieDomain 135 - if cookieDomain == "" { 136 - cookieDomain = reqDomain 137 - } 138 - 132 + func (m *Manager) ClearCookie() *http.Cookie { 139 133 return &http.Cookie{ 140 134 Name: m.CookieName, 141 135 Value: "", 142 136 Path: "/", 143 - Domain: cookieDomain, 137 + Domain: m.CookieDomain, 144 138 Expires: time.Unix(0, 0), 145 139 MaxAge: -1, 146 140 Secure: true,
+5 -7
internal/ui/templates/forbidden.html
··· 81 81 color: var(--text-color); 82 82 } 83 83 84 - a.button { 85 - display: inline-block; 84 + button { 86 85 background-color: transparent; 87 86 color: var(--text-color); 88 87 border: 1px solid var(--border-color); 89 88 padding: 0.75rem 1.5rem; 90 - text-decoration: none; 91 89 border-radius: 6px; 92 - font-weight: 500; 93 90 transition: background-color 0.2s; 91 + cursor: pointer; 94 92 } 95 93 96 - a.button:hover { 94 + button:hover { 97 95 background-color: rgba(0, 0, 0, 0.05); 98 96 } 99 97 100 98 @media (prefers-color-scheme: dark) { 101 - a.button:hover { 99 + button:hover { 102 100 background-color: rgba(255, 255, 255, 0.1); 103 101 } 104 102 } ··· 118 116 You are logged in as <strong>{{ .Handle }}</strong> ({{ .DID 119 117 }}), but you are not authorized to access this resource. 120 118 </p> 121 - <a href="/logout" class="button">Log Out</a> 119 + <a href="{{ .LogoutURL }}"><button>Log Out</button></a> 122 120 </main> 123 121 <footer> 124 122 <p class="subtext">
+6 -2
internal/ui/templates/login.html
··· 152 152 <div class="error-alert">{{ .Error }}</div> 153 153 {{ end }} 154 154 155 - <form action="/login" method="POST"> 155 + <form action="{{ .LoginURL }}" method="POST"> 156 156 {{ if .Redirect }} 157 - <input type="hidden" name="redirect_to" value="{{ .Redirect }}" /> 157 + <input 158 + type="hidden" 159 + name="redirect_to" 160 + value="{{ .Redirect }}" 161 + /> 158 162 {{ end }} 159 163 <div class="input-group"> 160 164 <label for="handle">Handle</label>
+5 -3
internal/ui/ui.go
··· 66 66 AppName string 67 67 Error string 68 68 Redirect string 69 + LoginURL string 69 70 } 70 71 71 72 // ForbiddenData is the context for forbidden.html 72 73 type ForbiddenData struct { 73 - AppName string 74 - DID string 75 - Handle string 74 + AppName string 75 + DID string 76 + Handle string 77 + LogoutURL string 76 78 } 77 79 78 80 func (r *Renderer) RenderLogin(w io.Writer, data LoginData) error {
+195 -64
portal.go
··· 3 3 import ( 4 4 "encoding/json" 5 5 "fmt" 6 + "net" 6 7 "net/http" 7 8 "net/url" 9 + "slices" 8 10 "strings" 9 - "time" 11 + "sync" 10 12 11 13 "github.com/caddyserver/caddy/v2" 12 14 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" ··· 19 21 ) 20 22 21 23 func init() { 22 - caddy.RegisterModule(Portal{}) 24 + caddy.RegisterModule(&Portal{}) 23 25 httpcaddyfile.RegisterHandlerDirective("atproto_portal", parseCaddyfilePortal) 24 26 } 25 27 26 28 // Portal is the centralized authentication portal for Path B (Auth Hub). 27 29 type Portal struct { 28 - Name string `json:"name,omitempty"` 29 - Domain string `json:"domain,omitempty"` // Public domain of the portal (e.g. auth.example.com) 30 - UI ui.Config `json:"ui,omitempty"` // Custom UI configuration 30 + Name string `json:"name,omitempty"` 31 + LoginTemplatePath string `json:"login_template,omitempty"` 32 + ForbiddenTemplatePath string `json:"forbidden_template,omitempty"` 33 + CookieName string `json:"cookie_name,omitempty"` 34 + CookieDomain string `json:"cookie_domain,omitempty"` 35 + AllowedRedirectDomains []string `json:"allowed_redirect_domains,omitempty"` 36 + LogoutRedirectURL string `json:"logout_redirect_url,omitempty"` 31 37 32 38 // Paths configuration 33 39 PathPrefix string `json:"path_prefix,omitempty"` 34 40 35 41 // Dependencies 36 - app *App 37 - oauth *oauth.Manager 38 - sessions *session.Manager 39 - renderer *ui.Renderer 40 - logger *zap.Logger 42 + app *App 43 + sessions *session.Manager 44 + oauthManagers map[string]*oauth.Manager 45 + oauthMu sync.RWMutex 46 + renderer *ui.Renderer 47 + logger *zap.Logger 41 48 } 42 49 43 50 // CaddyModule returns the Caddy module information. 44 - func (Portal) CaddyModule() caddy.ModuleInfo { 51 + func (*Portal) CaddyModule() caddy.ModuleInfo { 45 52 return caddy.ModuleInfo{ 46 53 ID: "http.handlers.atproto_portal", 47 54 New: func() caddy.Module { return new(Portal) }, ··· 59 66 } 60 67 p.app = app.(*App) 61 68 62 - // 2. Initialize Session Manager from global app 63 - p.sessions = p.app.SessionManager 69 + // 2. Initialize Session Manager from global app secret 70 + p.sessions = session.NewManager(p.app.CookieSecret, p.CookieName, p.CookieDomain) 64 71 65 72 // 4. Initialize UI Renderer 66 - renderer, err := ui.NewRenderer(p.UI) 73 + renderer, err := ui.NewRenderer(ui.Config{ 74 + LoginTemplatePath: p.LoginTemplatePath, 75 + ForbiddenTemplatePath: p.ForbiddenTemplatePath, 76 + }) 67 77 if err != nil { 68 78 return fmt.Errorf("failed to init ui renderer: %w", err) 69 79 } 70 80 p.renderer = renderer 71 81 72 - // 5. Initialize OAuth Manager 73 - // We need the domain to construct ClientID and CallbackURL. 74 - // If domain is missing, we might defer initialization? No, Manager needs it. 75 - // User must configure 'domain' in Caddyfile for now. 76 - if p.Domain == "" { 77 - return fmt.Errorf("atproto_portal requires 'domain' to be set (e.g. auth.example.com)") 82 + p.oauthManagers = make(map[string]*oauth.Manager) 83 + 84 + // Normalize PathPrefix (ensure it starts with / and doesn't end with /) 85 + if p.PathPrefix != "" { 86 + if !strings.HasPrefix(p.PathPrefix, "/") { 87 + p.PathPrefix = "/" + p.PathPrefix 88 + } 89 + p.PathPrefix = strings.TrimSuffix(p.PathPrefix, "/") 78 90 } 79 91 80 - clientID := fmt.Sprintf("https://%s/.well-known/oauth-client-metadata.json", p.Domain) 81 - callbackURL := fmt.Sprintf("https://%s/callback", p.Domain) 92 + return nil 93 + } 94 + 95 + // getOAuthManager gets or initializes the OAuth manager for a specific host. 96 + func (p *Portal) getOAuthManager(r *http.Request) (*oauth.Manager, error) { 97 + host := getRequestHost(r) 98 + 99 + p.oauthMu.RLock() 100 + mgr, exists := p.oauthManagers[host] 101 + p.oauthMu.RUnlock() 102 + 103 + if exists { 104 + return mgr, nil 105 + } 106 + 107 + p.oauthMu.Lock() 108 + defer p.oauthMu.Unlock() 109 + 110 + // Double-check after acquiring write lock 111 + if mgr, exists := p.oauthManagers[host]; exists { 112 + return mgr, nil 113 + } 114 + 115 + if len(p.oauthManagers) >= p.app.OAuthManagerCacheSize { 116 + // Prevent DoS from unbounded map growth 117 + p.logger.Warn("oauth managers cache full, clearing to prevent OOM") 118 + p.oauthManagers = make(map[string]*oauth.Manager) 119 + } 120 + 121 + scheme := getRequestScheme(r) 122 + 123 + clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", scheme, host) 124 + callbackURL := fmt.Sprintf("%s://%s%s/callback", scheme, host, p.PathPrefix) 82 125 83 126 mgr, err := oauth.NewManager(p.app.Store, clientID, callbackURL) 84 127 if err != nil { 85 - return fmt.Errorf("failed to init oauth manager: %w", err) 128 + return nil, err 86 129 } 87 - p.oauth = mgr 88 130 89 - // Defaults for paths 90 - // If PathPrefix is set (e.g. /auth), endpoints become /auth/login and /auth/logout 91 - // If PathPrefix is empty, endpoints are /login and /logout 92 - 93 - return nil 131 + p.oauthManagers[host] = mgr 132 + return mgr, nil 94 133 } 95 134 96 135 // Validate checks that the configuration is valid. ··· 111 150 return d.ArgErr() 112 151 } 113 152 p.Name = d.Val() 114 - case "domain": 153 + case "login_template": 115 154 if !d.NextArg() { 116 155 return d.ArgErr() 117 156 } 118 - p.Domain = d.Val() 157 + p.LoginTemplatePath = d.Val() 158 + case "forbidden_template": 159 + if !d.NextArg() { 160 + return d.ArgErr() 161 + } 162 + p.ForbiddenTemplatePath = d.Val() 163 + case "cookie_name": 164 + if !d.NextArg() { 165 + return d.ArgErr() 166 + } 167 + p.CookieName = d.Val() 168 + case "cookie_domain": 169 + if !d.NextArg() { 170 + return d.ArgErr() 171 + } 172 + p.CookieDomain = d.Val() 173 + case "allowed_redirect_domains": 174 + p.AllowedRedirectDomains = append(p.AllowedRedirectDomains, d.RemainingArgs()...) 175 + case "logout_redirect_url": 176 + if !d.NextArg() { 177 + return d.ArgErr() 178 + } 179 + p.LogoutRedirectURL = d.Val() 119 180 case "path_prefix": 120 181 if !d.NextArg() { 121 182 return d.ArgErr() 122 183 } 123 184 p.PathPrefix = d.Val() 124 - case "ui": 125 - for nesting := d.Nesting(); d.NextBlock(nesting); { 126 - switch d.Val() { 127 - case "login_template": 128 - if !d.NextArg() { 129 - return d.ArgErr() 130 - } 131 - p.UI.LoginTemplatePath = d.Val() 132 - case "forbidden_template": 133 - if !d.NextArg() { 134 - return d.ArgErr() 135 - } 136 - p.UI.ForbiddenTemplatePath = d.Val() 137 - default: 138 - return d.Errf("unrecognized subdirective '%s'", d.Val()) 139 - } 140 - } 141 185 default: 142 186 return d.Errf("unrecognized subdirective '%s'", d.Val()) 143 187 } ··· 152 196 return &p, err 153 197 } 154 198 199 + func (p *Portal) setSecurityHeaders(w http.ResponseWriter) { 200 + w.Header().Set("Content-Security-Policy", "default-src 'self'; style-src 'self' 'unsafe-inline'") 201 + w.Header().Set("X-Content-Type-Options", "nosniff") 202 + w.Header().Set("X-Frame-Options", "DENY") 203 + } 204 + 155 205 // ServeHTTP implements caddyhttp.MiddlewareHandler. 156 206 func (p *Portal) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { 207 + oauthMgr, err := p.getOAuthManager(r) 208 + if err != nil { 209 + p.logger.Error("failed to get oauth manager", zap.Error(err)) 210 + http.Error(w, "Internal Server Error", http.StatusInternalServerError) 211 + return nil 212 + } 213 + 157 214 // 1. Metadata Endpoint 158 215 if r.URL.Path == "/.well-known/oauth-client-metadata.json" { 159 - meta, err := p.oauth.GetClientMetadata() 216 + meta, err := oauthMgr.GetClientMetadata() 160 217 if err != nil { 161 218 p.logger.Error("failed to get client metadata", zap.Error(err)) 162 219 http.Error(w, "Internal Server Error", http.StatusInternalServerError) ··· 168 225 } 169 226 170 227 // 2. Callback Endpoint 171 - if r.URL.Path == "/callback" { 228 + callbackPath := p.PathPrefix + "/callback" 229 + if r.URL.Path == callbackPath { 172 230 // Process callback 173 231 ctx := r.Context() 174 232 query := r.URL.Query() 175 233 176 - sessionData, handle, err := p.oauth.ProcessCallback(ctx, query) 234 + sessionData, handle, err := oauthMgr.ProcessCallback(ctx, query) 177 235 if err != nil { 178 236 p.logger.Error("oauth callback failed", zap.Error(err)) 237 + p.setSecurityHeaders(w) 179 238 w.Header().Set("Content-Type", "text/html; charset=utf-8") 180 239 w.WriteHeader(http.StatusBadRequest) 181 - _ = p.renderer.RenderLogin(w, ui.LoginData{AppName: p.Name, Error: fmt.Sprintf("Authentication failed: %v", err)}) 240 + _ = p.renderer.RenderLogin(w, ui.LoginData{ 241 + AppName: p.Name, 242 + Error: "Authentication failed. Please try again.", 243 + LoginURL: p.PathPrefix + "/login", 244 + }) 182 245 return nil 183 246 } 184 247 185 - reqDomain := strings.Split(p.Domain, ":")[0] 248 + reqDomain := getRequestHost(r) 186 249 187 250 // Create Session Cookie 188 251 cookie, err := p.sessions.CreateCookie( 189 252 sessionData.AccountDID, 190 253 handle, 191 254 sessionData.SessionID, 192 - 24*7*time.Hour, 193 - reqDomain, 255 + p.app.SessionDuration, 194 256 ) 195 257 if err != nil { 196 258 p.logger.Error("failed to create session cookie", zap.Error(err)) ··· 229 291 isAllowedDomain := false 230 292 if err == nil { 231 293 h := parsed.Host 232 - if h == p.Domain { 294 + // Strip port if present 295 + hostNoPort, _, hostErr := net.SplitHostPort(h) 296 + if hostErr != nil { 297 + hostNoPort = h 298 + } 299 + 300 + if hostNoPort == reqDomain { 233 301 isAllowedDomain = true 302 + } else { 303 + for _, allowed := range p.AllowedRedirectDomains { 304 + if hostNoPort == allowed { 305 + isAllowedDomain = true 306 + break 307 + } 308 + } 234 309 } 235 310 } 236 311 ··· 249 324 // 3. Login Start (Form Action) 250 325 loginPath := p.PathPrefix + "/login" 251 326 logoutPath := p.PathPrefix + "/logout" 327 + forbiddenPath := p.PathPrefix + "/forbidden" 252 328 253 329 if r.URL.Path == loginPath && r.Method == "POST" { 254 330 handle := r.FormValue("handle") ··· 263 339 } 264 340 265 341 // Start Auth Flow 266 - redirectURI, err := p.oauth.StartAuthFlow(r.Context(), handle) 342 + redirectURI, err := oauthMgr.StartAuthFlow(r.Context(), handle) 267 343 if err != nil { 268 344 // Render error on login page 345 + p.setSecurityHeaders(w) 269 346 w.Header().Set("Content-Type", "text/html; charset=utf-8") 270 347 w.WriteHeader(http.StatusBadRequest) 271 348 if renderErr := p.renderer.RenderLogin(w, ui.LoginData{ 272 349 AppName: p.Name, 273 350 Redirect: r.FormValue("redirect_to"), 274 - Error: fmt.Sprintf("Authentication failed: %v", err), 351 + Error: "Authentication failed. Please try again.", 352 + LoginURL: loginPath, 275 353 }); renderErr != nil { 276 354 p.logger.Error("failed to render login error", zap.Error(renderErr)) 277 355 } ··· 298 376 299 377 // 4. Default: Login Page 300 378 if r.URL.Path == loginPath || (loginPath == "/login" && r.URL.Path == "/") { 379 + if sess, err := p.sessions.VerifyCookie(r); err == nil && sess != nil { 380 + // Already logged in 381 + redirectTo := r.URL.Query().Get("redirect_to") 382 + if redirectTo == "" { 383 + redirectTo = "/" 384 + } 385 + http.Redirect(w, r, redirectTo, http.StatusFound) 386 + return nil 387 + } 388 + 389 + p.setSecurityHeaders(w) 301 390 w.Header().Set("Content-Type", "text/html; charset=utf-8") 302 391 if err := p.renderer.RenderLogin(w, ui.LoginData{ 303 392 AppName: p.Name, 304 393 Redirect: r.URL.Query().Get("redirect_to"), 394 + LoginURL: loginPath, 305 395 }); err != nil { 306 396 p.logger.Error("failed to render login page", zap.Error(err)) 307 397 return caddyhttp.Error(http.StatusInternalServerError, err) ··· 314 404 // Invalidate credential if session exists 315 405 sess, err := p.sessions.VerifyCookie(r) 316 406 317 - reqDomain := strings.Split(p.Domain, ":")[0] 407 + reqDomain := getRequestHost(r) 318 408 319 409 if err == nil || err == session.ErrExpired { 320 410 appNameForLog := p.Name ··· 323 413 } 324 414 p.logger.Info(fmt.Sprintf("@%s (did: %s) has logged out for %s", sess.Handle, sess.DID, appNameForLog)) 325 415 326 - if err := p.oauth.Logout(r.Context(), sess.DID, sess.SessionID); err != nil { 416 + if err := oauthMgr.Logout(r.Context(), sess.DID, sess.SessionID); err != nil { 327 417 p.logger.Error("failed to revoke session during logout", zap.Error(err)) 328 418 } 329 419 } 330 420 331 - http.SetCookie(w, p.sessions.ClearCookie(reqDomain)) 421 + http.SetCookie(w, p.sessions.ClearCookie()) 332 422 333 423 // Handle redirect_to for logout 334 424 redirectTo := r.URL.Query().Get("redirect_to") 335 425 if redirectTo == "" { 336 - redirectTo = loginPath 426 + if p.LogoutRedirectURL != "" { 427 + redirectTo = p.LogoutRedirectURL 428 + } else { 429 + redirectTo = loginPath 430 + } 337 431 } else { 338 432 // Basic open redirect mitigation: ensure it's a relative path or matches CookieDomain/Domain 339 433 if strings.HasPrefix(redirectTo, "http://") || strings.HasPrefix(redirectTo, "https://") { ··· 341 435 isAllowedDomain := false 342 436 if err == nil { 343 437 h := parsed.Host 344 - if h == p.Domain { 438 + hostNoPort, _, hostErr := net.SplitHostPort(h) 439 + if hostErr != nil { 440 + hostNoPort = h 441 + } 442 + 443 + if hostNoPort == reqDomain { 345 444 isAllowedDomain = true 445 + } else { 446 + if slices.Contains(p.AllowedRedirectDomains, hostNoPort) { 447 + isAllowedDomain = true 448 + } 346 449 } 347 450 } 348 451 349 452 if !isAllowedDomain { 350 453 p.logger.Warn("blocked cross-domain redirect on logout", zap.String("url", redirectTo)) 351 - redirectTo = loginPath // Fallback to login page 454 + if p.LogoutRedirectURL != "" { 455 + redirectTo = p.LogoutRedirectURL 456 + } else { 457 + redirectTo = loginPath // Fallback to login page 458 + } 352 459 } 353 460 } 354 461 } 355 462 356 463 http.Redirect(w, r, redirectTo, http.StatusFound) 464 + return nil 465 + } 466 + 467 + // 6. Forbidden Page 468 + if r.URL.Path == forbiddenPath { 469 + p.setSecurityHeaders(w) 470 + w.Header().Set("Content-Type", "text/html; charset=utf-8") 471 + w.WriteHeader(http.StatusForbidden) 472 + 473 + var did, handle string 474 + sess, err := p.sessions.VerifyCookie(r) 475 + if err == nil || err == session.ErrExpired { 476 + did = sess.DID 477 + handle = sess.Handle 478 + } 479 + 480 + if err := p.renderer.RenderForbidden(w, ui.ForbiddenData{ 481 + AppName: p.Name, 482 + DID: did, 483 + Handle: handle, 484 + LogoutURL: logoutPath, 485 + }); err != nil { 486 + p.logger.Error("failed to render forbidden page", zap.Error(err)) 487 + } 357 488 return nil 358 489 } 359 490
+24
util.go
··· 1 + package caddyatprotoauth 2 + 3 + import ( 4 + "net" 5 + "net/http" 6 + ) 7 + 8 + // getRequestScheme infers the protocol scheme of the incoming request. 9 + func getRequestScheme(r *http.Request) string { 10 + if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { 11 + return "https" 12 + } 13 + return "http" 14 + } 15 + 16 + // getRequestHost safely extracts the hostname, stripping the port and handling IPv6 literals. 17 + func getRequestHost(r *http.Request) string { 18 + host, _, err := net.SplitHostPort(r.Host) 19 + if err != nil { 20 + // Fallback if there is no port or if it's malformed 21 + return r.Host 22 + } 23 + return host 24 + }