Monorepo for Tangled
tangled.org
1package oauth
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "net/url"
10 "sync"
11 "time"
12
13 comatproto "github.com/bluesky-social/indigo/api/atproto"
14 "github.com/bluesky-social/indigo/atproto/atclient"
15 "github.com/bluesky-social/indigo/atproto/atcrypto"
16 "github.com/bluesky-social/indigo/atproto/auth/oauth"
17 "github.com/bluesky-social/indigo/atproto/syntax"
18 xrpc "github.com/bluesky-social/indigo/xrpc"
19 "github.com/gorilla/sessions"
20 "github.com/hashicorp/golang-lru/v2/expirable"
21 "github.com/posthog/posthog-go"
22 "golang.org/x/sync/singleflight"
23 "tangled.org/core/appview/config"
24 "tangled.org/core/appview/db"
25 "tangled.org/core/idresolver"
26 "tangled.org/core/rbac"
27)
28
29const (
30 sessionCacheSize = 10000
31 sessionCacheTTL = time.Hour
32)
33
34type OAuth struct {
35 ClientApp *oauth.ClientApp
36 SessStore *sessions.CookieStore
37 Config *config.Config
38 JwksUri string
39 ClientName string
40 ClientUri string
41 Posthog posthog.Client
42 Db *db.DB
43 Enforcer *rbac.Enforcer
44 IdResolver *idresolver.Resolver
45 Logger *slog.Logger
46
47 appPasswordSession *AppPasswordSession
48 appPasswordSessionMu sync.Mutex
49
50 sessionCache *expirable.LRU[string, *oauth.ClientSession]
51 sessionSF singleflight.Group
52}
53
54func sessionCacheKey(did syntax.DID, sessionId string) string {
55 return string(did) + ":" + sessionId
56}
57
58func (o *OAuth) resumeSession(ctx context.Context, did syntax.DID, sessionId string) (*oauth.ClientSession, error) {
59 key := sessionCacheKey(did, sessionId)
60 if v, ok := o.sessionCache.Get(key); ok {
61 return v, nil
62 }
63 v, err, _ := o.sessionSF.Do(key, func() (any, error) {
64 if v, ok := o.sessionCache.Get(key); ok {
65 return v, nil
66 }
67 sess, err := o.ClientApp.ResumeSession(ctx, did, sessionId)
68 if err != nil {
69 return nil, err
70 }
71 o.sessionCache.Add(key, sess)
72 return sess, nil
73 })
74 if err != nil {
75 return nil, err
76 }
77 return v.(*oauth.ClientSession), nil
78}
79
80func (o *OAuth) EvictSession(did syntax.DID, sessionId string) {
81 o.sessionCache.Remove(sessionCacheKey(did, sessionId))
82}
83
84func (o *OAuth) HandlePermanentAuthErr(ctx context.Context, did syntax.DID, sessionId string, err error) bool {
85 if !IsPermanentAuthErr(err) {
86 return false
87 }
88 o.EvictSession(did, sessionId)
89 if logoutErr := o.ClientApp.Logout(ctx, did, sessionId); logoutErr != nil {
90 o.Logger.Warn("store logout after permanent auth error failed", "did", did, "err", logoutErr)
91 }
92 return true
93}
94
95func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
96 var oauthConfig oauth.ClientConfig
97 var clientUri string
98 if config.Core.Dev {
99 clientUri = "http://127.0.0.1:3000"
100 callbackUri := clientUri + "/oauth/callback"
101 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes)
102 } else {
103 clientUri = "https://" + config.Core.AppviewHost
104 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
105 callbackUri := clientUri + "/oauth/callback"
106 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes)
107 }
108
109 // configure client secret
110 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
111 if err != nil {
112 return nil, err
113 }
114 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
115 return nil, err
116 }
117
118 jwksUri := clientUri + "/oauth/jwks.json"
119
120 authStore, err := NewRedisStore(&RedisStoreConfig{
121 RedisURL: config.Redis.ToURL(),
122 SessionExpiryDuration: time.Hour * 24 * 90,
123 SessionInactivityDuration: time.Hour * 24 * 14,
124 AuthRequestExpiryDuration: time.Minute * 30,
125 })
126 if err != nil {
127 return nil, err
128 }
129
130 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
131 sessStore.Options.SameSite = http.SameSiteLaxMode
132 sessStore.Options.HttpOnly = true
133 sessStore.Options.Secure = !config.Core.Dev
134
135 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
136 clientApp.Dir = res.Directory()
137 // allow non-public transports in dev mode
138 if config.Core.Dev {
139 clientApp.Resolver.Client.Transport = http.DefaultTransport
140 }
141
142 clientName := config.Core.AppviewName
143
144 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential())
145 return &OAuth{
146 ClientApp: clientApp,
147 Config: config,
148 SessStore: sessStore,
149 JwksUri: jwksUri,
150 ClientName: clientName,
151 ClientUri: clientUri,
152 Posthog: ph,
153 Db: db,
154 Enforcer: enforcer,
155 IdResolver: res,
156 Logger: logger,
157 sessionCache: expirable.NewLRU[string, *oauth.ClientSession](sessionCacheSize, nil, sessionCacheTTL),
158 }, nil
159}
160
161func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
162 userSession, err := o.SessStore.Get(r, SessionName)
163 if err != nil {
164 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err)
165 }
166
167 userSession.Values[SessionDid] = sessData.AccountDID.String()
168 userSession.Values[SessionPds] = sessData.HostURL
169 userSession.Values[SessionId] = sessData.SessionID
170 userSession.Values[SessionAuthenticated] = true
171
172 if err := userSession.Save(r, w); err != nil {
173 return err
174 }
175
176 handle := ""
177 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String())
178 if err == nil && resolved.Handle.String() != "" {
179 handle = resolved.Handle.String()
180 }
181
182 registry := o.GetAccounts(r)
183 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil {
184 return err
185 }
186 return o.saveAccounts(w, r, registry)
187}
188
189func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
190 userSession, err := o.SessStore.Get(r, SessionName)
191 if err != nil {
192 return nil, fmt.Errorf("error getting user session: %w", err)
193 }
194 if userSession.IsNew {
195 return nil, fmt.Errorf("no session available for user")
196 }
197
198 d := userSession.Values[SessionDid].(string)
199 sessDid, err := syntax.ParseDID(d)
200 if err != nil {
201 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
202 }
203
204 sessId := userSession.Values[SessionId].(string)
205
206 clientSess, err := o.resumeSession(r.Context(), sessDid, sessId)
207 if err != nil {
208 return nil, fmt.Errorf("failed to resume session: %w", err)
209 }
210
211 return clientSess, nil
212}
213
214func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
215 userSession, err := o.SessStore.Get(r, SessionName)
216 if err != nil {
217 return fmt.Errorf("error getting user session: %w", err)
218 }
219 if userSession.IsNew {
220 return fmt.Errorf("no session available for user")
221 }
222
223 d := userSession.Values[SessionDid].(string)
224 sessDid, err := syntax.ParseDID(d)
225 if err != nil {
226 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
227 }
228
229 sessId := userSession.Values[SessionId].(string)
230
231 o.EvictSession(sessDid, sessId)
232
233 // delete the session
234 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
235 if err1 != nil {
236 err1 = fmt.Errorf("failed to logout: %w", err1)
237 }
238 o.EvictSession(sessDid, sessId)
239
240 // remove the cookie
241 userSession.Options.MaxAge = -1
242 err2 := o.SessStore.Save(r, w, userSession)
243 if err2 != nil {
244 err2 = fmt.Errorf("failed to save into session store: %w", err2)
245 }
246
247 return errors.Join(err1, err2)
248}
249
250func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
251 registry := o.GetAccounts(r)
252 account := registry.FindAccount(targetDid)
253 if account == nil {
254 return fmt.Errorf("account not found in registry: %s", targetDid)
255 }
256
257 did, err := syntax.ParseDID(targetDid)
258 if err != nil {
259 return fmt.Errorf("invalid DID: %w", err)
260 }
261
262 sess, err := o.resumeSession(r.Context(), did, account.SessionId)
263 if err != nil {
264 registry.RemoveAccount(targetDid)
265 _ = o.saveAccounts(w, r, registry)
266 return fmt.Errorf("session expired for account: %w", err)
267 }
268
269 userSession, err := o.SessStore.Get(r, SessionName)
270 if err != nil {
271 return err
272 }
273
274 userSession.Values[SessionDid] = sess.Data.AccountDID.String()
275 userSession.Values[SessionPds] = sess.Data.HostURL
276 userSession.Values[SessionId] = sess.Data.SessionID
277 userSession.Values[SessionAuthenticated] = true
278
279 return userSession.Save(r, w)
280}
281
282func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
283 registry := o.GetAccounts(r)
284 account := registry.FindAccount(targetDid)
285 if account == nil {
286 return nil
287 }
288
289 did, err := syntax.ParseDID(targetDid)
290 if err == nil {
291 o.EvictSession(did, account.SessionId)
292 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId)
293 o.EvictSession(did, account.SessionId)
294 }
295
296 registry.RemoveAccount(targetDid)
297 return o.saveAccounts(w, r, registry)
298}
299
300func (o *OAuth) GetDid(r *http.Request) string {
301 if u := o.GetMultiAccountUser(r); u != nil {
302 return u.Did
303 }
304
305 return ""
306}
307
308func (o *OAuth) GetDidFromCookie(r *http.Request) syntax.DID {
309 userSession, err := o.SessStore.Get(r, SessionName)
310 if err != nil || userSession.IsNew {
311 return ""
312 }
313 d, ok := userSession.Values[SessionDid].(string)
314 if !ok {
315 return ""
316 }
317 parsed, err := syntax.ParseDID(d)
318 if err != nil {
319 return ""
320 }
321 return parsed
322}
323
324func (o *OAuth) GetSessIdFromCookie(r *http.Request) string {
325 userSession, err := o.SessStore.Get(r, SessionName)
326 if err != nil || userSession.IsNew {
327 return ""
328 }
329 s, ok := userSession.Values[SessionId].(string)
330 if !ok {
331 return ""
332 }
333 return s
334}
335
336func (o *OAuth) AuthorizedClient(r *http.Request) (*atclient.APIClient, error) {
337 session, err := o.ResumeSession(r)
338 if err != nil {
339 return nil, fmt.Errorf("error getting session: %w", err)
340 }
341 return session.APIClient(), nil
342}
343
344// this is a higher level abstraction on ServerGetServiceAuth
345type ServiceClientOpts struct {
346 service string
347 exp int64
348 lxm string
349 dev bool
350 timeout time.Duration
351}
352
353type ServiceClientOpt func(*ServiceClientOpts)
354
355func DefaultServiceClientOpts() ServiceClientOpts {
356 return ServiceClientOpts{
357 timeout: time.Second * 5,
358 }
359}
360
361func WithService(service string) ServiceClientOpt {
362 return func(s *ServiceClientOpts) {
363 s.service = service
364 }
365}
366
367// Specify the Duration in seconds for the expiry of this token
368//
369// The time of expiry is calculated as time.Now().Unix() + exp
370func WithExp(exp int64) ServiceClientOpt {
371 return func(s *ServiceClientOpts) {
372 s.exp = time.Now().Unix() + exp
373 }
374}
375
376func WithLxm(lxm string) ServiceClientOpt {
377 return func(s *ServiceClientOpts) {
378 s.lxm = lxm
379 }
380}
381
382func WithDev(dev bool) ServiceClientOpt {
383 return func(s *ServiceClientOpts) {
384 s.dev = dev
385 }
386}
387
388func WithTimeout(timeout time.Duration) ServiceClientOpt {
389 return func(s *ServiceClientOpts) {
390 s.timeout = timeout
391 }
392}
393
394func (s *ServiceClientOpts) Audience() string {
395 return fmt.Sprintf("did:web:%s", s.service)
396}
397
398func (s *ServiceClientOpts) Host() string {
399 scheme := "https://"
400 if s.dev {
401 scheme = "http://"
402 }
403
404 return scheme + s.service
405}
406
407func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
408 opts := DefaultServiceClientOpts()
409 for _, o := range os {
410 o(&opts)
411 }
412
413 client, err := o.AuthorizedClient(r)
414 if err != nil {
415 return nil, err
416 }
417
418 // force expiry to atleast 60 seconds in the future
419 sixty := time.Now().Unix() + 60
420 if opts.exp < sixty {
421 opts.exp = sixty
422 }
423
424 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
425 if err != nil {
426 return nil, err
427 }
428
429 return &xrpc.Client{
430 Auth: &xrpc.AuthInfo{
431 AccessJwt: resp.Token,
432 },
433 Host: opts.Host(),
434 Client: &http.Client{
435 Timeout: opts.timeout,
436 },
437 }, nil
438}
439
440func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) {
441 parsedDid, err := syntax.ParseDID(did)
442 if err != nil {
443 return "", fmt.Errorf("invalid DID: %w", err)
444 }
445
446 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier())
447 if err != nil {
448 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err)
449 }
450
451 host := ident.PDSEndpoint()
452 if host == "" {
453 return "", fmt.Errorf("identity does not link to an atproto host (PDS)")
454 }
455
456 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host)
457 if err != nil {
458 return "", fmt.Errorf("resolving auth server: %w", err)
459 }
460
461 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL)
462 if err != nil {
463 return "", fmt.Errorf("fetching auth server metadata: %w", err)
464 }
465
466 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes))
467 scopes = append(scopes, TangledScopes...)
468 scopes = append(scopes, extraScopes...)
469
470 loginHint := did
471 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() {
472 loginHint = ident.Handle.String()
473 }
474
475 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint)
476 if err != nil {
477 return "", fmt.Errorf("auth request failed: %w", err)
478 }
479
480 info.AccountDID = &parsedDid
481 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info)
482
483 if err := o.SetAuthReturn(w, r, returnURL); err != nil {
484 return "", fmt.Errorf("failed to set auth return: %w", err)
485 }
486
487 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s",
488 authserverMeta.AuthorizationEndpoint,
489 url.QueryEscape(o.ClientApp.Config.ClientID),
490 url.QueryEscape(info.RequestURI),
491 )
492
493 return redirectURL, nil
494}