Monorepo for Tangled
tangled.org
1package oauth
2
3import (
4 "context"
5 "errors"
6 "io"
7 "log/slog"
8 "sync"
9 "sync/atomic"
10 "testing"
11
12 "github.com/bluesky-social/indigo/atproto/atcrypto"
13 "github.com/bluesky-social/indigo/atproto/auth/oauth"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "github.com/hashicorp/golang-lru/v2/expirable"
16)
17
18func discardLogger(t *testing.T) *slog.Logger {
19 t.Helper()
20 return slog.New(slog.NewTextHandler(io.Discard, nil))
21}
22
23type stubStore struct {
24 mu sync.Mutex
25 data map[string]oauth.ClientSessionData
26 getSessionCalls atomic.Int32
27 deleteCalls atomic.Int32
28}
29
30func (s *stubStore) key(did syntax.DID, sessId string) string {
31 return string(did) + ":" + sessId
32}
33
34func (s *stubStore) GetSession(_ context.Context, did syntax.DID, sessId string) (*oauth.ClientSessionData, error) {
35 s.getSessionCalls.Add(1)
36 s.mu.Lock()
37 defer s.mu.Unlock()
38 v, ok := s.data[s.key(did, sessId)]
39 if !ok {
40 return nil, errors.New("no such session")
41 }
42 clone := v
43 return &clone, nil
44}
45
46func (s *stubStore) SaveSession(_ context.Context, sess oauth.ClientSessionData) error {
47 s.mu.Lock()
48 defer s.mu.Unlock()
49 s.data[s.key(sess.AccountDID, sess.SessionID)] = sess
50 return nil
51}
52
53func (s *stubStore) DeleteSession(_ context.Context, did syntax.DID, sessId string) error {
54 s.deleteCalls.Add(1)
55 s.mu.Lock()
56 defer s.mu.Unlock()
57 delete(s.data, s.key(did, sessId))
58 return nil
59}
60
61func (s *stubStore) GetAuthRequestInfo(context.Context, string) (*oauth.AuthRequestData, error) {
62 return nil, errors.New("not used")
63}
64func (s *stubStore) SaveAuthRequestInfo(context.Context, oauth.AuthRequestData) error {
65 return nil
66}
67func (s *stubStore) DeleteAuthRequestInfo(context.Context, string) error { return nil }
68
69func newTestOAuth(t *testing.T) (*OAuth, *stubStore) {
70 t.Helper()
71 priv, err := atcrypto.GeneratePrivateKeyP256()
72 if err != nil {
73 t.Fatalf("generate key: %v", err)
74 }
75 store := &stubStore{data: map[string]oauth.ClientSessionData{}}
76 store.data[store.key("did:plc:boltless", "sess1")] = oauth.ClientSessionData{
77 AccountDID: "did:plc:boltless",
78 SessionID: "sess1",
79 HostURL: "https://pds.example",
80 AuthServerURL: "https://pds.example",
81 AuthServerTokenEndpoint: "https://pds.example/oauth/token",
82 DPoPPrivateKeyMultibase: priv.Multibase(),
83 }
84
85 cfg := oauth.NewLocalhostConfig("http://127.0.0.1/cb", []string{"atproto"})
86 app := oauth.NewClientApp(&cfg, store)
87 o := &OAuth{
88 ClientApp: app,
89 Logger: discardLogger(t),
90 sessionCache: expirable.NewLRU[string, *oauth.ClientSession](sessionCacheSize, nil, sessionCacheTTL),
91 }
92 return o, store
93}
94
95func TestResumeSessionSingleflightDedupes(t *testing.T) {
96 o, store := newTestOAuth(t)
97
98 const n = 32
99 var wg sync.WaitGroup
100 results := make([]*oauth.ClientSession, n)
101 errs := make([]error, n)
102 wg.Add(n)
103 for i := range n {
104 go func() {
105 defer wg.Done()
106 sess, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1")
107 results[i] = sess
108 errs[i] = err
109 }()
110 }
111 wg.Wait()
112
113 for i, err := range errs {
114 if err != nil {
115 t.Fatalf("goroutine %d: %v", i, err)
116 }
117 }
118 first := results[0]
119 if first == nil {
120 t.Fatal("first session is nil")
121 }
122 for i, s := range results {
123 if s != first {
124 t.Fatalf("goroutine %d got different *ClientSession (%p vs %p)", i, s, first)
125 }
126 }
127 calls := store.getSessionCalls.Load()
128 if calls > 1 {
129 t.Fatalf("GetSession called %d times, want 1", calls)
130 }
131}
132
133func TestResumeSessionReuseAfterCache(t *testing.T) {
134 o, store := newTestOAuth(t)
135
136 a, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1")
137 if err != nil {
138 t.Fatalf("first: %v", err)
139 }
140 b, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1")
141 if err != nil {
142 t.Fatalf("second: %v", err)
143 }
144 if a != b {
145 t.Fatalf("expected same pointer across cache hit")
146 }
147 if got := store.getSessionCalls.Load(); got != 1 {
148 t.Fatalf("GetSession called %d times, want 1", got)
149 }
150}
151
152func TestHandlePermanentAuthErrEvictsAndLogsOut(t *testing.T) {
153 o, store := newTestOAuth(t)
154
155 if _, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1"); err != nil {
156 t.Fatalf("seed: %v", err)
157 }
158 if _, ok := o.sessionCache.Get(sessionCacheKey("did:plc:boltless", "sess1")); !ok {
159 t.Fatal("cache missing after resume")
160 }
161
162 handled := o.HandlePermanentAuthErr(
163 context.Background(), "did:plc:boltless", "sess1",
164 errors.New("auth server request failed (HTTP 400): invalid_grant"),
165 )
166 if !handled {
167 t.Fatal("HandlePermanentAuthErr returned false")
168 }
169 if _, ok := o.sessionCache.Get(sessionCacheKey("did:plc:boltless", "sess1")); ok {
170 t.Fatal("cache still holds entry after HandlePermanentAuthErr")
171 }
172 if got := store.deleteCalls.Load(); got != 1 {
173 t.Fatalf("store.DeleteSession called %d times, want 1", got)
174 }
175 if _, ok := store.data[store.key("did:plc:boltless", "sess1")]; ok {
176 t.Fatal("store still holds session after Logout")
177 }
178}
179
180func TestHandlePermanentAuthErrIgnoresTransient(t *testing.T) {
181 o, store := newTestOAuth(t)
182 if _, err := o.resumeSession(context.Background(), "did:plc:boltless", "sess1"); err != nil {
183 t.Fatalf("seed: %v", err)
184 }
185 handled := o.HandlePermanentAuthErr(
186 context.Background(), "did:plc:boltless", "sess1",
187 errors.New("token refresh failed (HTTP 429): rate_limited"),
188 )
189 if handled {
190 t.Fatal("HandlePermanentAuthErr matched a transient error")
191 }
192 if _, ok := o.sessionCache.Get(sessionCacheKey("did:plc:boltless", "sess1")); !ok {
193 t.Fatal("transient error evicted cache")
194 }
195 if got := store.deleteCalls.Load(); got != 0 {
196 t.Fatalf("store.DeleteSession called %d times, want 0", got)
197 }
198}