Caddy module to require at-proto authentication and restrict routes to DIDs
2

Configure Feed

Select the types of activity you want to include in your feed.

at main 10 kB View raw
1package caddyatprotoauth 2 3import ( 4 "context" 5 "fmt" 6 "net/http" 7 "net/url" 8 "strings" 9 "sync" 10 11 "strconv" 12 13 "github.com/caddyserver/caddy/v2" 14 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" 15 "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" 16 "github.com/caddyserver/caddy/v2/modules/caddyhttp" 17 "go.uber.org/zap" 18 "tangled.org/vvill.dev/caddy-atproto-auth/internal/oauth" 19 "tangled.org/vvill.dev/caddy-atproto-auth/internal/resolver" 20 "tangled.org/vvill.dev/caddy-atproto-auth/internal/session" 21) 22 23func init() { 24 caddy.RegisterModule(&Gate{}) 25 httpcaddyfile.RegisterHandlerDirective("atproto_gate", parseCaddyfileGate) 26} 27 28// Gate acts as a middleware that guards endpoints 29// and validates the session cookie. 30type Gate struct { 31 Allow []string `json:"allow,omitempty"` 32 PortalURL string `json:"portal_url,omitempty"` // URL of the auth portal (e.g. https://auth.example.com or /auth) 33 CookieName string `json:"cookie_name,omitempty"` 34 CookieDomain string `json:"cookie_domain,omitempty"` 35 ResolveHandlesOnRequest bool `json:"resolve_handles_on_request,omitempty"` 36 37 // Dependencies 38 app *App 39 sessions *session.Manager 40 oauthManagers map[string]*oauth.Manager 41 oauthMu sync.RWMutex 42 logger *zap.Logger 43 resolvedDIDs []string 44 handleCache sync.Map 45 resolver *resolver.Resolver 46} 47 48// CaddyModule returns the Caddy module information. 49func (*Gate) CaddyModule() caddy.ModuleInfo { 50 return caddy.ModuleInfo{ 51 ID: "http.handlers.atproto_gate", 52 New: func() caddy.Module { return new(Gate) }, 53 } 54} 55 56// Provision sets up the module. 57func (g *Gate) Provision(ctx caddy.Context) error { 58 g.logger = ctx.Logger() 59 60 // Get Global App 61 app, err := ctx.App("atproto") 62 if err != nil { 63 return fmt.Errorf("getting atproto app: %w", err) 64 } 65 g.app = app.(*App) 66 67 // Initialize Session Manager (using global secret) 68 g.sessions = session.NewManager(g.app.CookieSecret, g.CookieName, g.CookieDomain) 69 70 g.oauthManagers = make(map[string]*oauth.Manager) 71 72 // Normalize PortalURL (ensure it doesn't end with /) 73 if g.PortalURL == "" { 74 g.PortalURL = "/" 75 } else if len(g.PortalURL) > 0 && g.PortalURL[len(g.PortalURL)-1] == '/' && g.PortalURL != "/" { 76 g.PortalURL = g.PortalURL[:len(g.PortalURL)-1] 77 } 78 79 // Initialize OAuth Manager for transparent refresh 80 // We derive the ClientID from the PortalURL if it's absolute, 81 // or from the Host header at request time if it's relative. 82 // For now, if it's absolute, we can init the OAuth manager immediately. 83 if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") { 84 parsedURL, err := url.Parse(g.PortalURL) 85 if err == nil { 86 clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", parsedURL.Scheme, parsedURL.Host) 87 mgr, err := oauth.NewManager(g.app.Store, clientID, "", g.app.AllowPrivateCIDRs) 88 if err != nil { 89 return fmt.Errorf("failed to init oauth manager for refresh: %w", err) 90 } 91 g.oauthManagers[parsedURL.Host] = mgr 92 } 93 } 94 95 // Pre-resolve allowed handles to DIDs 96 if len(g.app.AllowPrivateCIDRs) > 0 { 97 g.resolver = resolver.NewWithAllowedCIDRs(g.app.AllowPrivateCIDRs) 98 } else { 99 g.resolver = resolver.New() 100 } 101 102 g.resolvedDIDs = make([]string, 0, len(g.Allow)) 103 ctxResolver := context.Background() // Use background context for boot-time resolution 104 for _, allow := range g.Allow { 105 if allow == "*" { 106 g.resolvedDIDs = append(g.resolvedDIDs, "*") 107 continue 108 } 109 110 // If it's already a DID, append it directly 111 if strings.HasPrefix(allow, "did:") { 112 g.resolvedDIDs = append(g.resolvedDIDs, allow) 113 continue 114 } 115 116 // Treat as handle and resolve 117 did, err := g.resolver.ResolveIdentifier(ctxResolver, allow) 118 if err != nil { 119 g.logger.Warn("failed to resolve handle during provision", zap.String("handle", allow), zap.Error(err)) 120 } else { 121 g.resolvedDIDs = append(g.resolvedDIDs, did) 122 g.handleCache.Store(allow, did) 123 } 124 } 125 126 return nil 127} 128 129// Validate checks that the configuration is valid. 130func (g *Gate) Validate() error { 131 return nil 132} 133 134// UnmarshalCaddyfile implements caddyfile.Unmarshaler. 135func (g *Gate) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { 136 for d.Next() { 137 for nesting := d.Nesting(); d.NextBlock(nesting); { 138 switch d.Val() { 139 case "allow": 140 g.Allow = append(g.Allow, d.RemainingArgs()...) 141 case "cookie_name": 142 if !d.NextArg() { 143 return d.ArgErr() 144 } 145 g.CookieName = d.Val() 146 case "cookie_domain": 147 if !d.NextArg() { 148 return d.ArgErr() 149 } 150 g.CookieDomain = d.Val() 151 case "portal_url": 152 if !d.NextArg() { 153 return d.ArgErr() 154 } 155 g.PortalURL = d.Val() 156 case "resolve_handles_on_request": 157 if d.NextArg() { 158 val, err := strconv.ParseBool(d.Val()) 159 if err != nil { 160 return d.Errf("invalid boolean value '%s'", d.Val()) 161 } 162 g.ResolveHandlesOnRequest = val 163 } else { 164 g.ResolveHandlesOnRequest = true 165 } 166 default: 167 return d.Errf("unrecognized subdirective '%s'", d.Val()) 168 } 169 } 170 } 171 return nil 172} 173 174// parseCaddyfileGate parses the atproto_gate directive from a Caddyfile. 175func parseCaddyfileGate(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { 176 var g Gate 177 err := g.UnmarshalCaddyfile(h.Dispenser) 178 return &g, err 179} 180 181// getOAuthManager gets or initializes the OAuth manager for a specific host. 182func (g *Gate) getOAuthManager(r *http.Request) (*oauth.Manager, error) { 183 host := getRequestHost(r) 184 185 // If PortalURL is absolute, we already cached it under parsedURL.Host 186 if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") { 187 parsedURL, err := url.Parse(g.PortalURL) 188 if err == nil { 189 host = parsedURL.Host 190 } 191 } 192 193 g.oauthMu.RLock() 194 mgr, exists := g.oauthManagers[host] 195 g.oauthMu.RUnlock() 196 197 if exists { 198 return mgr, nil 199 } 200 201 g.oauthMu.Lock() 202 defer g.oauthMu.Unlock() 203 204 if mgr, exists := g.oauthManagers[host]; exists { 205 return mgr, nil 206 } 207 208 if len(g.oauthManagers) >= g.app.OAuthManagerCacheSize { 209 // Prevent DoS from unbounded map growth 210 g.logger.Warn("oauth managers cache full, clearing to prevent OOM") 211 g.oauthManagers = make(map[string]*oauth.Manager) 212 } 213 214 scheme := getRequestScheme(r) 215 216 clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", scheme, host) 217 mgr, err := oauth.NewManager(g.app.Store, clientID, "", g.app.AllowPrivateCIDRs) 218 if err != nil { 219 return nil, err 220 } 221 222 g.oauthManagers[host] = mgr 223 return mgr, nil 224} 225 226// ServeHTTP implements caddyhttp.MiddlewareHandler. 227func (g *Gate) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { 228 // 1. Verify stateless cookie here 229 sess, err := g.sessions.VerifyCookie(r) 230 if err == session.ErrExpired { 231 // Attempt transparent refresh if we are in a mode that supports it. 232 // We need an OAuth manager to refresh. 233 oauthMgr, _ := g.getOAuthManager(r) 234 235 if oauthMgr != nil && sess != nil { 236 clientSession, errRefresh := oauthMgr.ResumeSession(r.Context(), sess.DID, sess.SessionID) 237 if errRefresh == nil { 238 // Refresh tokens 239 if _, errRefresh := clientSession.RefreshTokens(r.Context()); errRefresh == nil { 240 // Success! Update cookie. 241 242 // Resolve fresh handle, fallback to old if unavailable 243 handle := sess.Handle 244 ident, errDir := oauthMgr.App.Dir.LookupDID(r.Context(), clientSession.Data.AccountDID) 245 if errDir == nil && ident != nil { 246 handle = ident.Handle.String() 247 } 248 249 cookie, errCookie := g.sessions.CreateCookie( 250 clientSession.Data.AccountDID, 251 handle, 252 clientSession.Data.SessionID, 253 g.app.SessionDuration, 254 ) 255 if errCookie == nil { 256 http.SetCookie(w, cookie) 257 r.AddCookie(cookie) 258 259 // Update local session for authorization checks below 260 sess.DID = clientSession.Data.AccountDID.String() 261 sess.Handle = handle 262 err = nil // clear expiration error to proceed 263 } 264 } 265 } 266 // If refresh failed, err remains ErrExpired, falling through to redirect 267 } 268 } 269 270 if err == nil { 271 // Session valid! 272 // Check authorization against allowlist 273 allowed := false 274 for _, allow := range g.resolvedDIDs { 275 if allow == "*" || allow == sess.DID { 276 allowed = true 277 break 278 } 279 } 280 281 if !allowed && g.ResolveHandlesOnRequest { 282 // Try dynamic resolution for handles that weren't in the pre-resolved list 283 // or might have changed. 284 for _, allow := range g.Allow { 285 if allow != "*" && !strings.HasPrefix(allow, "did:") { 286 if cachedDID, ok := g.handleCache.Load(allow); ok && cachedDID == sess.DID { 287 allowed = true 288 break 289 } 290 // Resolve and cache 291 did, resErr := g.resolver.ResolveIdentifier(r.Context(), allow) 292 if resErr == nil { 293 g.handleCache.Store(allow, did) 294 if did == sess.DID { 295 allowed = true 296 break 297 } 298 } 299 } 300 } 301 } 302 303 if allowed { 304 // Inject headers 305 r.Header.Set("X-Atproto-Did", sess.DID) 306 r.Header.Set("X-Atproto-Handle", sess.Handle) 307 return next.ServeHTTP(w, r) 308 } 309 310 // Authenticated but not authorized 311 if g.PortalURL != "" { 312 scheme := getRequestScheme(r) 313 host := r.Host 314 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI()) 315 316 portalURL := g.PortalURL 317 if portalURL == "/" { 318 portalURL = "" 319 } 320 portalForbidden := fmt.Sprintf("%s/forbidden?redirect_to=%s", portalURL, url.QueryEscape(currentURL)) 321 http.Redirect(w, r, portalForbidden, http.StatusFound) 322 return nil 323 } 324 325 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 326 w.WriteHeader(http.StatusForbidden) 327 w.Write([]byte("Forbidden")) 328 return nil 329 } 330 331 // 2. If invalid/missing, initiate redirect to Portal 332 if g.PortalURL != "" { 333 // Construct redirect URL: ${PortalURL}/login?redirect_to=${CurrentURL} 334 scheme := getRequestScheme(r) 335 host := r.Host 336 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI()) 337 338 portalURL := g.PortalURL 339 if portalURL == "/" { 340 portalURL = "" 341 } 342 portalLogin := fmt.Sprintf("%s/login?redirect_to=%s", portalURL, url.QueryEscape(currentURL)) 343 http.Redirect(w, r, portalLogin, http.StatusFound) 344 return nil 345 } 346 347 // Fallback: 401 348 return caddyhttp.Error(http.StatusUnauthorized, fmt.Errorf("unauthorized")) 349} 350 351// Interface guards 352var ( 353 _ caddy.Provisioner = (*Gate)(nil) 354 _ caddy.Validator = (*Gate)(nil) 355 _ caddyhttp.MiddlewareHandler = (*Gate)(nil) 356 _ caddyfile.Unmarshaler = (*Gate)(nil) 357)