Monorepo for Tangled tangled.org
5

Configure Feed

Select the types of activity you want to include in your feed.

1package oauth 2 3import ( 4 "context" 5 "errors" 6 "io" 7 "log/slog" 8 "sync" 9 "sync/atomic" 10 "testing" 11 12 "github.com/bluesky-social/indigo/atproto/atcrypto" 13 "github.com/bluesky-social/indigo/atproto/auth/oauth" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "github.com/hashicorp/golang-lru/v2/expirable" 16) 17 18func discardLogger(t *testing.T) *slog.Logger { 19 t.Helper() 20 return slog.New(slog.NewTextHandler(io.Discard, nil)) 21} 22 23type stubStore struct { 24 mu sync.Mutex 25 data map[string]oauth.ClientSessionData 26 getSessionCalls atomic.Int32 27 deleteCalls atomic.Int32 28} 29 30func (s *stubStore) key(did syntax.DID, sessId string) string { 31 return string(did) + ":" + sessId 32} 33 34func (s *stubStore) GetSession(_ context.Context, did syntax.DID, sessId string) (*oauth.ClientSessionData, error) { 35 s.getSessionCalls.Add(1) 36 s.mu.Lock() 37 defer s.mu.Unlock() 38 v, ok := s.data[s.key(did, sessId)] 39 if !ok { 40 return nil, errors.New("no such session") 41 } 42 clone := v 43 return &clone, nil 44} 45 46func (s *stubStore) SaveSession(_ context.Context, sess oauth.ClientSessionData) error { 47 s.mu.Lock() 48 defer s.mu.Unlock() 49 s.data[s.key(sess.AccountDID, sess.SessionID)] = sess 50 return nil 51} 52 53func (s *stubStore) DeleteSession(_ context.Context, did syntax.DID, sessId string) error { 54 s.deleteCalls.Add(1) 55 s.mu.Lock() 56 defer s.mu.Unlock() 57 delete(s.data, s.key(did, sessId)) 58 return nil 59} 60 61func (s *stubStore) GetAuthRequestInfo(context.Context, string) (*oauth.AuthRequestData, error) { 62 return nil, errors.New("not used") 63} 64func (s *stubStore) SaveAuthRequestInfo(context.Context, oauth.AuthRequestData) error { 65 return nil 66} 67func (s *stubStore) DeleteAuthRequestInfo(context.Context, string) error { return nil } 68 69func newTestOAuth(t *testing.T) (*OAuth, *stubStore) { 70 t.Helper() 71 priv, err := atcrypto.GeneratePrivateKeyP256() 72 if err != nil { 73 t.Fatalf("generate key: %v", err) 74 } 75 store := &stubStore{data: map[string]oauth.ClientSessionData{}} 76 store.data[store.key("did:plc:boltless", "sess1")] = oauth.ClientSessionData{ 77 AccountDID: "did:plc:boltless", 78 SessionID: "sess1", 79 HostURL: "https://pds.example", 80 AuthServerURL: "https://pds.example", 81 AuthServerTokenEndpoint: "https://pds.example/oauth/token", 82 DPoPPrivateKeyMultibase: priv.Multibase(), 83 } 84 85 cfg := oauth.NewLocalhostConfig("http://127.0.0.1/cb", []string{"atproto"}) 86 app := oauth.NewClientApp(&cfg, store) 87 o := &OAuth{ 88 ClientApp: app, 89 Logger: discardLogger(t), 90 sessionCache: expirable.NewLRU[string, *oauth.ClientSession](sessionCacheSize, nil, sessionCacheTTL), 91 } 92 return o, store 93} 94 95func TestResumeSessionSingleflightDedupes(t *testing.T) { 96 o, store := newTestOAuth(t) 97 98 const n = 32 99 var wg sync.WaitGroup 100 results := make([]*oauth.ClientSession, n) 101 errs := make([]error, n) 102 wg.Add(n) 103 for i := range n { 104 go func() { 105 defer wg.Done() 106 sess, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1") 107 results[i] = sess 108 errs[i] = err 109 }() 110 } 111 wg.Wait() 112 113 for i, err := range errs { 114 if err != nil { 115 t.Fatalf("goroutine %d: %v", i, err) 116 } 117 } 118 first := results[0] 119 if first == nil { 120 t.Fatal("first session is nil") 121 } 122 for i, s := range results { 123 if s != first { 124 t.Fatalf("goroutine %d got different *ClientSession (%p vs %p)", i, s, first) 125 } 126 } 127 calls := store.getSessionCalls.Load() 128 if calls > 1 { 129 t.Fatalf("GetSession called %d times, want 1", calls) 130 } 131} 132 133func TestResumeSessionReuseAfterCache(t *testing.T) { 134 o, store := newTestOAuth(t) 135 136 a, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1") 137 if err != nil { 138 t.Fatalf("first: %v", err) 139 } 140 b, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1") 141 if err != nil { 142 t.Fatalf("second: %v", err) 143 } 144 if a != b { 145 t.Fatalf("expected same pointer across cache hit") 146 } 147 if got := store.getSessionCalls.Load(); got != 1 { 148 t.Fatalf("GetSession called %d times, want 1", got) 149 } 150} 151 152func TestHandlePermanentAuthErrEvictsAndLogsOut(t *testing.T) { 153 o, store := newTestOAuth(t) 154 155 if _, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1"); err != nil { 156 t.Fatalf("seed: %v", err) 157 } 158 if _, ok := o.sessionCache.Get(sessionCacheKey("did:plc:boltless", "sess1")); !ok { 159 t.Fatal("cache missing after resume") 160 } 161 162 handled := o.HandlePermanentAuthErr( 163 context.Background(), "did:plc:boltless", "sess1", 164 errors.New("auth server request failed (HTTP 400): invalid_grant"), 165 ) 166 if !handled { 167 t.Fatal("HandlePermanentAuthErr returned false") 168 } 169 if _, ok := o.sessionCache.Get(sessionCacheKey("did:plc:boltless", "sess1")); ok { 170 t.Fatal("cache still holds entry after HandlePermanentAuthErr") 171 } 172 if got := store.deleteCalls.Load(); got != 1 { 173 t.Fatalf("store.DeleteSession called %d times, want 1", got) 174 } 175 if _, ok := store.data[store.key("did:plc:boltless", "sess1")]; ok { 176 t.Fatal("store still holds session after Logout") 177 } 178} 179 180func TestHandlePermanentAuthErrIgnoresTransient(t *testing.T) { 181 o, store := newTestOAuth(t) 182 if _, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1"); err != nil { 183 t.Fatalf("seed: %v", err) 184 } 185 handled := o.HandlePermanentAuthErr( 186 context.Background(), "did:plc:boltless", "sess1", 187 errors.New("token refresh failed (HTTP 429): rate_limited"), 188 ) 189 if handled { 190 t.Fatal("HandlePermanentAuthErr matched a transient error") 191 } 192 if _, ok := o.sessionCache.Get(sessionCacheKey("did:plc:boltless", "sess1")); !ok { 193 t.Fatal("transient error evicted cache") 194 } 195 if got := store.deleteCalls.Load(); got != 0 { 196 t.Fatalf("store.DeleteSession called %d times, want 0", got) 197 } 198}