Caddy module to require at-proto authentication and restrict routes to DIDs
3

Configure Feed

Select the types of activity you want to include in your feed.

1package caddyatprotoauth 2 3import ( 4 "context" 5 "fmt" 6 "net/http" 7 "net/url" 8 "strings" 9 "sync" 10 11 "strconv" 12 13 "github.com/caddyserver/caddy/v2" 14 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" 15 "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" 16 "github.com/caddyserver/caddy/v2/modules/caddyhttp" 17 "go.uber.org/zap" 18 "tangled.org/vvill.dev/caddy-atproto-auth/internal/oauth" 19 "tangled.org/vvill.dev/caddy-atproto-auth/internal/resolver" 20 "tangled.org/vvill.dev/caddy-atproto-auth/internal/session" 21) 22 23func init() { 24 caddy.RegisterModule(&Gate{}) 25 httpcaddyfile.RegisterHandlerDirective("atproto_gate", parseCaddyfileGate) 26} 27 28// Gate acts as a middleware that guards endpoints 29// and validates the session cookie. 30type Gate struct { 31 Allow []string `json:"allow,omitempty"` 32 PortalURL string `json:"portal_url,omitempty"` // URL of the auth portal (e.g. https://auth.example.com or /auth) 33 CookieName string `json:"cookie_name,omitempty"` 34 CookieDomain string `json:"cookie_domain,omitempty"` 35 ResolveHandlesOnRequest bool `json:"resolve_handles_on_request,omitempty"` 36 37 // Dependencies 38 app *App 39 sessions *session.Manager 40 oauthManagers map[string]*oauth.Manager 41 oauthMu sync.RWMutex 42 logger *zap.Logger 43 resolvedDIDs []string 44 handleCache sync.Map 45 resolver *resolver.Resolver 46} 47 48// CaddyModule returns the Caddy module information. 49func (*Gate) CaddyModule() caddy.ModuleInfo { 50 return caddy.ModuleInfo{ 51 ID: "http.handlers.atproto_gate", 52 New: func() caddy.Module { return new(Gate) }, 53 } 54} 55 56// Provision sets up the module. 57func (g *Gate) Provision(ctx caddy.Context) error { 58 g.logger = ctx.Logger() 59 60 // Get Global App 61 app, err := ctx.App("atproto") 62 if err != nil { 63 return fmt.Errorf("getting atproto app: %w", err) 64 } 65 g.app = app.(*App) 66 67 // Initialize Session Manager (using global secret) 68 g.sessions = session.NewManager(g.app.CookieSecret, g.CookieName, g.CookieDomain) 69 70 g.oauthManagers = make(map[string]*oauth.Manager) 71 72 // Normalize PortalURL (ensure it doesn't end with /) 73 if g.PortalURL == "" { 74 g.PortalURL = "/" 75 } else if len(g.PortalURL) > 0 && g.PortalURL[len(g.PortalURL)-1] == '/' && g.PortalURL != "/" { 76 g.PortalURL = g.PortalURL[:len(g.PortalURL)-1] 77 } 78 79 // Initialize OAuth Manager for transparent refresh 80 // We derive the ClientID from the PortalURL if it's absolute, 81 // or from the Host header at request time if it's relative. 82 // For now, if it's absolute, we can init the OAuth manager immediately. 83 if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") { 84 parsedURL, err := url.Parse(g.PortalURL) 85 if err == nil { 86 clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", parsedURL.Scheme, parsedURL.Host) 87 mgr, err := oauth.NewManager(g.app.Store, clientID, "") 88 if err != nil { 89 return fmt.Errorf("failed to init oauth manager for refresh: %w", err) 90 } 91 g.oauthManagers[parsedURL.Host] = mgr 92 } 93 } 94 95 // Pre-resolve allowed handles to DIDs 96 g.resolver = resolver.New() 97 98 g.resolvedDIDs = make([]string, 0, len(g.Allow)) 99 ctxResolver := context.Background() // Use background context for boot-time resolution 100 for _, allow := range g.Allow { 101 if allow == "*" { 102 g.resolvedDIDs = append(g.resolvedDIDs, "*") 103 continue 104 } 105 106 // If it's already a DID, append it directly 107 if strings.HasPrefix(allow, "did:") { 108 g.resolvedDIDs = append(g.resolvedDIDs, allow) 109 continue 110 } 111 112 // Treat as handle and resolve 113 did, err := g.resolver.ResolveIdentifier(ctxResolver, allow) 114 if err != nil { 115 g.logger.Warn("failed to resolve handle during provision", zap.String("handle", allow), zap.Error(err)) 116 } else { 117 g.resolvedDIDs = append(g.resolvedDIDs, did) 118 g.handleCache.Store(allow, did) 119 } 120 } 121 122 return nil 123} 124 125// Validate checks that the configuration is valid. 126func (g *Gate) Validate() error { 127 return nil 128} 129 130// UnmarshalCaddyfile implements caddyfile.Unmarshaler. 131func (g *Gate) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { 132 for d.Next() { 133 for nesting := d.Nesting(); d.NextBlock(nesting); { 134 switch d.Val() { 135 case "allow": 136 g.Allow = append(g.Allow, d.RemainingArgs()...) 137 case "cookie_name": 138 if !d.NextArg() { 139 return d.ArgErr() 140 } 141 g.CookieName = d.Val() 142 case "cookie_domain": 143 if !d.NextArg() { 144 return d.ArgErr() 145 } 146 g.CookieDomain = d.Val() 147 case "portal_url": 148 if !d.NextArg() { 149 return d.ArgErr() 150 } 151 g.PortalURL = d.Val() 152 case "resolve_handles_on_request": 153 if d.NextArg() { 154 val, err := strconv.ParseBool(d.Val()) 155 if err != nil { 156 return d.Errf("invalid boolean value '%s'", d.Val()) 157 } 158 g.ResolveHandlesOnRequest = val 159 } else { 160 g.ResolveHandlesOnRequest = true 161 } 162 default: 163 return d.Errf("unrecognized subdirective '%s'", d.Val()) 164 } 165 } 166 } 167 return nil 168} 169 170// parseCaddyfileGate parses the atproto_gate directive from a Caddyfile. 171func parseCaddyfileGate(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { 172 var g Gate 173 err := g.UnmarshalCaddyfile(h.Dispenser) 174 return &g, err 175} 176 177// getOAuthManager gets or initializes the OAuth manager for a specific host. 178func (g *Gate) getOAuthManager(r *http.Request) (*oauth.Manager, error) { 179 host := getRequestHost(r) 180 181 // If PortalURL is absolute, we already cached it under parsedURL.Host 182 if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") { 183 parsedURL, err := url.Parse(g.PortalURL) 184 if err == nil { 185 host = parsedURL.Host 186 } 187 } 188 189 g.oauthMu.RLock() 190 mgr, exists := g.oauthManagers[host] 191 g.oauthMu.RUnlock() 192 193 if exists { 194 return mgr, nil 195 } 196 197 g.oauthMu.Lock() 198 defer g.oauthMu.Unlock() 199 200 if mgr, exists := g.oauthManagers[host]; exists { 201 return mgr, nil 202 } 203 204 if len(g.oauthManagers) >= g.app.OAuthManagerCacheSize { 205 // Prevent DoS from unbounded map growth 206 g.logger.Warn("oauth managers cache full, clearing to prevent OOM") 207 g.oauthManagers = make(map[string]*oauth.Manager) 208 } 209 210 scheme := getRequestScheme(r) 211 212 clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", scheme, host) 213 mgr, err := oauth.NewManager(g.app.Store, clientID, "") 214 if err != nil { 215 return nil, err 216 } 217 218 g.oauthManagers[host] = mgr 219 return mgr, nil 220} 221 222// ServeHTTP implements caddyhttp.MiddlewareHandler. 223func (g *Gate) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { 224 // 1. Verify stateless cookie here 225 sess, err := g.sessions.VerifyCookie(r) 226 if err == session.ErrExpired { 227 // Attempt transparent refresh if we are in a mode that supports it. 228 // We need an OAuth manager to refresh. 229 oauthMgr, _ := g.getOAuthManager(r) 230 231 if oauthMgr != nil && sess != nil { 232 clientSession, errRefresh := oauthMgr.ResumeSession(r.Context(), sess.DID, sess.SessionID) 233 if errRefresh == nil { 234 // Refresh tokens 235 if _, errRefresh := clientSession.RefreshTokens(r.Context()); errRefresh == nil { 236 // Success! Update cookie. 237 238 // Resolve fresh handle, fallback to old if unavailable 239 handle := sess.Handle 240 ident, errDir := oauthMgr.App.Dir.LookupDID(r.Context(), clientSession.Data.AccountDID) 241 if errDir == nil && ident != nil { 242 handle = ident.Handle.String() 243 } 244 245 cookie, errCookie := g.sessions.CreateCookie( 246 clientSession.Data.AccountDID, 247 handle, 248 clientSession.Data.SessionID, 249 g.app.SessionDuration, 250 ) 251 if errCookie == nil { 252 http.SetCookie(w, cookie) 253 r.AddCookie(cookie) 254 255 // Update local session for authorization checks below 256 sess.DID = clientSession.Data.AccountDID.String() 257 sess.Handle = handle 258 err = nil // clear expiration error to proceed 259 } 260 } 261 } 262 // If refresh failed, err remains ErrExpired, falling through to redirect 263 } 264 } 265 266 if err == nil { 267 // Session valid! 268 // Check authorization against allowlist 269 allowed := false 270 for _, allow := range g.resolvedDIDs { 271 if allow == "*" || allow == sess.DID { 272 allowed = true 273 break 274 } 275 } 276 277 if !allowed && g.ResolveHandlesOnRequest { 278 // Try dynamic resolution for handles that weren't in the pre-resolved list 279 // or might have changed. 280 for _, allow := range g.Allow { 281 if allow != "*" && !strings.HasPrefix(allow, "did:") { 282 if cachedDID, ok := g.handleCache.Load(allow); ok && cachedDID == sess.DID { 283 allowed = true 284 break 285 } 286 // Resolve and cache 287 did, resErr := g.resolver.ResolveIdentifier(r.Context(), allow) 288 if resErr == nil { 289 g.handleCache.Store(allow, did) 290 if did == sess.DID { 291 allowed = true 292 break 293 } 294 } 295 } 296 } 297 } 298 299 if allowed { 300 // Inject headers 301 r.Header.Set("X-Atproto-Did", sess.DID) 302 r.Header.Set("X-Atproto-Handle", sess.Handle) 303 return next.ServeHTTP(w, r) 304 } 305 306 // Authenticated but not authorized 307 if g.PortalURL != "" { 308 scheme := getRequestScheme(r) 309 host := r.Host 310 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI()) 311 312 portalURL := g.PortalURL 313 if portalURL == "/" { 314 portalURL = "" 315 } 316 portalForbidden := fmt.Sprintf("%s/forbidden?redirect_to=%s", portalURL, url.QueryEscape(currentURL)) 317 http.Redirect(w, r, portalForbidden, http.StatusFound) 318 return nil 319 } 320 321 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 322 w.WriteHeader(http.StatusForbidden) 323 w.Write([]byte("Forbidden")) 324 return nil 325 } 326 327 // 2. If invalid/missing, initiate redirect to Portal 328 if g.PortalURL != "" { 329 // Construct redirect URL: ${PortalURL}/login?redirect_to=${CurrentURL} 330 scheme := getRequestScheme(r) 331 host := r.Host 332 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI()) 333 334 portalURL := g.PortalURL 335 if portalURL == "/" { 336 portalURL = "" 337 } 338 portalLogin := fmt.Sprintf("%s/login?redirect_to=%s", portalURL, url.QueryEscape(currentURL)) 339 http.Redirect(w, r, portalLogin, http.StatusFound) 340 return nil 341 } 342 343 // Fallback: 401 344 return caddyhttp.Error(http.StatusUnauthorized, fmt.Errorf("unauthorized")) 345} 346 347// Interface guards 348var ( 349 _ caddy.Provisioner = (*Gate)(nil) 350 _ caddy.Validator = (*Gate)(nil) 351 _ caddyhttp.MiddlewareHandler = (*Gate)(nil) 352 _ caddyfile.Unmarshaler = (*Gate)(nil) 353)