Caddy module to require at-proto authentication and restrict routes to DIDs
1package caddyatprotoauth
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/url"
8 "strings"
9 "sync"
10
11 "strconv"
12
13 "github.com/caddyserver/caddy/v2"
14 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
15 "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
16 "github.com/caddyserver/caddy/v2/modules/caddyhttp"
17 "go.uber.org/zap"
18 "tangled.org/vvill.dev/caddy-atproto-auth/internal/oauth"
19 "tangled.org/vvill.dev/caddy-atproto-auth/internal/resolver"
20 "tangled.org/vvill.dev/caddy-atproto-auth/internal/session"
21)
22
23func init() {
24 caddy.RegisterModule(&Gate{})
25 httpcaddyfile.RegisterHandlerDirective("atproto_gate", parseCaddyfileGate)
26}
27
28// Gate acts as a middleware that guards endpoints
29// and validates the session cookie.
30type Gate struct {
31 Allow []string `json:"allow,omitempty"`
32 PortalURL string `json:"portal_url,omitempty"` // URL of the auth portal (e.g. https://auth.example.com or /auth)
33 CookieName string `json:"cookie_name,omitempty"`
34 CookieDomain string `json:"cookie_domain,omitempty"`
35 ResolveHandlesOnRequest bool `json:"resolve_handles_on_request,omitempty"`
36
37 // Dependencies
38 app *App
39 sessions *session.Manager
40 oauthManagers map[string]*oauth.Manager
41 oauthMu sync.RWMutex
42 logger *zap.Logger
43 resolvedDIDs []string
44 handleCache sync.Map
45 resolver *resolver.Resolver
46}
47
48// CaddyModule returns the Caddy module information.
49func (*Gate) CaddyModule() caddy.ModuleInfo {
50 return caddy.ModuleInfo{
51 ID: "http.handlers.atproto_gate",
52 New: func() caddy.Module { return new(Gate) },
53 }
54}
55
56// Provision sets up the module.
57func (g *Gate) Provision(ctx caddy.Context) error {
58 g.logger = ctx.Logger()
59
60 // Get Global App
61 app, err := ctx.App("atproto")
62 if err != nil {
63 return fmt.Errorf("getting atproto app: %w", err)
64 }
65 g.app = app.(*App)
66
67 // Initialize Session Manager (using global secret)
68 g.sessions = session.NewManager(g.app.CookieSecret, g.CookieName, g.CookieDomain)
69
70 g.oauthManagers = make(map[string]*oauth.Manager)
71
72 // Normalize PortalURL (ensure it doesn't end with /)
73 if g.PortalURL == "" {
74 g.PortalURL = "/"
75 } else if len(g.PortalURL) > 0 && g.PortalURL[len(g.PortalURL)-1] == '/' && g.PortalURL != "/" {
76 g.PortalURL = g.PortalURL[:len(g.PortalURL)-1]
77 }
78
79 // Initialize OAuth Manager for transparent refresh
80 // We derive the ClientID from the PortalURL if it's absolute,
81 // or from the Host header at request time if it's relative.
82 // For now, if it's absolute, we can init the OAuth manager immediately.
83 if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") {
84 parsedURL, err := url.Parse(g.PortalURL)
85 if err == nil {
86 clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", parsedURL.Scheme, parsedURL.Host)
87 mgr, err := oauth.NewManager(g.app.Store, clientID, "", g.app.AllowPrivateCIDRs)
88 if err != nil {
89 return fmt.Errorf("failed to init oauth manager for refresh: %w", err)
90 }
91 g.oauthManagers[parsedURL.Host] = mgr
92 }
93 }
94
95 // Pre-resolve allowed handles to DIDs
96 if len(g.app.AllowPrivateCIDRs) > 0 {
97 g.resolver = resolver.NewWithAllowedCIDRs(g.app.AllowPrivateCIDRs)
98 } else {
99 g.resolver = resolver.New()
100 }
101
102 g.resolvedDIDs = make([]string, 0, len(g.Allow))
103 ctxResolver := context.Background() // Use background context for boot-time resolution
104 for _, allow := range g.Allow {
105 if allow == "*" {
106 g.resolvedDIDs = append(g.resolvedDIDs, "*")
107 continue
108 }
109
110 // If it's already a DID, append it directly
111 if strings.HasPrefix(allow, "did:") {
112 g.resolvedDIDs = append(g.resolvedDIDs, allow)
113 continue
114 }
115
116 // Treat as handle and resolve
117 did, err := g.resolver.ResolveIdentifier(ctxResolver, allow)
118 if err != nil {
119 g.logger.Warn("failed to resolve handle during provision", zap.String("handle", allow), zap.Error(err))
120 } else {
121 g.resolvedDIDs = append(g.resolvedDIDs, did)
122 g.handleCache.Store(allow, did)
123 }
124 }
125
126 return nil
127}
128
129// Validate checks that the configuration is valid.
130func (g *Gate) Validate() error {
131 return nil
132}
133
134// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
135func (g *Gate) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
136 for d.Next() {
137 for nesting := d.Nesting(); d.NextBlock(nesting); {
138 switch d.Val() {
139 case "allow":
140 g.Allow = append(g.Allow, d.RemainingArgs()...)
141 case "cookie_name":
142 if !d.NextArg() {
143 return d.ArgErr()
144 }
145 g.CookieName = d.Val()
146 case "cookie_domain":
147 if !d.NextArg() {
148 return d.ArgErr()
149 }
150 g.CookieDomain = d.Val()
151 case "portal_url":
152 if !d.NextArg() {
153 return d.ArgErr()
154 }
155 g.PortalURL = d.Val()
156 case "resolve_handles_on_request":
157 if d.NextArg() {
158 val, err := strconv.ParseBool(d.Val())
159 if err != nil {
160 return d.Errf("invalid boolean value '%s'", d.Val())
161 }
162 g.ResolveHandlesOnRequest = val
163 } else {
164 g.ResolveHandlesOnRequest = true
165 }
166 default:
167 return d.Errf("unrecognized subdirective '%s'", d.Val())
168 }
169 }
170 }
171 return nil
172}
173
174// parseCaddyfileGate parses the atproto_gate directive from a Caddyfile.
175func parseCaddyfileGate(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
176 var g Gate
177 err := g.UnmarshalCaddyfile(h.Dispenser)
178 return &g, err
179}
180
181// getOAuthManager gets or initializes the OAuth manager for a specific host.
182func (g *Gate) getOAuthManager(r *http.Request) (*oauth.Manager, error) {
183 host := getRequestHost(r)
184
185 // If PortalURL is absolute, we already cached it under parsedURL.Host
186 if strings.HasPrefix(g.PortalURL, "http://") || strings.HasPrefix(g.PortalURL, "https://") {
187 parsedURL, err := url.Parse(g.PortalURL)
188 if err == nil {
189 host = parsedURL.Host
190 }
191 }
192
193 g.oauthMu.RLock()
194 mgr, exists := g.oauthManagers[host]
195 g.oauthMu.RUnlock()
196
197 if exists {
198 return mgr, nil
199 }
200
201 g.oauthMu.Lock()
202 defer g.oauthMu.Unlock()
203
204 if mgr, exists := g.oauthManagers[host]; exists {
205 return mgr, nil
206 }
207
208 if len(g.oauthManagers) >= g.app.OAuthManagerCacheSize {
209 // Prevent DoS from unbounded map growth
210 g.logger.Warn("oauth managers cache full, clearing to prevent OOM")
211 g.oauthManagers = make(map[string]*oauth.Manager)
212 }
213
214 scheme := getRequestScheme(r)
215
216 clientID := fmt.Sprintf("%s://%s/.well-known/oauth-client-metadata.json", scheme, host)
217 mgr, err := oauth.NewManager(g.app.Store, clientID, "", g.app.AllowPrivateCIDRs)
218 if err != nil {
219 return nil, err
220 }
221
222 g.oauthManagers[host] = mgr
223 return mgr, nil
224}
225
226// ServeHTTP implements caddyhttp.MiddlewareHandler.
227func (g *Gate) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
228 // 1. Verify stateless cookie here
229 sess, err := g.sessions.VerifyCookie(r)
230 if err == session.ErrExpired {
231 // Attempt transparent refresh if we are in a mode that supports it.
232 // We need an OAuth manager to refresh.
233 oauthMgr, _ := g.getOAuthManager(r)
234
235 if oauthMgr != nil && sess != nil {
236 clientSession, errRefresh := oauthMgr.ResumeSession(r.Context(), sess.DID, sess.SessionID)
237 if errRefresh == nil {
238 // Refresh tokens
239 if _, errRefresh := clientSession.RefreshTokens(r.Context()); errRefresh == nil {
240 // Success! Update cookie.
241
242 // Resolve fresh handle, fallback to old if unavailable
243 handle := sess.Handle
244 ident, errDir := oauthMgr.App.Dir.LookupDID(r.Context(), clientSession.Data.AccountDID)
245 if errDir == nil && ident != nil {
246 handle = ident.Handle.String()
247 }
248
249 cookie, errCookie := g.sessions.CreateCookie(
250 clientSession.Data.AccountDID,
251 handle,
252 clientSession.Data.SessionID,
253 g.app.SessionDuration,
254 )
255 if errCookie == nil {
256 http.SetCookie(w, cookie)
257 r.AddCookie(cookie)
258
259 // Update local session for authorization checks below
260 sess.DID = clientSession.Data.AccountDID.String()
261 sess.Handle = handle
262 err = nil // clear expiration error to proceed
263 }
264 }
265 }
266 // If refresh failed, err remains ErrExpired, falling through to redirect
267 }
268 }
269
270 if err == nil {
271 // Session valid!
272 // Check authorization against allowlist
273 allowed := false
274 for _, allow := range g.resolvedDIDs {
275 if allow == "*" || allow == sess.DID {
276 allowed = true
277 break
278 }
279 }
280
281 if !allowed && g.ResolveHandlesOnRequest {
282 // Try dynamic resolution for handles that weren't in the pre-resolved list
283 // or might have changed.
284 for _, allow := range g.Allow {
285 if allow != "*" && !strings.HasPrefix(allow, "did:") {
286 if cachedDID, ok := g.handleCache.Load(allow); ok && cachedDID == sess.DID {
287 allowed = true
288 break
289 }
290 // Resolve and cache
291 did, resErr := g.resolver.ResolveIdentifier(r.Context(), allow)
292 if resErr == nil {
293 g.handleCache.Store(allow, did)
294 if did == sess.DID {
295 allowed = true
296 break
297 }
298 }
299 }
300 }
301 }
302
303 if allowed {
304 // Inject headers
305 r.Header.Set("X-Atproto-Did", sess.DID)
306 r.Header.Set("X-Atproto-Handle", sess.Handle)
307 return next.ServeHTTP(w, r)
308 }
309
310 // Authenticated but not authorized
311 if g.PortalURL != "" {
312 scheme := getRequestScheme(r)
313 host := r.Host
314 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI())
315
316 portalURL := g.PortalURL
317 if portalURL == "/" {
318 portalURL = ""
319 }
320 portalForbidden := fmt.Sprintf("%s/forbidden?redirect_to=%s", portalURL, url.QueryEscape(currentURL))
321 http.Redirect(w, r, portalForbidden, http.StatusFound)
322 return nil
323 }
324
325 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
326 w.WriteHeader(http.StatusForbidden)
327 w.Write([]byte("Forbidden"))
328 return nil
329 }
330
331 // 2. If invalid/missing, initiate redirect to Portal
332 if g.PortalURL != "" {
333 // Construct redirect URL: ${PortalURL}/login?redirect_to=${CurrentURL}
334 scheme := getRequestScheme(r)
335 host := r.Host
336 currentURL := fmt.Sprintf("%s://%s%s", scheme, host, r.URL.RequestURI())
337
338 portalURL := g.PortalURL
339 if portalURL == "/" {
340 portalURL = ""
341 }
342 portalLogin := fmt.Sprintf("%s/login?redirect_to=%s", portalURL, url.QueryEscape(currentURL))
343 http.Redirect(w, r, portalLogin, http.StatusFound)
344 return nil
345 }
346
347 // Fallback: 401
348 return caddyhttp.Error(http.StatusUnauthorized, fmt.Errorf("unauthorized"))
349}
350
351// Interface guards
352var (
353 _ caddy.Provisioner = (*Gate)(nil)
354 _ caddy.Validator = (*Gate)(nil)
355 _ caddyhttp.MiddlewareHandler = (*Gate)(nil)
356 _ caddyfile.Unmarshaler = (*Gate)(nil)
357)