forked from
tangled.org/core
Monorepo for Tangled
1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "net/http"
11 "slices"
12 "strings"
13 "time"
14
15 comatproto "github.com/bluesky-social/indigo/api/atproto"
16 atpclient "github.com/bluesky-social/indigo/atproto/atclient"
17 "github.com/bluesky-social/indigo/atproto/auth/oauth"
18 lexutil "github.com/bluesky-social/indigo/lex/util"
19 xrpc "github.com/bluesky-social/indigo/xrpc"
20 "github.com/go-chi/chi/v5"
21 "github.com/posthog/posthog-go"
22 "tangled.org/core/api/tangled"
23 "tangled.org/core/appview/db"
24 "tangled.org/core/appview/models"
25 "tangled.org/core/consts"
26 "tangled.org/core/idresolver"
27 "tangled.org/core/orm"
28 "tangled.org/core/tid"
29)
30
31func (o *OAuth) Router() http.Handler {
32 r := chi.NewRouter()
33
34 r.Get("/oauth/client-metadata.json", o.clientMetadata)
35 r.Get("/oauth/jwks.json", o.jwks)
36 r.Get("/oauth/callback", o.callback)
37 return r
38}
39
40func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) {
41 doc := o.ClientApp.Config.ClientMetadata()
42 doc.JWKSURI = &o.JwksUri
43 doc.ClientName = &o.ClientName
44 doc.ClientURI = &o.ClientUri
45 doc.Scope = doc.Scope + " identity:handle"
46
47 w.Header().Set("Content-Type", "application/json")
48 if err := json.NewEncoder(w).Encode(doc); err != nil {
49 http.Error(w, err.Error(), http.StatusInternalServerError)
50 return
51 }
52}
53
54func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) {
55 w.Header().Set("Content-Type", "application/json")
56 body := o.ClientApp.Config.PublicJWKS()
57 if err := json.NewEncoder(w).Encode(body); err != nil {
58 http.Error(w, err.Error(), http.StatusInternalServerError)
59 return
60 }
61}
62
63func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) {
64 ctx := r.Context()
65 l := o.Logger.With("query", r.URL.Query())
66
67 authReturn := o.GetAuthReturn(r)
68 _ = o.ClearAuthReturn(w, r)
69
70 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query())
71 if err != nil {
72 var callbackErr *oauth.AuthRequestCallbackError
73 if errors.As(err, &callbackErr) {
74 l.Debug("callback error", "err", callbackErr)
75 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", callbackErr.ErrorCode), http.StatusFound)
76 return
77 }
78 l.Error("failed to process callback", "err", err)
79 http.Redirect(w, r, "/login?error=oauth", http.StatusFound)
80 return
81 }
82
83 if err := o.SaveSession(w, r, sessData); err != nil {
84 l.Error("failed to save session", "data", sessData, "err", err)
85 errorCode := "session"
86 if errors.Is(err, ErrMaxAccountsReached) {
87 errorCode = "max_accounts"
88 }
89 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound)
90 return
91 }
92
93 o.Logger.Debug("session saved successfully")
94
95 go o.addToDefaultKnot(sessData.AccountDID.String())
96 go o.addToDefaultSpindle(sessData.AccountDID.String())
97 go o.ensureTangledProfile(sessData)
98 go o.autoClaimTnglShDomain(sessData.AccountDID.String())
99 go o.drainPdsRewrites(sessData)
100
101 if !o.Config.Core.Dev {
102 err = o.Posthog.Enqueue(posthog.Capture{
103 DistinctId: sessData.AccountDID.String(),
104 Event: "signin",
105 })
106 if err != nil {
107 o.Logger.Error("failed to enqueue posthog event", "err", err)
108 }
109 }
110
111 redirectURL := "/"
112 if authReturn.ReturnURL != "" {
113 redirectURL = authReturn.ReturnURL
114 }
115
116 if o.isAccountDeactivated(sessData) {
117 redirectURL = "/settings/profile"
118 }
119
120 http.Redirect(w, r, redirectURL, http.StatusFound)
121}
122
123func (o *OAuth) isAccountDeactivated(sessData *oauth.ClientSessionData) bool {
124 pdsClient := &xrpc.Client{
125 Host: sessData.HostURL,
126 Client: &http.Client{Timeout: 5 * time.Second},
127 }
128
129 _, err := comatproto.RepoDescribeRepo(
130 context.Background(),
131 pdsClient,
132 sessData.AccountDID.String(),
133 )
134 if err == nil {
135 return false
136 }
137
138 var xrpcErr *xrpc.Error
139 var xrpcBody *xrpc.XRPCError
140 return errors.As(err, &xrpcErr) &&
141 errors.As(xrpcErr.Wrapped, &xrpcBody) &&
142 xrpcBody.ErrStr == "RepoDeactivated"
143}
144
145func (o *OAuth) addToDefaultSpindle(did string) {
146 l := o.Logger.With("subject", did)
147
148 // use the tangled.sh app password to get an accessJwt
149 // and create an sh.tangled.spindle.member record with that
150 spindleMembers, err := db.GetSpindleMembers(
151 o.Db,
152 orm.FilterEq("instance", "spindle.tangled.sh"),
153 orm.FilterEq("subject", did),
154 )
155 if err != nil {
156 l.Error("failed to get spindle members", "err", err)
157 return
158 }
159
160 if len(spindleMembers) != 0 {
161 l.Warn("already a member of the default spindle")
162 return
163 }
164
165 l.Debug("adding to default spindle")
166 session, err := o.getAppPasswordSession()
167 if err != nil {
168 l.Error("failed to create session", "err", err)
169 return
170 }
171
172 record := tangled.SpindleMember{
173 LexiconTypeID: tangled.SpindleMemberNSID,
174 Subject: did,
175 Instance: consts.DefaultSpindle,
176 CreatedAt: time.Now().Format(time.RFC3339),
177 }
178
179 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil {
180 l.Error("failed to add to default spindle", "err", err)
181 return
182 }
183
184 l.Debug("successfully added to default spindle", "did", did)
185}
186
187func (o *OAuth) addToDefaultKnot(did string) {
188 l := o.Logger.With("subject", did)
189
190 // use the tangled.sh app password to get an accessJwt
191 // and create an sh.tangled.spindle.member record with that
192
193 allKnots, err := o.Enforcer.GetKnotsForUser(did)
194 if err != nil {
195 l.Error("failed to get knot members for did", "err", err)
196 return
197 }
198
199 if slices.Contains(allKnots, consts.DefaultKnot) {
200 l.Warn("already a member of the default knot")
201 return
202 }
203
204 l.Debug("adding to default knot")
205 session, err := o.getAppPasswordSession()
206 if err != nil {
207 l.Error("failed to create session", "err", err)
208 return
209 }
210
211 record := tangled.KnotMember{
212 LexiconTypeID: tangled.KnotMemberNSID,
213 Subject: did,
214 Domain: consts.DefaultKnot,
215 CreatedAt: time.Now().Format(time.RFC3339),
216 }
217
218 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil {
219 l.Error("failed to add to default knot", "err", err)
220 return
221 }
222
223 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil {
224 l.Error("failed to set up enforcer rules", "err", err)
225 return
226 }
227
228 l.Debug("successfully added to default knot")
229}
230
231func (o *OAuth) ensureTangledProfile(sessData *oauth.ClientSessionData) {
232 ctx := context.Background()
233 did := sessData.AccountDID.String()
234 l := o.Logger.With("did", did)
235
236 profile, _ := db.GetProfile(o.Db, did)
237 if profile != nil {
238 l.Debug("profile already exists in DB")
239 return
240 }
241
242 l.Debug("creating empty Tangled profile")
243
244 sess, err := o.ClientApp.ResumeSession(ctx, sessData.AccountDID, sessData.SessionID)
245 if err != nil {
246 l.Error("failed to resume session for profile creation", "err", err)
247 return
248 }
249 client := sess.APIClient()
250
251 _, err = comatproto.RepoPutRecord(ctx, client, &comatproto.RepoPutRecord_Input{
252 Collection: tangled.ActorProfileNSID,
253 Repo: did,
254 Rkey: "self",
255 Record: &lexutil.LexiconTypeDecoder{Val: &tangled.ActorProfile{}},
256 })
257
258 if err != nil {
259 l.Error("failed to create empty profile on PDS", "err", err)
260 return
261 }
262
263 tx, err := o.Db.BeginTx(ctx, nil)
264 if err != nil {
265 l.Error("failed to start transaction", "err", err)
266 return
267 }
268
269 emptyProfile := &models.Profile{Did: did}
270 if err := db.UpsertProfile(tx, emptyProfile); err != nil {
271 l.Error("failed to create empty profile in DB", "err", err)
272 return
273 }
274
275 l.Debug("successfully created empty Tangled profile on PDS and DB")
276}
277
278func (o *OAuth) drainPdsRewrites(sessData *oauth.ClientSessionData) {
279 ctx := context.Background()
280 did := sessData.AccountDID.String()
281 l := o.Logger.With("did", did, "handler", "drainPdsRewrites")
282
283 rewrites, err := db.GetPendingPdsRewrites(o.Db, did)
284 if err != nil {
285 l.Error("failed to get pending rewrites", "err", err)
286 return
287 }
288 if len(rewrites) == 0 {
289 return
290 }
291
292 l.Info("draining pending PDS rewrites", "count", len(rewrites))
293
294 sess, err := o.ClientApp.ResumeSession(ctx, sessData.AccountDID, sessData.SessionID)
295 if err != nil {
296 l.Error("failed to resume session for PDS rewrites", "err", err)
297 return
298 }
299 client := sess.APIClient()
300
301 for _, rw := range rewrites {
302 if err := o.rewritePdsRecord(ctx, client, did, rw); err != nil {
303 l.Error("failed to rewrite PDS record",
304 "nsid", rw.RecordNsid,
305 "rkey", rw.RecordRkey,
306 "repo_did", rw.RepoDid,
307 "err", err)
308 continue
309 }
310
311 if err := db.CompletePdsRewrite(o.Db, rw.Id); err != nil {
312 l.Error("failed to mark rewrite complete", "id", rw.Id, "err", err)
313 }
314 }
315}
316
317func (o *OAuth) rewritePdsRecord(ctx context.Context, client *atpclient.APIClient, userDid string, rw db.PdsRewrite) error {
318 ex, err := comatproto.RepoGetRecord(ctx, client, "", rw.RecordNsid, userDid, rw.RecordRkey)
319 if err != nil {
320 return fmt.Errorf("get record: %w", err)
321 }
322
323 val := ex.Value.Val
324 repoDid := rw.RepoDid
325
326 switch rw.RecordNsid {
327 case tangled.RepoNSID:
328 rec, ok := val.(*tangled.Repo)
329 if !ok {
330 return fmt.Errorf("unexpected type for repo record")
331 }
332 rec.RepoDid = &repoDid
333
334 case tangled.RepoIssueNSID:
335 rec, ok := val.(*tangled.RepoIssue)
336 if !ok {
337 return fmt.Errorf("unexpected type for issue record")
338 }
339 rec.RepoDid = &repoDid
340
341 case tangled.RepoPullNSID:
342 rec, ok := val.(*tangled.RepoPull)
343 if !ok {
344 return fmt.Errorf("unexpected type for pull record")
345 }
346 if rec.Target != nil {
347 rec.Target.RepoDid = &repoDid
348 }
349 if rec.Source != nil && rec.Source.Repo != nil && *rec.Source.Repo == rw.OldRepoAt {
350 rec.Source.RepoDid = &repoDid
351 }
352
353 case tangled.RepoCollaboratorNSID:
354 rec, ok := val.(*tangled.RepoCollaborator)
355 if !ok {
356 return fmt.Errorf("unexpected type for collaborator record")
357 }
358 rec.RepoDid = &repoDid
359
360 case tangled.RepoArtifactNSID:
361 rec, ok := val.(*tangled.RepoArtifact)
362 if !ok {
363 return fmt.Errorf("unexpected type for artifact record")
364 }
365 rec.RepoDid = &repoDid
366
367 case tangled.FeedStarNSID:
368 rec, ok := val.(*tangled.FeedStar)
369 if !ok {
370 return fmt.Errorf("unexpected type for star record")
371 }
372 rec.SubjectDid = &repoDid
373
374 case tangled.ActorProfileNSID:
375 rec, ok := val.(*tangled.ActorProfile)
376 if !ok {
377 return fmt.Errorf("unexpected type for profile record")
378 }
379 var dids []string
380 var remaining []string
381 for _, pinUri := range rec.PinnedRepositories {
382 repo, repoErr := db.GetRepoByAtUri(o.Db, pinUri)
383 if repoErr != nil || repo.RepoDid == "" {
384 remaining = append(remaining, pinUri)
385 continue
386 }
387 dids = append(dids, repo.RepoDid)
388 }
389 rec.PinnedRepositoryDids = append(rec.PinnedRepositoryDids, dids...)
390 rec.PinnedRepositories = remaining
391
392 default:
393 return fmt.Errorf("unsupported NSID for PDS rewrite: %s", rw.RecordNsid)
394 }
395
396 _, err = comatproto.RepoPutRecord(ctx, client, &comatproto.RepoPutRecord_Input{
397 Collection: rw.RecordNsid,
398 Repo: userDid,
399 Rkey: rw.RecordRkey,
400 SwapRecord: ex.Cid,
401 Record: &lexutil.LexiconTypeDecoder{Val: val},
402 })
403 if err != nil {
404 return fmt.Errorf("put record: %w", err)
405 }
406
407 return nil
408}
409
410// create a AppPasswordSession using apppasswords
411type AppPasswordSession struct {
412 AccessJwt string `json:"accessJwt"`
413 RefreshJwt string `json:"refreshJwt"`
414 PdsEndpoint string
415 Did string
416 Logger *slog.Logger
417 ExpiresAt time.Time
418}
419
420func CreateAppPasswordSession(res *idresolver.Resolver, appPassword, did string, logger *slog.Logger) (*AppPasswordSession, error) {
421 if appPassword == "" {
422 return nil, fmt.Errorf("no app password configured")
423 }
424
425 resolved, err := res.ResolveIdent(context.Background(), did)
426 if err != nil {
427 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err)
428 }
429
430 pdsEndpoint := resolved.PDSEndpoint()
431 if pdsEndpoint == "" {
432 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did)
433 }
434
435 sessionPayload := map[string]string{
436 "identifier": did,
437 "password": appPassword,
438 }
439 sessionBytes, err := json.Marshal(sessionPayload)
440 if err != nil {
441 return nil, fmt.Errorf("failed to marshal session payload: %v", err)
442 }
443
444 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession"
445 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes))
446 if err != nil {
447 return nil, fmt.Errorf("failed to create session request: %v", err)
448 }
449 sessionReq.Header.Set("Content-Type", "application/json")
450
451 logger.Debug("creating app password session", "url", sessionURL, "headers", sessionReq.Header)
452
453 client := &http.Client{Timeout: 30 * time.Second}
454 sessionResp, err := client.Do(sessionReq)
455 if err != nil {
456 return nil, fmt.Errorf("failed to create session: %v", err)
457 }
458 defer sessionResp.Body.Close()
459
460 if sessionResp.StatusCode != http.StatusOK {
461 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode)
462 }
463
464 var session AppPasswordSession
465 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil {
466 return nil, fmt.Errorf("failed to decode session response: %v", err)
467 }
468
469 session.PdsEndpoint = pdsEndpoint
470 session.Did = did
471 session.Logger = logger
472 session.ExpiresAt = time.Now().Add(115 * time.Minute)
473
474 return &session, nil
475}
476
477func (s *AppPasswordSession) refreshSession() error {
478 refreshURL := s.PdsEndpoint + "/xrpc/com.atproto.server.refreshSession"
479 req, err := http.NewRequestWithContext(context.Background(), "POST", refreshURL, nil)
480 if err != nil {
481 return fmt.Errorf("failed to create refresh request: %w", err)
482 }
483
484 req.Header.Set("Authorization", "Bearer "+s.RefreshJwt)
485
486 s.Logger.Debug("refreshing app password session", "url", refreshURL)
487
488 client := &http.Client{Timeout: 30 * time.Second}
489 resp, err := client.Do(req)
490 if err != nil {
491 return fmt.Errorf("failed to refresh session: %w", err)
492 }
493 defer resp.Body.Close()
494
495 if resp.StatusCode != http.StatusOK {
496 var errorResponse map[string]any
497 if err := json.NewDecoder(resp.Body).Decode(&errorResponse); err != nil {
498 return fmt.Errorf("failed to refresh session: HTTP %d (failed to decode error response: %w)", resp.StatusCode, err)
499 }
500 errorBytes, _ := json.Marshal(errorResponse)
501 return fmt.Errorf("failed to refresh session: HTTP %d, response: %s", resp.StatusCode, string(errorBytes))
502 }
503
504 var refreshResponse struct {
505 AccessJwt string `json:"accessJwt"`
506 RefreshJwt string `json:"refreshJwt"`
507 }
508 if err := json.NewDecoder(resp.Body).Decode(&refreshResponse); err != nil {
509 return fmt.Errorf("failed to decode refresh response: %w", err)
510 }
511
512 s.AccessJwt = refreshResponse.AccessJwt
513 s.RefreshJwt = refreshResponse.RefreshJwt
514 // Set new expiry time with 5 minute buffer
515 s.ExpiresAt = time.Now().Add(115 * time.Minute)
516
517 s.Logger.Debug("successfully refreshed app password session")
518 return nil
519}
520
521func (s *AppPasswordSession) isValid() bool {
522 return time.Now().Before(s.ExpiresAt)
523}
524
525func (s *AppPasswordSession) putRecord(record any, collection string) error {
526 if !s.isValid() {
527 s.Logger.Debug("access token expired, refreshing session")
528 if err := s.refreshSession(); err != nil {
529 return fmt.Errorf("failed to refresh session: %w", err)
530 }
531 s.Logger.Debug("session refreshed")
532 }
533
534 recordBytes, err := json.Marshal(record)
535 if err != nil {
536 return fmt.Errorf("failed to marshal knot member record: %w", err)
537 }
538
539 payload := map[string]any{
540 "repo": s.Did,
541 "collection": collection,
542 "rkey": tid.TID(),
543 "record": json.RawMessage(recordBytes),
544 }
545
546 payloadBytes, err := json.Marshal(payload)
547 if err != nil {
548 return fmt.Errorf("failed to marshal request payload: %w", err)
549 }
550
551 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord"
552 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes))
553 if err != nil {
554 return fmt.Errorf("failed to create HTTP request: %w", err)
555 }
556
557 req.Header.Set("Content-Type", "application/json")
558 req.Header.Set("Authorization", "Bearer "+s.AccessJwt)
559
560 s.Logger.Debug("putting record", "url", url, "collection", collection)
561
562 client := &http.Client{Timeout: 30 * time.Second}
563 resp, err := client.Do(req)
564 if err != nil {
565 return fmt.Errorf("failed to add user to default service: %w", err)
566 }
567 defer resp.Body.Close()
568
569 if resp.StatusCode != http.StatusOK {
570 var errorResponse map[string]any
571 if err := json.NewDecoder(resp.Body).Decode(&errorResponse); err != nil {
572 return fmt.Errorf("failed to add user to default service: HTTP %d (failed to decode error response: %w)", resp.StatusCode, err)
573 }
574 return fmt.Errorf("failed to add user to default service: HTTP %d, response: %v", resp.StatusCode, errorResponse)
575 }
576
577 return nil
578}
579
580// autoClaimTnglShDomain checks if the user has a .tngl.sh handle and, if so,
581// ensures their corresponding sites domain is claimed. This is idempotent —
582// ClaimDomain is a no-op if the claim already exists.
583func (o *OAuth) autoClaimTnglShDomain(did string) {
584 l := o.Logger.With("did", did)
585
586 pdsDomain := strings.TrimPrefix(o.Config.Pds.Host, "https://")
587 pdsDomain = strings.TrimPrefix(pdsDomain, "http://")
588
589 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did)
590 if err != nil {
591 l.Error("autoClaimTnglShDomain: failed to resolve ident", "err", err)
592 return
593 }
594
595 handle := resolved.Handle.String()
596 if !strings.HasSuffix(handle, "."+pdsDomain) {
597 return
598 }
599
600 if err := db.ClaimDomain(o.Db, did, handle); err != nil {
601 l.Warn("autoClaimTnglShDomain: failed to claim domain", "domain", handle, "err", err)
602 } else {
603 l.Info("autoClaimTnglShDomain: claimed domain", "domain", handle)
604 }
605}
606
607// getAppPasswordSession returns a cached AppPasswordSession, creating one if needed.
608func (o *OAuth) getAppPasswordSession() (*AppPasswordSession, error) {
609 o.appPasswordSessionMu.Lock()
610 defer o.appPasswordSessionMu.Unlock()
611
612 if o.appPasswordSession != nil {
613 return o.appPasswordSession, nil
614 }
615
616 session, err := CreateAppPasswordSession(o.IdResolver, o.Config.Core.AppPassword, consts.TangledDid, o.Logger)
617 if err != nil {
618 return nil, err
619 }
620
621 o.appPasswordSession = session
622 return session, nil
623}