Monorepo for Tangled
tangled.org
1package migration
2
3import (
4 "context"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/atclient"
13 "github.com/bluesky-social/indigo/atproto/identity"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15
16 "tangled.org/core/appview/db"
17 "tangled.org/core/appview/models"
18 "tangled.org/core/appview/oauth"
19)
20
21const maxConcurrentMigrations = 8
22
23type migrator func(ctx context.Context, client *atclient.APIClient, did syntax.DID, aturi syntax.ATURI) error
24
25type permAuthErrHandler func(ctx context.Context, did syntax.DID, sessId string, err error) bool
26
27type Migration struct {
28 db *db.DB
29 oauth *oauth.OAuth
30 dir identity.Directory
31 logger *slog.Logger
32 inflight sync.Map
33 sem chan struct{}
34 migrators map[string]migrator
35 onPermAuthErr permAuthErrHandler
36}
37
38func NewMigration(db *db.DB, oauth *oauth.OAuth, dir identity.Directory, logger *slog.Logger) *Migration {
39 m := &Migration{
40 db: db,
41 oauth: oauth,
42 dir: dir,
43 logger: logger,
44 sem: make(chan struct{}, maxConcurrentMigrations),
45 onPermAuthErr: oauth.HandlePermanentAuthErr,
46 }
47 m.migrators = map[string]migrator{
48 "add-repo-did": m.migrateAddRepoDid,
49 "use-feed-comment": m.migrateUseFeedComment,
50 }
51 return m
52}
53
54func (s *Migration) BackgroundMigrationMiddleware(next http.Handler) http.Handler {
55 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
56 defer next.ServeHTTP(w, r)
57
58 did := s.oauth.GetDidFromCookie(r)
59 if did == "" {
60 return
61 }
62
63 hasPending, err := db.HasPendingPdsRecordMigration(r.Context(), s.db, did)
64 if err != nil || !hasPending {
65 return
66 }
67
68 if _, loaded := s.inflight.LoadOrStore(did, struct{}{}); loaded {
69 return
70 }
71
72 select {
73 case s.sem <- struct{}{}:
74 default:
75 s.inflight.Delete(did)
76 return
77 }
78
79 sessId := s.oauth.GetSessIdFromCookie(r)
80 client, err := s.oauth.AuthorizedClient(r)
81 if err != nil || client.AccountDID == nil {
82 <-s.sem
83 s.inflight.Delete(did)
84 return
85 }
86
87 go func() {
88 defer s.inflight.Delete(did)
89 defer func() { <-s.sem }()
90 s.runPendingMigrations(context.Background(), *client.AccountDID, sessId, client)
91 }()
92 })
93}
94
95func (s *Migration) runPendingMigrations(ctx context.Context, did syntax.DID, sessId string, client *atclient.APIClient) {
96 l := s.logger.With("did", did)
97 migrations, err := db.ListPendingPdsRecordMigrations(ctx, s.db, did)
98 if err != nil {
99 l.Error("failed to query pending migrations", "err", err)
100 return
101 }
102
103 for _, migration := range migrations {
104 if err := s.migrate(ctx, client, sessId, migration); err != nil {
105 l.Error("migration failed", "err", err)
106 }
107 }
108}
109
110func (s *Migration) migrate(ctx context.Context, client *atclient.APIClient, sessId string, migration *models.PDSMigration) error {
111 l := s.logger.With(
112 "name", migration.Name,
113 "aturi", migration.RecordAtUri(),
114 )
115
116 mig, ok := s.migrators[migration.Name]
117 if !ok {
118 return fmt.Errorf("unexpected migration name %s", migration.Name)
119 }
120 err := mig(ctx, client, migration.Did, migration.RecordAtUri())
121
122 if err == nil {
123 l.Info("migrated")
124 migration.Status = models.PDSMigrationStatusDone
125 migration.ErrorMsg = nil
126 migration.RetryCount = 0
127 migration.RetryAfter = 0
128 } else {
129 l.Warn("failed to migrate", "err", err)
130
131 errMsg := strings.ReplaceAll(err.Error(), "\x00", "")
132 migration.ErrorMsg = &errMsg
133 migration.RetryCount++
134
135 if s.onPermAuthErr(ctx, migration.Did, sessId, err) {
136 migration.Status = models.PDSMigrationStatusFailed
137 migration.RetryAfter = 0
138 } else {
139 migration.Status = models.PDSMigrationStatusPending
140 migration.RetryAfter = time.Now().Add(retryBackoff(migration.RetryCount)).Unix()
141 }
142 }
143 if err := db.UpdatePdsRecordMigration(ctx, s.db, migration); err != nil {
144 return fmt.Errorf("failed to update migration status: %w", err)
145 }
146 return nil
147}
148
149func retryBackoff(retries int) time.Duration {
150 return min(time.Duration(retries)*5*time.Second, time.Hour)
151}