Monorepo for Tangled tangled.org
6

Configure Feed

Select the types of activity you want to include in your feed.

1package oauth 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "log/slog" 8 "net/http" 9 "net/url" 10 "sync" 11 "time" 12 13 comatproto "github.com/bluesky-social/indigo/api/atproto" 14 "github.com/bluesky-social/indigo/atproto/atclient" 15 "github.com/bluesky-social/indigo/atproto/atcrypto" 16 "github.com/bluesky-social/indigo/atproto/auth/oauth" 17 "github.com/bluesky-social/indigo/atproto/syntax" 18 xrpc "github.com/bluesky-social/indigo/xrpc" 19 "github.com/gorilla/sessions" 20 "github.com/hashicorp/golang-lru/v2/expirable" 21 "github.com/posthog/posthog-go" 22 "golang.org/x/sync/singleflight" 23 "tangled.org/core/appview/config" 24 "tangled.org/core/appview/db" 25 "tangled.org/core/idresolver" 26 "tangled.org/core/rbac" 27) 28 29const ( 30 sessionCacheSize = 10000 31 sessionCacheTTL = time.Hour 32) 33 34type OAuth struct { 35 ClientApp *oauth.ClientApp 36 SessStore *sessions.CookieStore 37 Config *config.Config 38 JwksUri string 39 ClientName string 40 ClientUri string 41 Posthog posthog.Client 42 Db *db.DB 43 Enforcer *rbac.Enforcer 44 IdResolver *idresolver.Resolver 45 Logger *slog.Logger 46 47 appPasswordSession *AppPasswordSession 48 appPasswordSessionMu sync.Mutex 49 50 sessionCache *expirable.LRU[string, *oauth.ClientSession] 51 sessionSF singleflight.Group 52} 53 54func sessionCacheKey(did syntax.DID, sessionId string) string { 55 return string(did) + ":" + sessionId 56} 57 58func (o *OAuth) resumeSession(ctx context.Context, did syntax.DID, sessionId string) (*oauth.ClientSession, error) { 59 key := sessionCacheKey(did, sessionId) 60 if v, ok := o.sessionCache.Get(key); ok { 61 return v, nil 62 } 63 v, err, _ := o.sessionSF.Do(key, func() (any, error) { 64 if v, ok := o.sessionCache.Get(key); ok { 65 return v, nil 66 } 67 sess, err := o.ClientApp.ResumeSession(ctx, did, sessionId) 68 if err != nil { 69 return nil, err 70 } 71 o.sessionCache.Add(key, sess) 72 return sess, nil 73 }) 74 if err != nil { 75 return nil, err 76 } 77 return v.(*oauth.ClientSession), nil 78} 79 80func (o *OAuth) EvictSession(did syntax.DID, sessionId string) { 81 o.sessionCache.Remove(sessionCacheKey(did, sessionId)) 82} 83 84func (o *OAuth) HandlePermanentAuthErr(ctx context.Context, did syntax.DID, sessionId string, err error) bool { 85 if !IsPermanentAuthErr(err) { 86 return false 87 } 88 o.EvictSession(did, sessionId) 89 if logoutErr := o.ClientApp.Logout(ctx, did, sessionId); logoutErr != nil { 90 o.Logger.Warn("store logout after permanent auth error failed", "did", did, "err", logoutErr) 91 } 92 return true 93} 94 95func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 96 var oauthConfig oauth.ClientConfig 97 var clientUri string 98 if config.Core.Dev { 99 clientUri = "http://127.0.0.1:3000" 100 callbackUri := clientUri + "/oauth/callback" 101 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes) 102 } else { 103 clientUri = "https://" + config.Core.AppviewHost 104 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 105 callbackUri := clientUri + "/oauth/callback" 106 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes) 107 } 108 109 // configure client secret 110 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 111 if err != nil { 112 return nil, err 113 } 114 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 115 return nil, err 116 } 117 118 jwksUri := clientUri + "/oauth/jwks.json" 119 120 authStore, err := NewRedisStore(&RedisStoreConfig{ 121 RedisURL: config.Redis.ToURL(), 122 SessionExpiryDuration: time.Hour * 24 * 90, 123 SessionInactivityDuration: time.Hour * 24 * 14, 124 AuthRequestExpiryDuration: time.Minute * 30, 125 }) 126 if err != nil { 127 return nil, err 128 } 129 130 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 131 sessStore.Options.SameSite = http.SameSiteLaxMode 132 sessStore.Options.HttpOnly = true 133 sessStore.Options.Secure = !config.Core.Dev 134 135 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 136 clientApp.Dir = res.Directory() 137 // allow non-public transports in dev mode 138 if config.Core.Dev { 139 clientApp.Resolver.Client.Transport = http.DefaultTransport 140 } 141 142 clientName := config.Core.AppviewName 143 144 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential()) 145 return &OAuth{ 146 ClientApp: clientApp, 147 Config: config, 148 SessStore: sessStore, 149 JwksUri: jwksUri, 150 ClientName: clientName, 151 ClientUri: clientUri, 152 Posthog: ph, 153 Db: db, 154 Enforcer: enforcer, 155 IdResolver: res, 156 Logger: logger, 157 sessionCache: expirable.NewLRU[string, *oauth.ClientSession](sessionCacheSize, nil, sessionCacheTTL), 158 }, nil 159} 160 161func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 162 userSession, err := o.SessStore.Get(r, SessionName) 163 if err != nil { 164 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err) 165 } 166 167 userSession.Values[SessionDid] = sessData.AccountDID.String() 168 userSession.Values[SessionPds] = sessData.HostURL 169 userSession.Values[SessionId] = sessData.SessionID 170 userSession.Values[SessionAuthenticated] = true 171 172 if err := userSession.Save(r, w); err != nil { 173 return err 174 } 175 176 handle := "" 177 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String()) 178 if err == nil && resolved.Handle.String() != "" { 179 handle = resolved.Handle.String() 180 } 181 182 registry := o.GetAccounts(r) 183 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil { 184 return err 185 } 186 return o.saveAccounts(w, r, registry) 187} 188 189func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 190 userSession, err := o.SessStore.Get(r, SessionName) 191 if err != nil { 192 return nil, fmt.Errorf("error getting user session: %w", err) 193 } 194 if userSession.IsNew { 195 return nil, fmt.Errorf("no session available for user") 196 } 197 198 d := userSession.Values[SessionDid].(string) 199 sessDid, err := syntax.ParseDID(d) 200 if err != nil { 201 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 202 } 203 204 sessId := userSession.Values[SessionId].(string) 205 206 clientSess, err := o.resumeSession(r.Context(), sessDid, sessId) 207 if err != nil { 208 return nil, fmt.Errorf("failed to resume session: %w", err) 209 } 210 211 return clientSess, nil 212} 213 214func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 215 userSession, err := o.SessStore.Get(r, SessionName) 216 if err != nil { 217 return fmt.Errorf("error getting user session: %w", err) 218 } 219 if userSession.IsNew { 220 return fmt.Errorf("no session available for user") 221 } 222 223 d := userSession.Values[SessionDid].(string) 224 sessDid, err := syntax.ParseDID(d) 225 if err != nil { 226 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 227 } 228 229 sessId := userSession.Values[SessionId].(string) 230 231 o.EvictSession(sessDid, sessId) 232 233 // delete the session 234 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 235 if err1 != nil { 236 err1 = fmt.Errorf("failed to logout: %w", err1) 237 } 238 o.EvictSession(sessDid, sessId) 239 240 // remove the cookie 241 userSession.Options.MaxAge = -1 242 err2 := o.SessStore.Save(r, w, userSession) 243 if err2 != nil { 244 err2 = fmt.Errorf("failed to save into session store: %w", err2) 245 } 246 247 return errors.Join(err1, err2) 248} 249 250func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 251 registry := o.GetAccounts(r) 252 account := registry.FindAccount(targetDid) 253 if account == nil { 254 return fmt.Errorf("account not found in registry: %s", targetDid) 255 } 256 257 did, err := syntax.ParseDID(targetDid) 258 if err != nil { 259 return fmt.Errorf("invalid DID: %w", err) 260 } 261 262 sess, err := o.resumeSession(r.Context(), did, account.SessionId) 263 if err != nil { 264 registry.RemoveAccount(targetDid) 265 _ = o.saveAccounts(w, r, registry) 266 return fmt.Errorf("session expired for account: %w", err) 267 } 268 269 userSession, err := o.SessStore.Get(r, SessionName) 270 if err != nil { 271 return err 272 } 273 274 userSession.Values[SessionDid] = sess.Data.AccountDID.String() 275 userSession.Values[SessionPds] = sess.Data.HostURL 276 userSession.Values[SessionId] = sess.Data.SessionID 277 userSession.Values[SessionAuthenticated] = true 278 279 return userSession.Save(r, w) 280} 281 282func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 283 registry := o.GetAccounts(r) 284 account := registry.FindAccount(targetDid) 285 if account == nil { 286 return nil 287 } 288 289 did, err := syntax.ParseDID(targetDid) 290 if err == nil { 291 o.EvictSession(did, account.SessionId) 292 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId) 293 o.EvictSession(did, account.SessionId) 294 } 295 296 registry.RemoveAccount(targetDid) 297 return o.saveAccounts(w, r, registry) 298} 299 300func (o *OAuth) GetDid(r *http.Request) string { 301 if u := o.GetMultiAccountUser(r); u != nil { 302 return u.Did 303 } 304 305 return "" 306} 307 308func (o *OAuth) GetDidFromCookie(r *http.Request) syntax.DID { 309 userSession, err := o.SessStore.Get(r, SessionName) 310 if err != nil || userSession.IsNew { 311 return "" 312 } 313 d, ok := userSession.Values[SessionDid].(string) 314 if !ok { 315 return "" 316 } 317 parsed, err := syntax.ParseDID(d) 318 if err != nil { 319 return "" 320 } 321 return parsed 322} 323 324func (o *OAuth) GetSessIdFromCookie(r *http.Request) string { 325 userSession, err := o.SessStore.Get(r, SessionName) 326 if err != nil || userSession.IsNew { 327 return "" 328 } 329 s, ok := userSession.Values[SessionId].(string) 330 if !ok { 331 return "" 332 } 333 return s 334} 335 336func (o *OAuth) AuthorizedClient(r *http.Request) (*atclient.APIClient, error) { 337 session, err := o.ResumeSession(r) 338 if err != nil { 339 return nil, fmt.Errorf("error getting session: %w", err) 340 } 341 return session.APIClient(), nil 342} 343 344// this is a higher level abstraction on ServerGetServiceAuth 345type ServiceClientOpts struct { 346 service string 347 exp int64 348 lxm string 349 dev bool 350 timeout time.Duration 351} 352 353type ServiceClientOpt func(*ServiceClientOpts) 354 355func DefaultServiceClientOpts() ServiceClientOpts { 356 return ServiceClientOpts{ 357 timeout: time.Second * 5, 358 } 359} 360 361func WithService(service string) ServiceClientOpt { 362 return func(s *ServiceClientOpts) { 363 s.service = service 364 } 365} 366 367// Specify the Duration in seconds for the expiry of this token 368// 369// The time of expiry is calculated as time.Now().Unix() + exp 370func WithExp(exp int64) ServiceClientOpt { 371 return func(s *ServiceClientOpts) { 372 s.exp = time.Now().Unix() + exp 373 } 374} 375 376func WithLxm(lxm string) ServiceClientOpt { 377 return func(s *ServiceClientOpts) { 378 s.lxm = lxm 379 } 380} 381 382func WithDev(dev bool) ServiceClientOpt { 383 return func(s *ServiceClientOpts) { 384 s.dev = dev 385 } 386} 387 388func WithTimeout(timeout time.Duration) ServiceClientOpt { 389 return func(s *ServiceClientOpts) { 390 s.timeout = timeout 391 } 392} 393 394func (s *ServiceClientOpts) Audience() string { 395 return fmt.Sprintf("did:web:%s", s.service) 396} 397 398func (s *ServiceClientOpts) Host() string { 399 scheme := "https://" 400 if s.dev { 401 scheme = "http://" 402 } 403 404 return scheme + s.service 405} 406 407func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 408 opts := DefaultServiceClientOpts() 409 for _, o := range os { 410 o(&opts) 411 } 412 413 client, err := o.AuthorizedClient(r) 414 if err != nil { 415 return nil, err 416 } 417 418 // force expiry to atleast 60 seconds in the future 419 sixty := time.Now().Unix() + 60 420 if opts.exp < sixty { 421 opts.exp = sixty 422 } 423 424 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 425 if err != nil { 426 return nil, err 427 } 428 429 return &xrpc.Client{ 430 Auth: &xrpc.AuthInfo{ 431 AccessJwt: resp.Token, 432 }, 433 Host: opts.Host(), 434 Client: &http.Client{ 435 Timeout: opts.timeout, 436 }, 437 }, nil 438} 439 440func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) { 441 parsedDid, err := syntax.ParseDID(did) 442 if err != nil { 443 return "", fmt.Errorf("invalid DID: %w", err) 444 } 445 446 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier()) 447 if err != nil { 448 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err) 449 } 450 451 host := ident.PDSEndpoint() 452 if host == "" { 453 return "", fmt.Errorf("identity does not link to an atproto host (PDS)") 454 } 455 456 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host) 457 if err != nil { 458 return "", fmt.Errorf("resolving auth server: %w", err) 459 } 460 461 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL) 462 if err != nil { 463 return "", fmt.Errorf("fetching auth server metadata: %w", err) 464 } 465 466 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes)) 467 scopes = append(scopes, TangledScopes...) 468 scopes = append(scopes, extraScopes...) 469 470 loginHint := did 471 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() { 472 loginHint = ident.Handle.String() 473 } 474 475 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint) 476 if err != nil { 477 return "", fmt.Errorf("auth request failed: %w", err) 478 } 479 480 info.AccountDID = &parsedDid 481 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info) 482 483 if err := o.SetAuthReturn(w, r, returnURL); err != nil { 484 return "", fmt.Errorf("failed to set auth return: %w", err) 485 } 486 487 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s", 488 authserverMeta.AuthorizationEndpoint, 489 url.QueryEscape(o.ClientApp.Config.ClientID), 490 url.QueryEscape(info.RequestURI), 491 ) 492 493 return redirectURL, nil 494}