A better Rust ATProto crate
1use std::future::Future;
2use std::sync::Arc;
3
4use dashmap::DashMap;
5use jacquard_common::{
6 bos::BosStr,
7 session::{SessionHint, SessionKey, SessionSelector, SessionStore, SessionStoreError},
8 types::did::Did,
9};
10use jacquard_identity::resolver::IdentityResolver;
11use smol_str::{SmolStr, format_smolstr};
12
13use crate::session::{AuthRequestData, ClientSessionData};
14
15/// OAuth session lookup result with the matched key and session data.
16#[derive(Clone, Debug, PartialEq, Eq)]
17pub struct OAuthSessionMatch {
18 /// Matched session key.
19 pub key: SessionKey,
20 /// Stored OAuth client session data for the matched key.
21 pub session: ClientSessionData,
22}
23
24/// Resolver-backed OAuth session selector.
25///
26/// This adapter keeps selection pluggable: callers can depend on [`SessionSelector`] while stores
27/// with better indexing can provide their own selector implementation.
28pub struct OAuthSessionSelector<'a, S, R> {
29 store: &'a S,
30 resolver: &'a R,
31}
32
33impl<'a, S, R> OAuthSessionSelector<'a, S, R> {
34 /// Create a selector over an OAuth auth store and identity resolver.
35 pub fn new(store: &'a S, resolver: &'a R) -> Self {
36 Self { store, resolver }
37 }
38}
39
40impl<S, R> SessionSelector<OAuthSessionMatch> for OAuthSessionSelector<'_, S, R>
41where
42 S: ClientAuthStore + SessionSelector<OAuthSessionMatch, Error = SessionStoreError> + Sync,
43 R: IdentityResolver + Sync,
44{
45 type Error = SessionStoreError;
46
47 async fn select_session<Str: BosStr + Send + Sync>(
48 &self,
49 hint: &SessionHint<Str>,
50 ) -> Result<Option<OAuthSessionMatch>, Self::Error> {
51 if let Some(matched) = self.store.select_session(hint).await? {
52 return Ok(Some(matched));
53 }
54
55 let SessionHint::Handle(handle) = hint else {
56 return Ok(None);
57 };
58
59 let did = self
60 .resolver
61 .resolve_handle(handle)
62 .await
63 .map_err(|e| SessionStoreError::Other(Box::new(e)))?;
64 self.store.select_session(&SessionHint::Did(did)).await
65 }
66}
67
68/// Resolve a [`SessionHint`] against an OAuth [`ClientAuthStore`].
69///
70/// Exact key lookup avoids enumeration. `Any`, `Did`, and `Handle` use
71/// [`ClientAuthStore::list_session_keys`] as the generic fallback; stores that need more efficient
72/// indexed lookup can add specialized APIs later without changing the common key type.
73pub async fn resolve_oauth_session_hint<S, R, Str>(
74 store: &S,
75 resolver: &R,
76 hint: &SessionHint<Str>,
77) -> Result<Option<OAuthSessionMatch>, SessionStoreError>
78where
79 S: ClientAuthStore + SessionSelector<OAuthSessionMatch, Error = SessionStoreError> + Sync,
80 R: IdentityResolver + Sync,
81 Str: BosStr + Send + Sync,
82{
83 OAuthSessionSelector::new(store, resolver)
84 .select_session(hint)
85 .await
86}
87
88async fn oauth_match_for_did<S, D>(
89 store: &S,
90 did: &Did<D>,
91) -> Result<Option<OAuthSessionMatch>, SessionStoreError>
92where
93 S: ClientAuthStore,
94 D: BosStr + Send + Sync,
95{
96 for key in store.list_session_keys().await? {
97 if key.did.as_str() == did.as_ref() {
98 if let Some(matched) = oauth_match_for_key(store, key).await? {
99 return Ok(Some(matched));
100 }
101 }
102 }
103 Ok(None)
104}
105
106async fn oauth_match_for_key<S>(
107 store: &S,
108 key: SessionKey,
109) -> Result<Option<OAuthSessionMatch>, SessionStoreError>
110where
111 S: ClientAuthStore,
112{
113 Ok(store
114 .get_session(&key.did, key.session_id.as_str())
115 .await?
116 .map(|session| OAuthSessionMatch { key, session }))
117}
118
119/// Persistent storage backend for OAuth client sessions and in-flight authorization requests.
120///
121/// Implementors are responsible for durably storing two categories of data:
122/// - Active client sessions (access tokens, refresh tokens, nonces) keyed by DID + session ID.
123/// - Pending authorization request state, keyed by the OAuth `state` parameter, which must
124/// survive the round-trip to the authorization server and be cleaned up after use.
125#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
126pub trait ClientAuthStore {
127 /// Retrieve an active session for the given DID and session identifier, if one exists.
128 fn get_session<D: BosStr + Send + Sync>(
129 &self,
130 did: &Did<D>,
131 session_id: &str,
132 ) -> impl Future<Output = Result<Option<ClientSessionData>, SessionStoreError>>;
133
134 /// Insert or update a session, replacing any existing entry for the same DID and session ID.
135 fn upsert_session(
136 &self,
137 session: ClientSessionData,
138 ) -> impl Future<Output = Result<(), SessionStoreError>>;
139
140 /// Delete the session for the given DID and session identifier.
141 fn delete_session<D: BosStr + Send + Sync>(
142 &self,
143 did: &Did<D>,
144 session_id: &str,
145 ) -> impl Future<Output = Result<(), SessionStoreError>>;
146
147 /// Retrieve the authorization request data associated with the given OAuth `state` value.
148 fn get_auth_req_info(
149 &self,
150 state: &str,
151 ) -> impl Future<Output = Result<Option<AuthRequestData>, SessionStoreError>>;
152
153 /// Persist authorization request data so it can be retrieved after the OAuth redirect.
154 fn save_auth_req_info(
155 &self,
156 auth_req_info: &AuthRequestData,
157 ) -> impl Future<Output = Result<(), SessionStoreError>>;
158
159 /// Remove authorization request data after the callback has been handled.
160 fn delete_auth_req_info(
161 &self,
162 state: &str,
163 ) -> impl Future<Output = Result<(), SessionStoreError>>;
164
165 /// List active OAuth session keys when the backend supports enumeration.
166 fn list_session_keys(
167 &self,
168 ) -> impl Future<Output = Result<Vec<SessionKey>, SessionStoreError>> {
169 async { Ok(Vec::new()) }
170 }
171}
172
173/// An in-memory implementation of [`ClientAuthStore`], suitable for testing and single-process
174/// deployments where session persistence across restarts is not required.
175pub struct MemoryAuthStore {
176 sessions: DashMap<SmolStr, ClientSessionData>,
177 auth_reqs: DashMap<SmolStr, AuthRequestData>,
178}
179
180impl MemoryAuthStore {
181 /// Create a new, empty in-memory auth store.
182 pub fn new() -> Self {
183 Self {
184 sessions: DashMap::new(),
185 auth_reqs: DashMap::new(),
186 }
187 }
188}
189
190impl ClientAuthStore for MemoryAuthStore {
191 async fn get_session<D: BosStr + Send + Sync>(
192 &self,
193 did: &Did<D>,
194 session_id: &str,
195 ) -> Result<Option<ClientSessionData>, SessionStoreError> {
196 let key = format_smolstr!("{}/{}", did, session_id);
197 Ok(self.sessions.get(&key).map(|v| v.clone()))
198 }
199
200 async fn upsert_session(&self, session: ClientSessionData) -> Result<(), SessionStoreError> {
201 let key = format_smolstr!("{}/{}", session.account_did, session.session_id);
202 self.sessions.insert(key, session);
203 Ok(())
204 }
205
206 async fn delete_session<D: BosStr + Send + Sync>(
207 &self,
208 did: &Did<D>,
209 session_id: &str,
210 ) -> Result<(), SessionStoreError> {
211 let key = format_smolstr!("{}/{}", did, session_id);
212 self.sessions.remove(&key);
213 Ok(())
214 }
215
216 async fn get_auth_req_info(
217 &self,
218 state: &str,
219 ) -> Result<Option<AuthRequestData>, SessionStoreError> {
220 Ok(self.auth_reqs.get(state).map(|v| v.clone()))
221 }
222
223 async fn save_auth_req_info(
224 &self,
225 auth_req_info: &AuthRequestData,
226 ) -> Result<(), SessionStoreError> {
227 self.auth_reqs
228 .insert(auth_req_info.state.clone(), auth_req_info.clone());
229 Ok(())
230 }
231
232 async fn delete_auth_req_info(&self, state: &str) -> Result<(), SessionStoreError> {
233 self.auth_reqs.remove(state);
234 Ok(())
235 }
236
237 async fn list_session_keys(&self) -> Result<Vec<SessionKey>, SessionStoreError> {
238 let mut sessions = self
239 .sessions
240 .iter()
241 .map(|entry| {
242 let session = entry.value();
243 SessionKey::new(session.account_did.clone(), session.session_id.clone())
244 })
245 .collect::<Vec<_>>();
246 sessions.sort();
247 Ok(sessions)
248 }
249}
250
251impl SessionSelector<OAuthSessionMatch> for MemoryAuthStore {
252 type Error = SessionStoreError;
253
254 async fn select_session<Str: BosStr + Send + Sync>(
255 &self,
256 hint: &SessionHint<Str>,
257 ) -> Result<Option<OAuthSessionMatch>, Self::Error> {
258 match hint {
259 SessionHint::Any => {
260 let Some(key) = self.list_session_keys().await?.into_iter().next() else {
261 return Ok(None);
262 };
263 oauth_match_for_key(self, key).await
264 }
265 SessionHint::Did(did) => oauth_match_for_did(self, did).await,
266 SessionHint::Handle(_) | SessionHint::Identifier(_) => Ok(None),
267 SessionHint::Key(key) => oauth_match_for_key(self, key.clone()).await,
268 }
269 }
270}
271
272impl<T: ClientAuthStore + Send + Sync> SessionStore<SessionKey, ClientSessionData> for Arc<T> {
273 /// Get the current session if present.
274 async fn get(&self, key: &SessionKey) -> Option<ClientSessionData> {
275 self.as_ref()
276 .get_session(&key.did, key.session_id.as_str())
277 .await
278 .ok()
279 .flatten()
280 }
281 /// Persist the given session.
282 async fn set(
283 &self,
284 _key: SessionKey,
285 session: ClientSessionData,
286 ) -> Result<(), SessionStoreError> {
287 self.as_ref().upsert_session(session).await
288 }
289 /// Delete the given session.
290 async fn del(&self, key: &SessionKey) -> Result<(), SessionStoreError> {
291 self.as_ref()
292 .delete_session(&key.did, key.session_id.as_str())
293 .await
294 }
295
296 async fn list_keys(&self) -> Result<Vec<SessionKey>, SessionStoreError> {
297 self.as_ref().list_session_keys().await
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use jacquard_common::deps::fluent_uri::Uri;
305
306 use crate::scopes::Scopes;
307 use crate::session::DpopClientData;
308 use crate::types::{OAuthTokenType, TokenSet};
309
310 fn client_session(did: &'static str, session_id: &'static str) -> ClientSessionData {
311 let account_did = Did::new_static(did).unwrap();
312 ClientSessionData {
313 account_did: account_did.clone(),
314 session_id: SmolStr::new_static(session_id),
315 host_url: Uri::parse("https://pds.example.com").unwrap().to_owned(),
316 authserver_url: SmolStr::new_static("https://issuer.example.com"),
317 authserver_token_endpoint: SmolStr::new_static("https://issuer.example.com/token"),
318 authserver_revocation_endpoint: None,
319 scopes: Scopes::empty(),
320 dpop_data: DpopClientData {
321 dpop_key: crate::utils::generate_key(&[SmolStr::new_static("ES256")]).unwrap(),
322 dpop_authserver_nonce: SmolStr::default(),
323 dpop_host_nonce: SmolStr::default(),
324 },
325 token_set: TokenSet {
326 iss: SmolStr::new_static("https://issuer.example.com"),
327 sub: account_did,
328 aud: SmolStr::new_static("https://pds.example.com"),
329 scope: None,
330 refresh_token: None,
331 access_token: SmolStr::new_static("access"),
332 token_type: OAuthTokenType::DPoP,
333 expires_at: None,
334 },
335 resolved_scopes: None,
336 }
337 }
338
339 #[tokio::test]
340 async fn memory_auth_store_lists_session_keys() {
341 let store = MemoryAuthStore::new();
342 let session = client_session("did:plc:alice", "state");
343 store.upsert_session(session).await.unwrap();
344
345 assert_eq!(
346 store.list_session_keys().await.unwrap(),
347 vec![SessionKey::new(
348 Did::new_static("did:plc:alice").unwrap(),
349 "state"
350 )]
351 );
352 }
353
354 #[tokio::test]
355 async fn memory_auth_store_selects_sessions_without_identifier_fallback() {
356 let store = MemoryAuthStore::new();
357 let alice = client_session("did:plc:alice", "state-a");
358 let alice_key = SessionKey::new(Did::new_static("did:plc:alice").unwrap(), "state-a");
359 store.upsert_session(alice.clone()).await.unwrap();
360 store
361 .upsert_session(client_session("did:plc:bob", "state-b"))
362 .await
363 .unwrap();
364
365 let matched = store
366 .select_session(&SessionHint::any())
367 .await
368 .unwrap()
369 .expect("any match");
370 assert_eq!(matched.key, alice_key);
371 assert_eq!(matched.session, alice);
372
373 let matched = store
374 .select_session(&SessionHint::did(Did::new_static("did:plc:alice").unwrap()))
375 .await
376 .unwrap()
377 .expect("did match");
378 assert_eq!(matched.key, alice_key);
379
380 let matched = store
381 .select_session(&SessionHint::key(alice_key.clone()))
382 .await
383 .unwrap()
384 .expect("key match");
385 assert_eq!(matched.key, alice_key);
386
387 assert!(
388 store
389 .select_session(&SessionHint::identifier("alice@example.com".into()))
390 .await
391 .unwrap()
392 .is_none(),
393 "identifier hints must not fall back to Any"
394 );
395 }
396
397 #[derive(Clone, Default)]
398 struct CountingResolver {
399 handle_calls: Arc<tokio::sync::RwLock<usize>>,
400 }
401
402 impl IdentityResolver for CountingResolver {
403 fn options(&self) -> &jacquard_identity::resolver::ResolverOptions {
404 use std::sync::LazyLock;
405 static OPTS: LazyLock<jacquard_identity::resolver::ResolverOptions> =
406 LazyLock::new(jacquard_identity::resolver::ResolverOptions::default);
407 &OPTS
408 }
409
410 async fn resolve_handle<S: BosStr + Sync>(
411 &self,
412 _handle: &jacquard_common::types::string::Handle<S>,
413 ) -> Result<Did, jacquard_identity::resolver::IdentityError> {
414 *self.handle_calls.write().await += 1;
415 Ok(Did::new_static("did:plc:alice").unwrap())
416 }
417
418 async fn resolve_did_doc<S: BosStr + Sync>(
419 &self,
420 _did: &Did<S>,
421 ) -> Result<
422 jacquard_identity::resolver::DidDocResponse,
423 jacquard_identity::resolver::IdentityError,
424 > {
425 unreachable!("OAuth selector tests do not resolve DID documents")
426 }
427 }
428
429 #[tokio::test]
430 async fn oauth_session_selector_uses_store_before_handle_resolution() {
431 let store = MemoryAuthStore::new();
432 let resolver = CountingResolver::default();
433 let alice = client_session("did:plc:alice", "state");
434 store.upsert_session(alice.clone()).await.unwrap();
435
436 assert!(
437 OAuthSessionSelector::new(&store, &resolver)
438 .select_session(&SessionHint::identifier("alice@example.com".into()))
439 .await
440 .unwrap()
441 .is_none(),
442 "identifier hints should not trigger resolver fallback"
443 );
444 assert_eq!(*resolver.handle_calls.read().await, 0);
445
446 let matched = OAuthSessionSelector::new(&store, &resolver)
447 .select_session(&SessionHint::handle(
448 jacquard_common::types::string::Handle::new_static("alice.bsky.social").unwrap(),
449 ))
450 .await
451 .unwrap()
452 .expect("resolver fallback DID match");
453 assert_eq!(matched.session, alice);
454 assert_eq!(*resolver.handle_calls.read().await, 1);
455 }
456
457 #[tokio::test]
458 async fn arc_memory_auth_store_is_session_store() {
459 let store = Arc::new(MemoryAuthStore::new());
460 let session = client_session("did:plc:alice", "state");
461 let key = SessionKey::new(Did::new_static("did:plc:alice").unwrap(), "state");
462
463 SessionStore::set(&store, key.clone(), session.clone())
464 .await
465 .unwrap();
466 assert_eq!(SessionStore::get(&store, &key).await, Some(session));
467 assert_eq!(
468 SessionStore::list_keys(&store).await.unwrap(),
469 vec![key.clone()]
470 );
471 SessionStore::del(&store, &key).await.unwrap();
472 assert_eq!(SessionStore::get(&store, &key).await, None);
473 }
474}