Monorepo for Tangled tangled.org
6

Configure Feed

Select the types of activity you want to include in your feed.

at icy/ytnwlw 14 kB View raw
1package oauth 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "log/slog" 8 "net/http" 9 "net/url" 10 "sync" 11 "time" 12 13 comatproto "github.com/bluesky-social/indigo/api/atproto" 14 "github.com/bluesky-social/indigo/atproto/atclient" 15 "github.com/bluesky-social/indigo/atproto/atcrypto" 16 "github.com/bluesky-social/indigo/atproto/auth/oauth" 17 "github.com/bluesky-social/indigo/atproto/syntax" 18 xrpc "github.com/bluesky-social/indigo/xrpc" 19 "github.com/gorilla/sessions" 20 "github.com/hashicorp/golang-lru/v2/expirable" 21 "github.com/posthog/posthog-go" 22 "golang.org/x/sync/singleflight" 23 "tangled.org/core/appview/config" 24 "tangled.org/core/appview/db" 25 "tangled.org/core/idresolver" 26 "tangled.org/core/rbac" 27 "tangled.org/core/xrpc/serviceauth" 28) 29 30const ( 31 sessionCacheSize = 10000 32 sessionCacheTTL = time.Hour 33) 34 35type KnotMembership interface { 36 IsKnotMember(ctx context.Context, host, userDid string) bool 37 InvalidateMembers(host string) 38} 39 40type OAuth struct { 41 ClientApp *oauth.ClientApp 42 SessStore *sessions.CookieStore 43 Config *config.Config 44 JwksUri string 45 ClientName string 46 ClientUri string 47 Posthog posthog.Client 48 Db *db.DB 49 Enforcer *rbac.Enforcer 50 Acl KnotMembership 51 IdResolver *idresolver.Resolver 52 Logger *slog.Logger 53 54 appPasswordSession *AppPasswordSession 55 appPasswordSessionMu sync.Mutex 56 57 sessionCache *expirable.LRU[string, *oauth.ClientSession] 58 sessionSF singleflight.Group 59} 60 61func sessionCacheKey(did syntax.DID, sessionId string) string { 62 return string(did) + ":" + sessionId 63} 64 65func (o *OAuth) resumeSession(ctx context.Context, did syntax.DID, sessionId string) (*oauth.ClientSession, error) { 66 key := sessionCacheKey(did, sessionId) 67 if v, ok := o.sessionCache.Get(key); ok { 68 return v, nil 69 } 70 v, err, _ := o.sessionSF.Do(key, func() (any, error) { 71 if v, ok := o.sessionCache.Get(key); ok { 72 return v, nil 73 } 74 sess, err := o.ClientApp.ResumeSession(ctx, did, sessionId) 75 if err != nil { 76 return nil, err 77 } 78 o.sessionCache.Add(key, sess) 79 return sess, nil 80 }) 81 if err != nil { 82 return nil, err 83 } 84 return v.(*oauth.ClientSession), nil 85} 86 87func (o *OAuth) EvictSession(did syntax.DID, sessionId string) { 88 o.sessionCache.Remove(sessionCacheKey(did, sessionId)) 89} 90 91func (o *OAuth) HandlePermanentAuthErr(ctx context.Context, did syntax.DID, sessionId string, err error) bool { 92 if !IsPermanentAuthErr(err) { 93 return false 94 } 95 o.EvictSession(did, sessionId) 96 if logoutErr := o.ClientApp.Logout(ctx, did, sessionId); logoutErr != nil { 97 o.Logger.Warn("store logout after permanent auth error failed", "did", did, "err", logoutErr) 98 } 99 return true 100} 101 102func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, acl KnotMembership, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 103 var oauthConfig oauth.ClientConfig 104 var clientUri string 105 if config.Core.Dev { 106 clientUri = "http://127.0.0.1:3000" 107 callbackUri := clientUri + "/oauth/callback" 108 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes) 109 } else { 110 clientUri = "https://" + config.Core.AppviewHost 111 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 112 callbackUri := clientUri + "/oauth/callback" 113 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes) 114 } 115 116 // configure client secret 117 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 118 if err != nil { 119 return nil, err 120 } 121 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 122 return nil, err 123 } 124 125 jwksUri := clientUri + "/oauth/jwks.json" 126 127 authStore, err := NewRedisStore(&RedisStoreConfig{ 128 RedisURL: config.Redis.ToURL(), 129 SessionExpiryDuration: time.Hour * 24 * 90, 130 SessionInactivityDuration: time.Hour * 24 * 14, 131 AuthRequestExpiryDuration: time.Minute * 30, 132 }) 133 if err != nil { 134 return nil, err 135 } 136 137 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 138 sessStore.Options.SameSite = http.SameSiteLaxMode 139 sessStore.Options.HttpOnly = true 140 sessStore.Options.Secure = !config.Core.Dev 141 142 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 143 clientApp.Dir = res.Directory() 144 // allow non-public transports in dev mode 145 if config.Core.Dev { 146 clientApp.Resolver.Client.Transport = http.DefaultTransport 147 } 148 149 clientName := config.Core.AppviewName 150 151 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential()) 152 return &OAuth{ 153 ClientApp: clientApp, 154 Config: config, 155 SessStore: sessStore, 156 JwksUri: jwksUri, 157 ClientName: clientName, 158 ClientUri: clientUri, 159 Posthog: ph, 160 Db: db, 161 Enforcer: enforcer, 162 Acl: acl, 163 IdResolver: res, 164 Logger: logger, 165 sessionCache: expirable.NewLRU[string, *oauth.ClientSession](sessionCacheSize, nil, sessionCacheTTL), 166 }, nil 167} 168 169func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 170 userSession, err := o.SessStore.Get(r, SessionName) 171 if err != nil { 172 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err) 173 } 174 175 userSession.Values[SessionDid] = sessData.AccountDID.String() 176 userSession.Values[SessionPds] = sessData.HostURL 177 userSession.Values[SessionId] = sessData.SessionID 178 userSession.Values[SessionAuthenticated] = true 179 180 if err := userSession.Save(r, w); err != nil { 181 return err 182 } 183 184 handle := "" 185 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String()) 186 if err == nil && resolved.Handle.String() != "" { 187 handle = resolved.Handle.String() 188 } 189 190 registry := o.GetAccounts(r) 191 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil { 192 return err 193 } 194 return o.saveAccounts(w, r, registry) 195} 196 197func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 198 userSession, err := o.SessStore.Get(r, SessionName) 199 if err != nil { 200 return nil, fmt.Errorf("error getting user session: %w", err) 201 } 202 if userSession.IsNew { 203 return nil, fmt.Errorf("no session available for user") 204 } 205 206 d := userSession.Values[SessionDid].(string) 207 sessDid, err := syntax.ParseDID(d) 208 if err != nil { 209 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 210 } 211 212 sessId := userSession.Values[SessionId].(string) 213 214 clientSess, err := o.resumeSession(r.Context(), sessDid, sessId) 215 if err != nil { 216 return nil, fmt.Errorf("failed to resume session: %w", err) 217 } 218 219 return clientSess, nil 220} 221 222func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 223 userSession, err := o.SessStore.Get(r, SessionName) 224 if err != nil { 225 return fmt.Errorf("error getting user session: %w", err) 226 } 227 if userSession.IsNew { 228 return fmt.Errorf("no session available for user") 229 } 230 231 d := userSession.Values[SessionDid].(string) 232 sessDid, err := syntax.ParseDID(d) 233 if err != nil { 234 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 235 } 236 237 sessId := userSession.Values[SessionId].(string) 238 239 o.EvictSession(sessDid, sessId) 240 241 // delete the session 242 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 243 if err1 != nil { 244 err1 = fmt.Errorf("failed to logout: %w", err1) 245 } 246 o.EvictSession(sessDid, sessId) 247 248 // remove the cookie 249 userSession.Options.MaxAge = -1 250 err2 := o.SessStore.Save(r, w, userSession) 251 if err2 != nil { 252 err2 = fmt.Errorf("failed to save into session store: %w", err2) 253 } 254 255 return errors.Join(err1, err2) 256} 257 258func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 259 registry := o.GetAccounts(r) 260 account := registry.FindAccount(targetDid) 261 if account == nil { 262 return fmt.Errorf("account not found in registry: %s", targetDid) 263 } 264 265 did, err := syntax.ParseDID(targetDid) 266 if err != nil { 267 return fmt.Errorf("invalid DID: %w", err) 268 } 269 270 sess, err := o.resumeSession(r.Context(), did, account.SessionId) 271 if err != nil { 272 registry.RemoveAccount(targetDid) 273 _ = o.saveAccounts(w, r, registry) 274 return fmt.Errorf("session expired for account: %w", err) 275 } 276 277 userSession, err := o.SessStore.Get(r, SessionName) 278 if err != nil { 279 return err 280 } 281 282 userSession.Values[SessionDid] = sess.Data.AccountDID.String() 283 userSession.Values[SessionPds] = sess.Data.HostURL 284 userSession.Values[SessionId] = sess.Data.SessionID 285 userSession.Values[SessionAuthenticated] = true 286 287 return userSession.Save(r, w) 288} 289 290func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 291 registry := o.GetAccounts(r) 292 account := registry.FindAccount(targetDid) 293 if account == nil { 294 return nil 295 } 296 297 did, err := syntax.ParseDID(targetDid) 298 if err == nil { 299 o.EvictSession(did, account.SessionId) 300 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId) 301 o.EvictSession(did, account.SessionId) 302 } 303 304 registry.RemoveAccount(targetDid) 305 return o.saveAccounts(w, r, registry) 306} 307 308func (o *OAuth) GetDid(r *http.Request) string { 309 if u := o.GetMultiAccountUser(r); u != nil { 310 return u.Did 311 } 312 313 return "" 314} 315 316func (o *OAuth) GetDidFromCookie(r *http.Request) syntax.DID { 317 userSession, err := o.SessStore.Get(r, SessionName) 318 if err != nil || userSession.IsNew { 319 return "" 320 } 321 d, ok := userSession.Values[SessionDid].(string) 322 if !ok { 323 return "" 324 } 325 parsed, err := syntax.ParseDID(d) 326 if err != nil { 327 return "" 328 } 329 return parsed 330} 331 332func (o *OAuth) GetSessIdFromCookie(r *http.Request) string { 333 userSession, err := o.SessStore.Get(r, SessionName) 334 if err != nil || userSession.IsNew { 335 return "" 336 } 337 s, ok := userSession.Values[SessionId].(string) 338 if !ok { 339 return "" 340 } 341 return s 342} 343 344func (o *OAuth) AuthorizedClient(r *http.Request) (*atclient.APIClient, error) { 345 session, err := o.ResumeSession(r) 346 if err != nil { 347 return nil, fmt.Errorf("error getting session: %w", err) 348 } 349 return session.APIClient(), nil 350} 351 352// this is a higher level abstraction on ServerGetServiceAuth 353type ServiceClientOpts struct { 354 service string 355 exp int64 356 lxm string 357 dev bool 358 timeout time.Duration 359} 360 361type ServiceClientOpt func(*ServiceClientOpts) 362 363func DefaultServiceClientOpts() ServiceClientOpts { 364 return ServiceClientOpts{ 365 timeout: time.Second * 5, 366 } 367} 368 369func WithService(service string) ServiceClientOpt { 370 return func(s *ServiceClientOpts) { 371 s.service = service 372 } 373} 374 375// Specify the Duration in seconds for the expiry of this token 376// 377// The time of expiry is calculated as time.Now().Unix() + exp 378func WithExp(exp int64) ServiceClientOpt { 379 return func(s *ServiceClientOpts) { 380 s.exp = time.Now().Unix() + exp 381 } 382} 383 384func WithLxm(lxm string) ServiceClientOpt { 385 return func(s *ServiceClientOpts) { 386 s.lxm = lxm 387 } 388} 389 390func WithDev(dev bool) ServiceClientOpt { 391 return func(s *ServiceClientOpts) { 392 s.dev = dev 393 } 394} 395 396func WithTimeout(timeout time.Duration) ServiceClientOpt { 397 return func(s *ServiceClientOpts) { 398 s.timeout = timeout 399 } 400} 401 402func (s *ServiceClientOpts) Audience() string { 403 return serviceauth.DidWeb(s.service).String() 404} 405 406func (s *ServiceClientOpts) Host() string { 407 scheme := "https://" 408 if s.dev { 409 scheme = "http://" 410 } 411 412 return scheme + s.service 413} 414 415func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 416 client, err := o.AuthorizedClient(r) 417 if err != nil { 418 return nil, err 419 } 420 421 opts := DefaultServiceClientOpts() 422 for _, o := range os { 423 o(&opts) 424 } 425 426 // force expiry to atleast 60 seconds in the future 427 sixty := time.Now().Unix() + 60 428 if opts.exp < sixty { 429 opts.exp = sixty 430 } 431 432 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 433 if err != nil { 434 return nil, err 435 } 436 437 return &xrpc.Client{ 438 Auth: &xrpc.AuthInfo{ 439 AccessJwt: resp.Token, 440 }, 441 Host: opts.Host(), 442 Client: &http.Client{ 443 Timeout: opts.timeout, 444 }, 445 }, nil 446} 447 448func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) { 449 parsedDid, err := syntax.ParseDID(did) 450 if err != nil { 451 return "", fmt.Errorf("invalid DID: %w", err) 452 } 453 454 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier()) 455 if err != nil { 456 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err) 457 } 458 459 host := ident.PDSEndpoint() 460 if host == "" { 461 return "", fmt.Errorf("identity does not link to an atproto host (PDS)") 462 } 463 464 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host) 465 if err != nil { 466 return "", fmt.Errorf("resolving auth server: %w", err) 467 } 468 469 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL) 470 if err != nil { 471 return "", fmt.Errorf("fetching auth server metadata: %w", err) 472 } 473 474 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes)) 475 scopes = append(scopes, TangledScopes...) 476 scopes = append(scopes, extraScopes...) 477 478 loginHint := did 479 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() { 480 loginHint = ident.Handle.String() 481 } 482 483 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint) 484 if err != nil { 485 return "", fmt.Errorf("auth request failed: %w", err) 486 } 487 488 info.AccountDID = &parsedDid 489 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info) 490 491 if err := o.SetAuthReturn(w, r, returnURL); err != nil { 492 return "", fmt.Errorf("failed to set auth return: %w", err) 493 } 494 495 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s", 496 authserverMeta.AuthorizationEndpoint, 497 url.QueryEscape(o.ClientApp.Config.ClientID), 498 url.QueryEscape(info.RequestURI), 499 ) 500 501 return redirectURL, nil 502}