Caddy module to require at-proto authentication and restrict routes to DIDs
1package caddyatprotoauth
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/url"
8 "strings"
9 "sync"
10
11 "strconv"
12
13 "github.com/caddyserver/caddy/v2"
14 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
15 "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
16 "github.com/caddyserver/caddy/v2/modules/caddyhttp"
17 "go.uber.org/zap"
18 "tangled.org/vvill.dev/caddy-atproto-auth/internal/oauth"
19 "tangled.org/vvill.dev/caddy-atproto-auth/internal/resolver"
20 "tangled.org/vvill.dev/caddy-atproto-auth/internal/session"
21)
22
23func init() {
24 caddy.RegisterModule(&Gate{})
25 httpcaddyfile.RegisterHandlerDirective("atproto_gate", parseCaddyfileGate)
26}
27
28// Gate acts as a middleware that guards endpoints
29// and validates the session cookie.
30type Gate struct {
31 Allow []string `json:"allow,omitempty"`
32 PortalURL string `json:"portal_url,omitempty"` // URL of the auth portal (e.g. https://auth.example.com or /auth)
33 CookieName string `json:"cookie_name,omitempty"`
34 CookieDomain string `json:"cookie_domain,omitempty"`
35 ResolveHandlesOnRequest bool `json:"resolve_handles_on_request,omitempty"`
36
37 // Dependencies
38 app *App
39 sessions *session.Manager
40 oauthManagers map[string]*oauth.Manager
41 oauthMu sync.RWMutex
42 logger *zap.Logger
43 resolvedDIDs []string
44 handleCache sync.Map
45 resolver *resolver.Resolver
46}
47
48// CaddyModule returns the Caddy module information.
49func (*Gate) CaddyModule() caddy.ModuleInfo {
50 return caddy.ModuleInfo{
51 ID: "http.handlers.atproto_gate",
52 New: func() caddy.Module { return new(Gate) },
53 }
54}
55
56// Provision sets up the module.
57func (g *Gate) Provision(ctx caddy.Context) error {
58 g.logger = ctx.Logger()
59
60 // Get Global App
61 app, err := ctx.App("atproto")
62 if err != nil {
63 return fmt.Errorf("getting atproto app: %w", err)
64 }
65 g.app = app.(*App)
66
67 // Initialize Session Manager (using global secret)
68 g.sessions = session.NewManager(g.app.CookieSecret, g.CookieName, g.CookieDomain)
69
70 g.oauthManagers = make(map[string]*oauth.Manager)
71
72 // Normalize PortalURL (ensure it doesn't end with /)
73 if g.PortalURL == "" {
74 g.PortalURL = "/"
75 } else if len(g.PortalURL) > 0 && g.PortalURL[len(g.PortalURL)-1] == '/' && g.PortalURL != "/" {
76 g.PortalURL = g.PortalURL[:len(g.PortalURL)-1]
77 }
78
79 // Initialize OAuth Manager for transparent refresh
80 // We derive the ClientID from the PortalURL if it's absolute,
81 // or from the Host header at request time if it's relative.
82 // For now, if it's absolute, we can init the OAuth manager immediately.
83 if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") {
84 parsedURL, err := url.Parse(g.PortalURL)
85 if err == nil {
86 clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", parsedURL.Scheme, parsedURL.Host)
87 mgr, err := oauth.NewManager(g.app.Store, clientID, "")
88 if err != nil {
89 return fmt.Errorf("failed to init oauth manager for refresh: %w", err)
90 }
91 g.oauthManagers[parsedURL.Host] = mgr
92 }
93 }
94
95 // Pre-resolve allowed handles to DIDs
96 g.resolver = resolver.New()
97
98 g.resolvedDIDs = make([]string, 0, len(g.Allow))
99 ctxResolver := context.Background() // Use background context for boot-time resolution
100 for _, allow := range g.Allow {
101 if allow == "*" {
102 g.resolvedDIDs = append(g.resolvedDIDs, "*")
103 continue
104 }
105
106 // If it's already a DID, append it directly
107 if strings.HasPrefix(allow, "did:") {
108 g.resolvedDIDs = append(g.resolvedDIDs, allow)
109 continue
110 }
111
112 // Treat as handle and resolve
113 did, err := g.resolver.ResolveIdentifier(ctxResolver, allow)
114 if err != nil {
115 g.logger.Warn("failed to resolve handle during provision", zap.String("handle", allow), zap.Error(err))
116 } else {
117 g.resolvedDIDs = append(g.resolvedDIDs, did)
118 g.handleCache.Store(allow, did)
119 }
120 }
121
122 return nil
123}
124
125// Validate checks that the configuration is valid.
126func (g *Gate) Validate() error {
127 return nil
128}
129
130// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
131func (g *Gate) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
132 for d.Next() {
133 for nesting := d.Nesting(); d.NextBlock(nesting); {
134 switch d.Val() {
135 case "allow":
136 g.Allow = append(g.Allow, d.RemainingArgs()...)
137 case "cookie_name":
138 if !d.NextArg() {
139 return d.ArgErr()
140 }
141 g.CookieName = d.Val()
142 case "cookie_domain":
143 if !d.NextArg() {
144 return d.ArgErr()
145 }
146 g.CookieDomain = d.Val()
147 case "portal_url":
148 if !d.NextArg() {
149 return d.ArgErr()
150 }
151 g.PortalURL = d.Val()
152 case "resolve_handles_on_request":
153 if d.NextArg() {
154 val, err := strconv.ParseBool(d.Val())
155 if err != nil {
156 return d.Errf("invalid boolean value '%s'", d.Val())
157 }
158 g.ResolveHandlesOnRequest = val
159 } else {
160 g.ResolveHandlesOnRequest = true
161 }
162 default:
163 return d.Errf("unrecognized subdirective '%s'", d.Val())
164 }
165 }
166 }
167 return nil
168}
169
170// parseCaddyfileGate parses the atproto_gate directive from a Caddyfile.
171func parseCaddyfileGate(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
172 var g Gate
173 err := g.UnmarshalCaddyfile(h.Dispenser)
174 return &g, err
175}
176
177// getOAuthManager gets or initializes the OAuth manager for a specific host.
178func (g *Gate) getOAuthManager(r *http.Request) (*oauth.Manager, error) {
179 host := getRequestHost(r)
180
181 // If PortalURL is absolute, we already cached it under parsedURL.Host
182 if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") {
183 parsedURL, err := url.Parse(g.PortalURL)
184 if err == nil {
185 host = parsedURL.Host
186 }
187 }
188
189 g.oauthMu.RLock()
190 mgr, exists := g.oauthManagers[host]
191 g.oauthMu.RUnlock()
192
193 if exists {
194 return mgr, nil
195 }
196
197 g.oauthMu.Lock()
198 defer g.oauthMu.Unlock()
199
200 if mgr, exists := g.oauthManagers[host]; exists {
201 return mgr, nil
202 }
203
204 if len(g.oauthManagers) >= g.app.OAuthManagerCacheSize {
205 // Prevent DoS from unbounded map growth
206 g.logger.Warn("oauth managers cache full, clearing to prevent OOM")
207 g.oauthManagers = make(map[string]*oauth.Manager)
208 }
209
210 scheme := getRequestScheme(r)
211
212 clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", scheme, host)
213 mgr, err := oauth.NewManager(g.app.Store, clientID, "")
214 if err != nil {
215 return nil, err
216 }
217
218 g.oauthManagers[host] = mgr
219 return mgr, nil
220}
221
222// ServeHTTP implements caddyhttp.MiddlewareHandler.
223func (g *Gate) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
224 // 1. Verify stateless cookie here
225 sess, err := g.sessions.VerifyCookie(r)
226 if err == session.ErrExpired {
227 // Attempt transparent refresh if we are in a mode that supports it.
228 // We need an OAuth manager to refresh.
229 oauthMgr, _ := g.getOAuthManager(r)
230
231 if oauthMgr != nil && sess != nil {
232 clientSession, errRefresh := oauthMgr.ResumeSession(r.Context(), sess.DID, sess.SessionID)
233 if errRefresh == nil {
234 // Refresh tokens
235 if _, errRefresh := clientSession.RefreshTokens(r.Context()); errRefresh == nil {
236 // Success! Update cookie.
237
238 // Resolve fresh handle, fallback to old if unavailable
239 handle := sess.Handle
240 ident, errDir := oauthMgr.App.Dir.LookupDID(r.Context(), clientSession.Data.AccountDID)
241 if errDir == nil && ident != nil {
242 handle = ident.Handle.String()
243 }
244
245 cookie, errCookie := g.sessions.CreateCookie(
246 clientSession.Data.AccountDID,
247 handle,
248 clientSession.Data.SessionID,
249 g.app.SessionDuration,
250 )
251 if errCookie == nil {
252 http.SetCookie(w, cookie)
253 r.AddCookie(cookie)
254
255 // Update local session for authorization checks below
256 sess.DID = clientSession.Data.AccountDID.String()
257 sess.Handle = handle
258 err = nil // clear expiration error to proceed
259 }
260 }
261 }
262 // If refresh failed, err remains ErrExpired, falling through to redirect
263 }
264 }
265
266 if err == nil {
267 // Session valid!
268 // Check authorization against allowlist
269 allowed := false
270 for _, allow := range g.resolvedDIDs {
271 if allow == "*" || allow == sess.DID {
272 allowed = true
273 break
274 }
275 }
276
277 if !allowed && g.ResolveHandlesOnRequest {
278 // Try dynamic resolution for handles that weren't in the pre-resolved list
279 // or might have changed.
280 for _, allow := range g.Allow {
281 if allow != "*" && !strings.HasPrefix(allow, "did:") {
282 if cachedDID, ok := g.handleCache.Load(allow); ok && cachedDID == sess.DID {
283 allowed = true
284 break
285 }
286 // Resolve and cache
287 did, resErr := g.resolver.ResolveIdentifier(r.Context(), allow)
288 if resErr == nil {
289 g.handleCache.Store(allow, did)
290 if did == sess.DID {
291 allowed = true
292 break
293 }
294 }
295 }
296 }
297 }
298
299 if allowed {
300 // Inject headers
301 r.Header.Set("X-Atproto-Did", sess.DID)
302 r.Header.Set("X-Atproto-Handle", sess.Handle)
303 return next.ServeHTTP(w, r)
304 }
305
306 // Authenticated but not authorized
307 if g.PortalURL != "" {
308 scheme := getRequestScheme(r)
309 host := r.Host
310 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI())
311
312 portalURL := g.PortalURL
313 if portalURL == "/" {
314 portalURL = ""
315 }
316 portalForbidden := fmt.Sprintf("%s/forbidden?redirect_to=%s", portalURL, url.QueryEscape(currentURL))
317 http.Redirect(w, r, portalForbidden, http.StatusFound)
318 return nil
319 }
320
321 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
322 w.WriteHeader(http.StatusForbidden)
323 w.Write([]byte("Forbidden"))
324 return nil
325 }
326
327 // 2. If invalid/missing, initiate redirect to Portal
328 if g.PortalURL != "" {
329 // Construct redirect URL: ${PortalURL}/login?redirect_to=${CurrentURL}
330 scheme := getRequestScheme(r)
331 host := r.Host
332 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI())
333
334 portalURL := g.PortalURL
335 if portalURL == "/" {
336 portalURL = ""
337 }
338 portalLogin := fmt.Sprintf("%s/login?redirect_to=%s", portalURL, url.QueryEscape(currentURL))
339 http.Redirect(w, r, portalLogin, http.StatusFound)
340 return nil
341 }
342
343 // Fallback: 401
344 return caddyhttp.Error(http.StatusUnauthorized, fmt.Errorf("unauthorized"))
345}
346
347// Interface guards
348var (
349 _ caddy.Provisioner = (*Gate)(nil)
350 _ caddy.Validator = (*Gate)(nil)
351 _ caddyhttp.MiddlewareHandler = (*Gate)(nil)
352 _ caddyfile.Unmarshaler = (*Gate)(nil)
353)