A better Rust ATProto crate
1

Configure Feed

Select the types of activity you want to include in your feed.

at main 16 kB View raw
1use std::future::Future; 2use std::sync::Arc; 3 4use dashmap::DashMap; 5use jacquard_common::{ 6 bos::BosStr, 7 session::{SessionHint, SessionKey, SessionSelector, SessionStore, SessionStoreError}, 8 types::did::Did, 9}; 10use jacquard_identity::resolver::IdentityResolver; 11use smol_str::{SmolStr, format_smolstr}; 12 13use crate::session::{AuthRequestData, ClientSessionData}; 14 15/// OAuth session lookup result with the matched key and session data. 16#[derive(Clone, Debug, PartialEq, Eq)] 17pub struct OAuthSessionMatch { 18 /// Matched session key. 19 pub key: SessionKey, 20 /// Stored OAuth client session data for the matched key. 21 pub session: ClientSessionData, 22} 23 24/// Resolver-backed OAuth session selector. 25/// 26/// This adapter keeps selection pluggable: callers can depend on [`SessionSelector`] while stores 27/// with better indexing can provide their own selector implementation. 28pub struct OAuthSessionSelector<'a, S, R> { 29 store: &'a S, 30 resolver: &'a R, 31} 32 33impl<'a, S, R> OAuthSessionSelector<'a, S, R> { 34 /// Create a selector over an OAuth auth store and identity resolver. 35 pub fn new(store: &'a S, resolver: &'a R) -> Self { 36 Self { store, resolver } 37 } 38} 39 40impl<S, R> SessionSelector<OAuthSessionMatch> for OAuthSessionSelector<'_, S, R> 41where 42 S: ClientAuthStore + SessionSelector<OAuthSessionMatch, Error = SessionStoreError> + Sync, 43 R: IdentityResolver + Sync, 44{ 45 type Error = SessionStoreError; 46 47 async fn select_session<Str: BosStr + Send + Sync>( 48 &self, 49 hint: &SessionHint<Str>, 50 ) -> Result<Option<OAuthSessionMatch>, Self::Error> { 51 if let Some(matched) = self.store.select_session(hint).await? { 52 return Ok(Some(matched)); 53 } 54 55 let SessionHint::Handle(handle) = hint else { 56 return Ok(None); 57 }; 58 59 let did = self 60 .resolver 61 .resolve_handle(handle) 62 .await 63 .map_err(|e| SessionStoreError::Other(Box::new(e)))?; 64 self.store.select_session(&SessionHint::Did(did)).await 65 } 66} 67 68/// Resolve a [`SessionHint`] against an OAuth [`ClientAuthStore`]. 69/// 70/// Exact key lookup avoids enumeration. `Any`, `Did`, and `Handle` use 71/// [`ClientAuthStore::list_session_keys`] as the generic fallback; stores that need more efficient 72/// indexed lookup can add specialized APIs later without changing the common key type. 73pub async fn resolve_oauth_session_hint<S, R, Str>( 74 store: &S, 75 resolver: &R, 76 hint: &SessionHint<Str>, 77) -> Result<Option<OAuthSessionMatch>, SessionStoreError> 78where 79 S: ClientAuthStore + SessionSelector<OAuthSessionMatch, Error = SessionStoreError> + Sync, 80 R: IdentityResolver + Sync, 81 Str: BosStr + Send + Sync, 82{ 83 OAuthSessionSelector::new(store, resolver) 84 .select_session(hint) 85 .await 86} 87 88async fn oauth_match_for_did<S, D>( 89 store: &S, 90 did: &Did<D>, 91) -> Result<Option<OAuthSessionMatch>, SessionStoreError> 92where 93 S: ClientAuthStore, 94 D: BosStr + Send + Sync, 95{ 96 for key in store.list_session_keys().await? { 97 if key.did.as_str() == did.as_ref() { 98 if let Some(matched) = oauth_match_for_key(store, key).await? { 99 return Ok(Some(matched)); 100 } 101 } 102 } 103 Ok(None) 104} 105 106async fn oauth_match_for_key<S>( 107 store: &S, 108 key: SessionKey, 109) -> Result<Option<OAuthSessionMatch>, SessionStoreError> 110where 111 S: ClientAuthStore, 112{ 113 Ok(store 114 .get_session(&key.did, key.session_id.as_str()) 115 .await? 116 .map(|session| OAuthSessionMatch { key, session })) 117} 118 119/// Persistent storage backend for OAuth client sessions and in-flight authorization requests. 120/// 121/// Implementors are responsible for durably storing two categories of data: 122/// - Active client sessions (access tokens, refresh tokens, nonces) keyed by DID + session ID. 123/// - Pending authorization request state, keyed by the OAuth `state` parameter, which must 124/// survive the round-trip to the authorization server and be cleaned up after use. 125#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] 126pub trait ClientAuthStore { 127 /// Retrieve an active session for the given DID and session identifier, if one exists. 128 fn get_session<D: BosStr + Send + Sync>( 129 &self, 130 did: &Did<D>, 131 session_id: &str, 132 ) -> impl Future<Output = Result<Option<ClientSessionData>, SessionStoreError>>; 133 134 /// Insert or update a session, replacing any existing entry for the same DID and session ID. 135 fn upsert_session( 136 &self, 137 session: ClientSessionData, 138 ) -> impl Future<Output = Result<(), SessionStoreError>>; 139 140 /// Delete the session for the given DID and session identifier. 141 fn delete_session<D: BosStr + Send + Sync>( 142 &self, 143 did: &Did<D>, 144 session_id: &str, 145 ) -> impl Future<Output = Result<(), SessionStoreError>>; 146 147 /// Retrieve the authorization request data associated with the given OAuth `state` value. 148 fn get_auth_req_info( 149 &self, 150 state: &str, 151 ) -> impl Future<Output = Result<Option<AuthRequestData>, SessionStoreError>>; 152 153 /// Persist authorization request data so it can be retrieved after the OAuth redirect. 154 fn save_auth_req_info( 155 &self, 156 auth_req_info: &AuthRequestData, 157 ) -> impl Future<Output = Result<(), SessionStoreError>>; 158 159 /// Remove authorization request data after the callback has been handled. 160 fn delete_auth_req_info( 161 &self, 162 state: &str, 163 ) -> impl Future<Output = Result<(), SessionStoreError>>; 164 165 /// List active OAuth session keys when the backend supports enumeration. 166 fn list_session_keys( 167 &self, 168 ) -> impl Future<Output = Result<Vec<SessionKey>, SessionStoreError>> { 169 async { Ok(Vec::new()) } 170 } 171} 172 173/// An in-memory implementation of [`ClientAuthStore`], suitable for testing and single-process 174/// deployments where session persistence across restarts is not required. 175pub struct MemoryAuthStore { 176 sessions: DashMap<SmolStr, ClientSessionData>, 177 auth_reqs: DashMap<SmolStr, AuthRequestData>, 178} 179 180impl MemoryAuthStore { 181 /// Create a new, empty in-memory auth store. 182 pub fn new() -> Self { 183 Self { 184 sessions: DashMap::new(), 185 auth_reqs: DashMap::new(), 186 } 187 } 188} 189 190impl ClientAuthStore for MemoryAuthStore { 191 async fn get_session<D: BosStr + Send + Sync>( 192 &self, 193 did: &Did<D>, 194 session_id: &str, 195 ) -> Result<Option<ClientSessionData>, SessionStoreError> { 196 let key = format_smolstr!("{}/{}", did, session_id); 197 Ok(self.sessions.get(&key).map(|v| v.clone())) 198 } 199 200 async fn upsert_session(&self, session: ClientSessionData) -> Result<(), SessionStoreError> { 201 let key = format_smolstr!("{}/{}", session.account_did, session.session_id); 202 self.sessions.insert(key, session); 203 Ok(()) 204 } 205 206 async fn delete_session<D: BosStr + Send + Sync>( 207 &self, 208 did: &Did<D>, 209 session_id: &str, 210 ) -> Result<(), SessionStoreError> { 211 let key = format_smolstr!("{}/{}", did, session_id); 212 self.sessions.remove(&key); 213 Ok(()) 214 } 215 216 async fn get_auth_req_info( 217 &self, 218 state: &str, 219 ) -> Result<Option<AuthRequestData>, SessionStoreError> { 220 Ok(self.auth_reqs.get(state).map(|v| v.clone())) 221 } 222 223 async fn save_auth_req_info( 224 &self, 225 auth_req_info: &AuthRequestData, 226 ) -> Result<(), SessionStoreError> { 227 self.auth_reqs 228 .insert(auth_req_info.state.clone(), auth_req_info.clone()); 229 Ok(()) 230 } 231 232 async fn delete_auth_req_info(&self, state: &str) -> Result<(), SessionStoreError> { 233 self.auth_reqs.remove(state); 234 Ok(()) 235 } 236 237 async fn list_session_keys(&self) -> Result<Vec<SessionKey>, SessionStoreError> { 238 let mut sessions = self 239 .sessions 240 .iter() 241 .map(|entry| { 242 let session = entry.value(); 243 SessionKey::new(session.account_did.clone(), session.session_id.clone()) 244 }) 245 .collect::<Vec<_>>(); 246 sessions.sort(); 247 Ok(sessions) 248 } 249} 250 251impl SessionSelector<OAuthSessionMatch> for MemoryAuthStore { 252 type Error = SessionStoreError; 253 254 async fn select_session<Str: BosStr + Send + Sync>( 255 &self, 256 hint: &SessionHint<Str>, 257 ) -> Result<Option<OAuthSessionMatch>, Self::Error> { 258 match hint { 259 SessionHint::Any => { 260 let Some(key) = self.list_session_keys().await?.into_iter().next() else { 261 return Ok(None); 262 }; 263 oauth_match_for_key(self, key).await 264 } 265 SessionHint::Did(did) => oauth_match_for_did(self, did).await, 266 SessionHint::Handle(_) | SessionHint::Identifier(_) => Ok(None), 267 SessionHint::Key(key) => oauth_match_for_key(self, key.clone()).await, 268 } 269 } 270} 271 272impl<T: ClientAuthStore + Send + Sync> SessionStore<SessionKey, ClientSessionData> for Arc<T> { 273 /// Get the current session if present. 274 async fn get(&self, key: &SessionKey) -> Option<ClientSessionData> { 275 self.as_ref() 276 .get_session(&key.did, key.session_id.as_str()) 277 .await 278 .ok() 279 .flatten() 280 } 281 /// Persist the given session. 282 async fn set( 283 &self, 284 _key: SessionKey, 285 session: ClientSessionData, 286 ) -> Result<(), SessionStoreError> { 287 self.as_ref().upsert_session(session).await 288 } 289 /// Delete the given session. 290 async fn del(&self, key: &SessionKey) -> Result<(), SessionStoreError> { 291 self.as_ref() 292 .delete_session(&key.did, key.session_id.as_str()) 293 .await 294 } 295 296 async fn list_keys(&self) -> Result<Vec<SessionKey>, SessionStoreError> { 297 self.as_ref().list_session_keys().await 298 } 299} 300 301#[cfg(test)] 302mod tests { 303 use super::*; 304 use jacquard_common::deps::fluent_uri::Uri; 305 306 use crate::scopes::Scopes; 307 use crate::session::DpopClientData; 308 use crate::types::{OAuthTokenType, TokenSet}; 309 310 fn client_session(did: &'static str, session_id: &'static str) -> ClientSessionData { 311 let account_did = Did::new_static(did).unwrap(); 312 ClientSessionData { 313 account_did: account_did.clone(), 314 session_id: SmolStr::new_static(session_id), 315 host_url: Uri::parse("https://pds.example.com").unwrap().to_owned(), 316 authserver_url: SmolStr::new_static("https://issuer.example.com"), 317 authserver_token_endpoint: SmolStr::new_static("https://issuer.example.com/token"), 318 authserver_revocation_endpoint: None, 319 scopes: Scopes::empty(), 320 dpop_data: DpopClientData { 321 dpop_key: crate::utils::generate_key(&[SmolStr::new_static("ES256")]).unwrap(), 322 dpop_authserver_nonce: SmolStr::default(), 323 dpop_host_nonce: SmolStr::default(), 324 }, 325 token_set: TokenSet { 326 iss: SmolStr::new_static("https://issuer.example.com"), 327 sub: account_did, 328 aud: SmolStr::new_static("https://pds.example.com"), 329 scope: None, 330 refresh_token: None, 331 access_token: SmolStr::new_static("access"), 332 token_type: OAuthTokenType::DPoP, 333 expires_at: None, 334 }, 335 resolved_scopes: None, 336 } 337 } 338 339 #[tokio::test] 340 async fn memory_auth_store_lists_session_keys() { 341 let store = MemoryAuthStore::new(); 342 let session = client_session("did:plc:alice", "state"); 343 store.upsert_session(session).await.unwrap(); 344 345 assert_eq!( 346 store.list_session_keys().await.unwrap(), 347 vec![SessionKey::new( 348 Did::new_static("did:plc:alice").unwrap(), 349 "state" 350 )] 351 ); 352 } 353 354 #[tokio::test] 355 async fn memory_auth_store_selects_sessions_without_identifier_fallback() { 356 let store = MemoryAuthStore::new(); 357 let alice = client_session("did:plc:alice", "state-a"); 358 let alice_key = SessionKey::new(Did::new_static("did:plc:alice").unwrap(), "state-a"); 359 store.upsert_session(alice.clone()).await.unwrap(); 360 store 361 .upsert_session(client_session("did:plc:bob", "state-b")) 362 .await 363 .unwrap(); 364 365 let matched = store 366 .select_session(&SessionHint::any()) 367 .await 368 .unwrap() 369 .expect("any match"); 370 assert_eq!(matched.key, alice_key); 371 assert_eq!(matched.session, alice); 372 373 let matched = store 374 .select_session(&SessionHint::did(Did::new_static("did:plc:alice").unwrap())) 375 .await 376 .unwrap() 377 .expect("did match"); 378 assert_eq!(matched.key, alice_key); 379 380 let matched = store 381 .select_session(&SessionHint::key(alice_key.clone())) 382 .await 383 .unwrap() 384 .expect("key match"); 385 assert_eq!(matched.key, alice_key); 386 387 assert!( 388 store 389 .select_session(&SessionHint::identifier("alice@example.com".into())) 390 .await 391 .unwrap() 392 .is_none(), 393 "identifier hints must not fall back to Any" 394 ); 395 } 396 397 #[derive(Clone, Default)] 398 struct CountingResolver { 399 handle_calls: Arc<tokio::sync::RwLock<usize>>, 400 } 401 402 impl IdentityResolver for CountingResolver { 403 fn options(&self) -> &jacquard_identity::resolver::ResolverOptions { 404 use std::sync::LazyLock; 405 static OPTS: LazyLock<jacquard_identity::resolver::ResolverOptions> = 406 LazyLock::new(jacquard_identity::resolver::ResolverOptions::default); 407 &OPTS 408 } 409 410 async fn resolve_handle<S: BosStr + Sync>( 411 &self, 412 _handle: &jacquard_common::types::string::Handle<S>, 413 ) -> Result<Did, jacquard_identity::resolver::IdentityError> { 414 *self.handle_calls.write().await += 1; 415 Ok(Did::new_static("did:plc:alice").unwrap()) 416 } 417 418 async fn resolve_did_doc<S: BosStr + Sync>( 419 &self, 420 _did: &Did<S>, 421 ) -> Result< 422 jacquard_identity::resolver::DidDocResponse, 423 jacquard_identity::resolver::IdentityError, 424 > { 425 unreachable!("OAuth selector tests do not resolve DID documents") 426 } 427 } 428 429 #[tokio::test] 430 async fn oauth_session_selector_uses_store_before_handle_resolution() { 431 let store = MemoryAuthStore::new(); 432 let resolver = CountingResolver::default(); 433 let alice = client_session("did:plc:alice", "state"); 434 store.upsert_session(alice.clone()).await.unwrap(); 435 436 assert!( 437 OAuthSessionSelector::new(&store, &resolver) 438 .select_session(&SessionHint::identifier("alice@example.com".into())) 439 .await 440 .unwrap() 441 .is_none(), 442 "identifier hints should not trigger resolver fallback" 443 ); 444 assert_eq!(*resolver.handle_calls.read().await, 0); 445 446 let matched = OAuthSessionSelector::new(&store, &resolver) 447 .select_session(&SessionHint::handle( 448 jacquard_common::types::string::Handle::new_static("alice.bsky.social").unwrap(), 449 )) 450 .await 451 .unwrap() 452 .expect("resolver fallback DID match"); 453 assert_eq!(matched.session, alice); 454 assert_eq!(*resolver.handle_calls.read().await, 1); 455 } 456 457 #[tokio::test] 458 async fn arc_memory_auth_store_is_session_store() { 459 let store = Arc::new(MemoryAuthStore::new()); 460 let session = client_session("did:plc:alice", "state"); 461 let key = SessionKey::new(Did::new_static("did:plc:alice").unwrap(), "state"); 462 463 SessionStore::set(&store, key.clone(), session.clone()) 464 .await 465 .unwrap(); 466 assert_eq!(SessionStore::get(&store, &key).await, Some(session)); 467 assert_eq!( 468 SessionStore::list_keys(&store).await.unwrap(), 469 vec![key.clone()] 470 ); 471 SessionStore::del(&store, &key).await.unwrap(); 472 assert_eq!(SessionStore::get(&store, &key).await, None); 473 } 474}