Monorepo for Tangled
tangled.org
1package oauth
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "net/url"
10 "sync"
11 "time"
12
13 comatproto "github.com/bluesky-social/indigo/api/atproto"
14 "github.com/bluesky-social/indigo/atproto/atclient"
15 "github.com/bluesky-social/indigo/atproto/atcrypto"
16 "github.com/bluesky-social/indigo/atproto/auth/oauth"
17 "github.com/bluesky-social/indigo/atproto/syntax"
18 xrpc "github.com/bluesky-social/indigo/xrpc"
19 "github.com/gorilla/sessions"
20 "github.com/hashicorp/golang-lru/v2/expirable"
21 "github.com/posthog/posthog-go"
22 "golang.org/x/sync/singleflight"
23 "tangled.org/core/appview/config"
24 "tangled.org/core/appview/db"
25 "tangled.org/core/idresolver"
26 "tangled.org/core/rbac"
27 "tangled.org/core/xrpc/serviceauth"
28)
29
30const (
31 sessionCacheSize = 10000
32 sessionCacheTTL = time.Hour
33)
34
35type KnotMembership interface {
36 IsKnotMember(ctx context.Context, host, userDid string) bool
37 InvalidateMembers(host string)
38}
39
40type OAuth struct {
41 ClientApp *oauth.ClientApp
42 SessStore *sessions.CookieStore
43 Config *config.Config
44 JwksUri string
45 ClientName string
46 ClientUri string
47 Posthog posthog.Client
48 Db *db.DB
49 Enforcer *rbac.Enforcer
50 Acl KnotMembership
51 IdResolver *idresolver.Resolver
52 Logger *slog.Logger
53
54 appPasswordSession *AppPasswordSession
55 appPasswordSessionMu sync.Mutex
56
57 sessionCache *expirable.LRU[string, *oauth.ClientSession]
58 sessionSF singleflight.Group
59}
60
61func sessionCacheKey(did syntax.DID, sessionId string) string {
62 return string(did) + ":" + sessionId
63}
64
65func (o *OAuth) resumeSession(ctx context.Context, did syntax.DID, sessionId string) (*oauth.ClientSession, error) {
66 key := sessionCacheKey(did, sessionId)
67 if v, ok := o.sessionCache.Get(key); ok {
68 return v, nil
69 }
70 v, err, _ := o.sessionSF.Do(key, func() (any, error) {
71 if v, ok := o.sessionCache.Get(key); ok {
72 return v, nil
73 }
74 sess, err := o.ClientApp.ResumeSession(ctx, did, sessionId)
75 if err != nil {
76 return nil, err
77 }
78 o.sessionCache.Add(key, sess)
79 return sess, nil
80 })
81 if err != nil {
82 return nil, err
83 }
84 return v.(*oauth.ClientSession), nil
85}
86
87func (o *OAuth) EvictSession(did syntax.DID, sessionId string) {
88 o.sessionCache.Remove(sessionCacheKey(did, sessionId))
89}
90
91func (o *OAuth) HandlePermanentAuthErr(ctx context.Context, did syntax.DID, sessionId string, err error) bool {
92 if !IsPermanentAuthErr(err) {
93 return false
94 }
95 o.EvictSession(did, sessionId)
96 if logoutErr := o.ClientApp.Logout(ctx, did, sessionId); logoutErr != nil {
97 o.Logger.Warn("store logout after permanent auth error failed", "did", did, "err", logoutErr)
98 }
99 return true
100}
101
102func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, acl KnotMembership, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
103 var oauthConfig oauth.ClientConfig
104 var clientUri string
105 if config.Core.Dev {
106 clientUri = "http://127.0.0.1:3000"
107 callbackUri := clientUri + "/oauth/callback"
108 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes)
109 } else {
110 clientUri = "https://" + config.Core.AppviewHost
111 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
112 callbackUri := clientUri + "/oauth/callback"
113 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes)
114 }
115
116 // configure client secret
117 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
118 if err != nil {
119 return nil, err
120 }
121 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
122 return nil, err
123 }
124
125 jwksUri := clientUri + "/oauth/jwks.json"
126
127 authStore, err := NewRedisStore(&RedisStoreConfig{
128 RedisURL: config.Redis.ToURL(),
129 SessionExpiryDuration: time.Hour * 24 * 90,
130 SessionInactivityDuration: time.Hour * 24 * 14,
131 AuthRequestExpiryDuration: time.Minute * 30,
132 })
133 if err != nil {
134 return nil, err
135 }
136
137 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
138 sessStore.Options.SameSite = http.SameSiteLaxMode
139 sessStore.Options.HttpOnly = true
140 sessStore.Options.Secure = !config.Core.Dev
141
142 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
143 clientApp.Dir = res.Directory()
144 // allow non-public transports in dev mode
145 if config.Core.Dev {
146 clientApp.Resolver.Client.Transport = http.DefaultTransport
147 }
148
149 clientName := config.Core.AppviewName
150
151 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential())
152 return &OAuth{
153 ClientApp: clientApp,
154 Config: config,
155 SessStore: sessStore,
156 JwksUri: jwksUri,
157 ClientName: clientName,
158 ClientUri: clientUri,
159 Posthog: ph,
160 Db: db,
161 Enforcer: enforcer,
162 Acl: acl,
163 IdResolver: res,
164 Logger: logger,
165 sessionCache: expirable.NewLRU[string, *oauth.ClientSession](sessionCacheSize, nil, sessionCacheTTL),
166 }, nil
167}
168
169func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
170 userSession, err := o.SessStore.Get(r, SessionName)
171 if err != nil {
172 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err)
173 }
174
175 userSession.Values[SessionDid] = sessData.AccountDID.String()
176 userSession.Values[SessionPds] = sessData.HostURL
177 userSession.Values[SessionId] = sessData.SessionID
178 userSession.Values[SessionAuthenticated] = true
179
180 if err := userSession.Save(r, w); err != nil {
181 return err
182 }
183
184 handle := ""
185 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String())
186 if err == nil && resolved.Handle.String() != "" {
187 handle = resolved.Handle.String()
188 }
189
190 registry := o.GetAccounts(r)
191 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil {
192 return err
193 }
194 return o.saveAccounts(w, r, registry)
195}
196
197func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
198 userSession, err := o.SessStore.Get(r, SessionName)
199 if err != nil {
200 return nil, fmt.Errorf("error getting user session: %w", err)
201 }
202 if userSession.IsNew {
203 return nil, fmt.Errorf("no session available for user")
204 }
205
206 d := userSession.Values[SessionDid].(string)
207 sessDid, err := syntax.ParseDID(d)
208 if err != nil {
209 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
210 }
211
212 sessId := userSession.Values[SessionId].(string)
213
214 clientSess, err := o.resumeSession(r.Context(), sessDid, sessId)
215 if err != nil {
216 return nil, fmt.Errorf("failed to resume session: %w", err)
217 }
218
219 return clientSess, nil
220}
221
222func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
223 userSession, err := o.SessStore.Get(r, SessionName)
224 if err != nil {
225 return fmt.Errorf("error getting user session: %w", err)
226 }
227 if userSession.IsNew {
228 return fmt.Errorf("no session available for user")
229 }
230
231 d := userSession.Values[SessionDid].(string)
232 sessDid, err := syntax.ParseDID(d)
233 if err != nil {
234 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
235 }
236
237 sessId := userSession.Values[SessionId].(string)
238
239 o.EvictSession(sessDid, sessId)
240
241 // delete the session
242 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
243 if err1 != nil {
244 err1 = fmt.Errorf("failed to logout: %w", err1)
245 }
246 o.EvictSession(sessDid, sessId)
247
248 // remove the cookie
249 userSession.Options.MaxAge = -1
250 err2 := o.SessStore.Save(r, w, userSession)
251 if err2 != nil {
252 err2 = fmt.Errorf("failed to save into session store: %w", err2)
253 }
254
255 return errors.Join(err1, err2)
256}
257
258func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
259 registry := o.GetAccounts(r)
260 account := registry.FindAccount(targetDid)
261 if account == nil {
262 return fmt.Errorf("account not found in registry: %s", targetDid)
263 }
264
265 did, err := syntax.ParseDID(targetDid)
266 if err != nil {
267 return fmt.Errorf("invalid DID: %w", err)
268 }
269
270 sess, err := o.resumeSession(r.Context(), did, account.SessionId)
271 if err != nil {
272 registry.RemoveAccount(targetDid)
273 _ = o.saveAccounts(w, r, registry)
274 return fmt.Errorf("session expired for account: %w", err)
275 }
276
277 userSession, err := o.SessStore.Get(r, SessionName)
278 if err != nil {
279 return err
280 }
281
282 userSession.Values[SessionDid] = sess.Data.AccountDID.String()
283 userSession.Values[SessionPds] = sess.Data.HostURL
284 userSession.Values[SessionId] = sess.Data.SessionID
285 userSession.Values[SessionAuthenticated] = true
286
287 return userSession.Save(r, w)
288}
289
290func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
291 registry := o.GetAccounts(r)
292 account := registry.FindAccount(targetDid)
293 if account == nil {
294 return nil
295 }
296
297 did, err := syntax.ParseDID(targetDid)
298 if err == nil {
299 o.EvictSession(did, account.SessionId)
300 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId)
301 o.EvictSession(did, account.SessionId)
302 }
303
304 registry.RemoveAccount(targetDid)
305 return o.saveAccounts(w, r, registry)
306}
307
308func (o *OAuth) GetDid(r *http.Request) string {
309 if u := o.GetMultiAccountUser(r); u != nil {
310 return u.Did
311 }
312
313 return ""
314}
315
316func (o *OAuth) GetDidFromCookie(r *http.Request) syntax.DID {
317 userSession, err := o.SessStore.Get(r, SessionName)
318 if err != nil || userSession.IsNew {
319 return ""
320 }
321 d, ok := userSession.Values[SessionDid].(string)
322 if !ok {
323 return ""
324 }
325 parsed, err := syntax.ParseDID(d)
326 if err != nil {
327 return ""
328 }
329 return parsed
330}
331
332func (o *OAuth) GetSessIdFromCookie(r *http.Request) string {
333 userSession, err := o.SessStore.Get(r, SessionName)
334 if err != nil || userSession.IsNew {
335 return ""
336 }
337 s, ok := userSession.Values[SessionId].(string)
338 if !ok {
339 return ""
340 }
341 return s
342}
343
344func (o *OAuth) AuthorizedClient(r *http.Request) (*atclient.APIClient, error) {
345 session, err := o.ResumeSession(r)
346 if err != nil {
347 return nil, fmt.Errorf("error getting session: %w", err)
348 }
349 return session.APIClient(), nil
350}
351
352// this is a higher level abstraction on ServerGetServiceAuth
353type ServiceClientOpts struct {
354 service string
355 exp int64
356 lxm string
357 dev bool
358 timeout time.Duration
359}
360
361type ServiceClientOpt func(*ServiceClientOpts)
362
363func DefaultServiceClientOpts() ServiceClientOpts {
364 return ServiceClientOpts{
365 timeout: time.Second * 5,
366 }
367}
368
369func WithService(service string) ServiceClientOpt {
370 return func(s *ServiceClientOpts) {
371 s.service = service
372 }
373}
374
375// Specify the Duration in seconds for the expiry of this token
376//
377// The time of expiry is calculated as time.Now().Unix() + exp
378func WithExp(exp int64) ServiceClientOpt {
379 return func(s *ServiceClientOpts) {
380 s.exp = time.Now().Unix() + exp
381 }
382}
383
384func WithLxm(lxm string) ServiceClientOpt {
385 return func(s *ServiceClientOpts) {
386 s.lxm = lxm
387 }
388}
389
390func WithDev(dev bool) ServiceClientOpt {
391 return func(s *ServiceClientOpts) {
392 s.dev = dev
393 }
394}
395
396func WithTimeout(timeout time.Duration) ServiceClientOpt {
397 return func(s *ServiceClientOpts) {
398 s.timeout = timeout
399 }
400}
401
402func (s *ServiceClientOpts) Audience() string {
403 return serviceauth.DidWeb(s.service).String()
404}
405
406func (s *ServiceClientOpts) Host() string {
407 scheme := "https://"
408 if s.dev {
409 scheme = "http://"
410 }
411
412 return scheme + s.service
413}
414
415func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
416 client, err := o.AuthorizedClient(r)
417 if err != nil {
418 return nil, err
419 }
420
421 opts := DefaultServiceClientOpts()
422 for _, o := range os {
423 o(&opts)
424 }
425
426 // force expiry to atleast 60 seconds in the future
427 sixty := time.Now().Unix() + 60
428 if opts.exp < sixty {
429 opts.exp = sixty
430 }
431
432 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
433 if err != nil {
434 return nil, err
435 }
436
437 return &xrpc.Client{
438 Auth: &xrpc.AuthInfo{
439 AccessJwt: resp.Token,
440 },
441 Host: opts.Host(),
442 Client: &http.Client{
443 Timeout: opts.timeout,
444 },
445 }, nil
446}
447
448func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) {
449 parsedDid, err := syntax.ParseDID(did)
450 if err != nil {
451 return "", fmt.Errorf("invalid DID: %w", err)
452 }
453
454 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier())
455 if err != nil {
456 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err)
457 }
458
459 host := ident.PDSEndpoint()
460 if host == "" {
461 return "", fmt.Errorf("identity does not link to an atproto host (PDS)")
462 }
463
464 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host)
465 if err != nil {
466 return "", fmt.Errorf("resolving auth server: %w", err)
467 }
468
469 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL)
470 if err != nil {
471 return "", fmt.Errorf("fetching auth server metadata: %w", err)
472 }
473
474 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes))
475 scopes = append(scopes, TangledScopes...)
476 scopes = append(scopes, extraScopes...)
477
478 loginHint := did
479 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() {
480 loginHint = ident.Handle.String()
481 }
482
483 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint)
484 if err != nil {
485 return "", fmt.Errorf("auth request failed: %w", err)
486 }
487
488 info.AccountDID = &parsedDid
489 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info)
490
491 if err := o.SetAuthReturn(w, r, returnURL); err != nil {
492 return "", fmt.Errorf("failed to set auth return: %w", err)
493 }
494
495 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s",
496 authserverMeta.AuthorizationEndpoint,
497 url.QueryEscape(o.ClientApp.Config.ClientID),
498 url.QueryEscape(info.RequestURI),
499 )
500
501 return redirectURL, nil
502}