A better Rust ATProto crate
1

Configure Feed

Select the types of activity you want to include in your feed.

at main 62 kB View raw
1use crate::{ 2 atproto::atproto_client_metadata, 3 authstore::{ClientAuthStore, OAuthSessionMatch, OAuthSessionSelector}, 4 dpop::DpopExt, 5 error::{CallbackError, OAuthError, Result}, 6 request::{OAuthMetadata, exchange_code, par}, 7 resolver::OAuthResolver, 8 scopes::Scopes, 9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry}, 10 types::{AuthorizeOptions, CallbackParams}, 11}; 12#[cfg(feature = "scope-check")] 13use crate::{ 14 error::ScopeError, 15 resolver::resolve_permission_set, 16 scopes::{IncludeScope, RepoCollection, RpcLexicon, Scope}, 17}; 18#[cfg(feature = "websocket")] 19use jacquard_common::CowStr; 20#[cfg(feature = "scope-check")] 21use jacquard_common::types::{nsid::Nsid, string::DidService}; 22use jacquard_common::{ 23 AuthorizationToken, IntoStatic, 24 bos::BosStr, 25 deps::fluent_uri::Uri, 26 error::{AuthError, ClientError, XrpcResult}, 27 http_client::HttpClient, 28 session::{SessionHint, SessionSelector, SessionStoreError}, 29 types::{did::Did, string::Handle}, 30 xrpc::{ 31 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, XrpcResp, XrpcResponse, 32 build_http_request, process_response, 33 }, 34}; 35#[cfg(feature = "scope-check")] 36use jacquard_identity::lexicon_resolver::LexiconSchemaResolver; 37 38#[cfg(feature = "scope-check")] 39use jacquard_common::deps::fluent_uri::pct_enc::{EStr, encoder::Query}; 40 41#[cfg(feature = "websocket")] 42use jacquard_common::websocket::{WebSocketClient, WebSocketConnection}; 43#[cfg(feature = "websocket")] 44use jacquard_common::xrpc::XrpcSubscription; 45use jacquard_identity::{ 46 JacquardResolver, 47 resolver::{DidDocResponse, IdentityError, IdentityResolver, ResolverOptions}, 48}; 49use jose_jwk::JwkSet; 50use smol_str::{SmolStr, ToSmolStr}; 51use std::{str::FromStr, sync::Arc}; 52use tokio::sync::RwLock; 53 54/// Result of resuming an OAuth session or starting a new authorization flow. 55pub enum OAuthResumeOrLogin<T, S> 56where 57 T: OAuthResolver, 58 S: ClientAuthStore, 59{ 60 /// A stored session was found and restored/refreshed. 61 Resumed(OAuthSession<T, S>), 62 /// No stored session matched; redirect the user to this login URL. 63 LoginUrl(String), 64 /// No stored session matched, and the hint did not contain enough information to start OAuth. 65 NeedsInput, 66} 67 68/// The top-level OAuth client responsible for driving the authorization flow. 69pub struct OAuthClient<T, S> 70where 71 T: OAuthResolver, 72 S: ClientAuthStore, 73{ 74 /// Shared session registry that mediates access to the backing auth store. 75 pub registry: Arc<SessionRegistry<T, S, SmolStr>>, 76 /// Default call options applied to every outgoing XRPC request. 77 pub options: RwLock<CallOptions>, 78 /// Override for the XRPC base URI; falls back to the public Bluesky AppView when `None`. 79 pub endpoint: RwLock<Option<Uri<String>>>, 80 /// Underlying HTTP/identity/OAuth resolver used for all network operations. 81 pub client: Arc<T>, 82} 83 84impl<S: ClientAuthStore, C: HttpClient + Sync> OAuthClient<JacquardResolver<C>, S> { 85 /// Create an `OAuthClient` using the default [`JacquardResolver`] for identity and metadata resolution. 86 pub fn new(store: S, client_data: ClientData<SmolStr>, http: C) -> Self { 87 let client = JacquardResolver::new(http, ResolverOptions::default()); 88 Self::new_from_resolver(store, client, client_data) 89 } 90} 91 92impl<S: ClientAuthStore> OAuthClient<JacquardResolver<reqwest::Client>, S> { 93 /// Create an OAuth client with the provided store and default localhost client metadata. 94 /// 95 /// This is a convenience constructor for quickly setting up an OAuth client 96 /// with default localhost redirect URIs and "atproto transition:generic" scopes. 97 /// 98 /// # Example 99 /// 100 /// ```no_run 101 /// # use jacquard_oauth::client::OAuthClient; 102 /// # use jacquard_oauth::authstore::MemoryAuthStore; 103 /// # #[tokio::main] 104 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { 105 /// let store = MemoryAuthStore::new(); 106 /// let oauth = OAuthClient::with_default_config(store); 107 /// # Ok(()) 108 /// # } 109 /// ``` 110 pub fn with_default_config(store: S) -> Self { 111 let client_data = ClientData { 112 keyset: None, 113 config: crate::atproto::AtprotoClientMetadata::default_localhost(), 114 }; 115 Self::new(store, client_data, reqwest::Client::new()) 116 } 117} 118 119impl OAuthClient<JacquardResolver<reqwest::Client>, crate::authstore::MemoryAuthStore> { 120 /// Create an OAuth client with an in-memory auth store and default localhost client metadata. 121 /// 122 /// This is a convenience constructor for simple testing and development. 123 /// The session will not persist across restarts. 124 /// 125 /// # Example 126 /// 127 /// ```no_run 128 /// # use jacquard_oauth::client::OAuthClient; 129 /// # #[tokio::main] 130 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { 131 /// let oauth = OAuthClient::with_memory_store(); 132 /// # Ok(()) 133 /// # } 134 /// ``` 135 pub fn with_memory_store() -> Self { 136 Self::with_default_config(crate::authstore::MemoryAuthStore::new()) 137 } 138} 139 140impl<T, S> OAuthClient<T, S> 141where 142 T: OAuthResolver, 143 S: ClientAuthStore, 144{ 145 /// Create an OAuth client from an explicit resolver instance, taking ownership of both. 146 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<SmolStr>) -> Self { 147 // #[cfg(feature = "tracing")] 148 // tracing::info!( 149 // redirect_uris = ?client_data.config.redirect_uris, 150 // scopes = ?client_data.config.scopes, 151 // has_keyset = client_data.keyset.is_some(), 152 // "oauth client created:" 153 // ); 154 155 let client = Arc::new(client); 156 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data)); 157 Self { 158 registry, 159 client, 160 options: RwLock::new(CallOptions::default()), 161 endpoint: RwLock::new(None), 162 } 163 } 164 165 /// Create an OAuth client from already-`Arc`-wrapped store and resolver. 166 pub fn new_with_shared( 167 store: Arc<S>, 168 client: Arc<T>, 169 client_data: ClientData<SmolStr>, 170 ) -> Self { 171 let registry = Arc::new(SessionRegistry::new_shared( 172 store, 173 client.clone(), 174 client_data, 175 )); 176 Self { 177 registry, 178 client, 179 options: RwLock::new(CallOptions::default()), 180 endpoint: RwLock::new(None), 181 } 182 } 183} 184 185impl<T, S> OAuthClient<T, S> 186where 187 S: ClientAuthStore + Send + Sync + 'static, 188 T: OAuthResolver + DpopExt + Send + Sync + 'static, 189{ 190 /// Return the public JWK set for this client's keyset, or an empty set if no keyset is configured. 191 pub fn jwks(&self) -> JwkSet { 192 self.registry 193 .client_data 194 .keyset 195 .as_ref() 196 .map(|keyset| keyset.public_jwks()) 197 .unwrap_or_default() 198 } 199 /// Begin an OAuth authorization flow and return the URL to which the user should be redirected. 200 /// 201 /// This resolves OAuth metadata for the given `input` (a handle, DID, or PDS/entryway URL), 202 /// performs a Pushed Authorization Request (PAR) to the authorization server, persists the 203 /// resulting state for later callback verification, and returns a fully-constructed 204 /// authorization endpoint URL. 205 /// 206 /// The caller is responsible for redirecting the user's browser to the returned URL. 207 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self, input), fields(input = input.as_ref())))] 208 pub async fn start_auth<Str: BosStr>( 209 &self, 210 input: impl AsRef<str>, 211 options: AuthorizeOptions<Str>, 212 ) -> Result<String> 213 where 214 Str: FromStr + Ord + Clone + core::fmt::Debug, 215 <Str as FromStr>::Err: core::fmt::Debug, 216 { 217 let client_metadata = atproto_client_metadata( 218 &self.registry.client_data.config, 219 &self.registry.client_data.keyset, 220 )?; 221 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?; 222 let login_hint = if identity.is_some() { 223 Some(input.as_ref().into()) 224 } else { 225 None 226 }; 227 let mut metadata = OAuthMetadata { 228 server_metadata, 229 client_metadata, 230 keyset: self.registry.client_data.keyset.clone(), 231 }; 232 233 let auth_req_info = par( 234 self.client.as_ref(), 235 login_hint, 236 options.prompt, 237 &mut metadata, 238 options.state.map(|s| s.as_ref().to_smolstr()), 239 ) 240 .await?; 241 242 // Persist state for callback handling 243 self.registry 244 .store 245 .save_auth_req_info(&auth_req_info) 246 .await?; 247 248 #[derive(serde::Serialize)] 249 struct Parameters { 250 client_id: smol_str::SmolStr, 251 request_uri: smol_str::SmolStr, 252 } 253 Ok(metadata.server_metadata.authorization_endpoint.to_string() 254 + "?" 255 + &serde_html_form::to_string(Parameters { 256 client_id: metadata.client_metadata.client_id, 257 request_uri: auth_req_info.request_uri, 258 }) 259 .unwrap()) 260 } 261 262 /// Complete the OAuth authorization flow after the authorization server redirects back to the client. 263 /// 264 /// Validates the `state` and optional `iss` parameters, exchanges the authorization code for 265 /// tokens via the token endpoint, verifies the `sub` claim against the expected issuer, and 266 /// persists the resulting session. On success returns an [`OAuthSession`] ready for API calls. 267 /// 268 /// When the `scope-check` feature is enabled, this method also eagerly resolves any `include:` 269 /// scopes by fetching the referenced permission sets. `T` must implement 270 /// `LexiconSchemaResolver` in that case. 271 #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip_all, fields(state = params.state.as_ref().map(|s| s.as_str()))))] 272 #[cfg(not(feature = "scope-check"))] 273 pub async fn callback(&self, params: CallbackParams) -> Result<OAuthSession<T, S>> { 274 let client_data = self.callback_core(params).await?; 275 self.create_session(client_data).await 276 } 277 278 /// Complete the OAuth authorization flow (scope-check variant). 279 /// 280 /// Same as `callback`, but eagerly resolves `include:` scopes into 281 /// concrete permissions via `LexiconSchemaResolver`. 282 #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip_all, fields(state = params.state.as_ref().map(|s| s.as_str()))))] 283 #[cfg(feature = "scope-check")] 284 pub async fn callback(&self, params: CallbackParams) -> Result<OAuthSession<T, S>> 285 where 286 T: LexiconSchemaResolver, 287 { 288 let mut client_data = self.callback_core(params).await?; 289 client_data.resolved_scopes = 290 Some(resolve_include_scopes(self.client.as_ref(), &client_data.scopes).await?); 291 self.create_session(client_data).await 292 } 293 294 /// Shared callback logic: validate state/iss, exchange code, build session data. 295 async fn callback_core(&self, params: CallbackParams) -> Result<ClientSessionData> { 296 let Some(state_key) = params.state else { 297 return Err(CallbackError::MissingState.into()); 298 }; 299 300 let Some(auth_req_info) = self 301 .registry 302 .store 303 .get_auth_req_info(state_key.as_str()) 304 .await? 305 else { 306 return Err(CallbackError::MissingState.into()); 307 }; 308 309 self.registry 310 .store 311 .delete_auth_req_info(state_key.as_str()) 312 .await?; 313 314 let metadata = self 315 .client 316 .get_authorization_server_metadata(auth_req_info.authserver_url.as_str()) 317 .await?; 318 319 if let Some(iss) = params.iss { 320 if iss != metadata.issuer { 321 return Err(CallbackError::IssuerMismatch { 322 expected: metadata.issuer.to_string(), 323 got: iss.to_string(), 324 } 325 .into()); 326 } 327 } else if metadata.authorization_response_iss_parameter_supported == Some(true) { 328 return Err(CallbackError::MissingIssuer.into()); 329 } 330 let metadata = OAuthMetadata { 331 server_metadata: metadata, 332 client_metadata: atproto_client_metadata( 333 &self.registry.client_data.config, 334 &self.registry.client_data.keyset, 335 )?, 336 keyset: self.registry.client_data.keyset.clone(), 337 }; 338 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone(); 339 340 match exchange_code( 341 self.client.as_ref(), 342 &mut auth_req_info.dpop_data.clone(), 343 params.code.as_str(), 344 auth_req_info.pkce_verifier.as_str(), 345 &metadata, 346 ) 347 .await 348 { 349 Ok(token_set) => { 350 let scopes = if let Some(scope) = &token_set.scope { 351 Scopes::new(scope.as_str().to_smolstr()) 352 .expect("Failed to parse scopes from token response") 353 } else { 354 Scopes::empty() 355 }; 356 Ok(ClientSessionData { 357 account_did: token_set.sub.clone(), 358 session_id: auth_req_info.state, 359 host_url: Uri::parse(token_set.aud.as_str())?.to_owned(), 360 authserver_url: auth_req_info.authserver_url, 361 authserver_token_endpoint: auth_req_info.authserver_token_endpoint, 362 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint, 363 scopes, 364 dpop_data: DpopClientData { 365 dpop_key: auth_req_info.dpop_data.dpop_key.clone(), 366 dpop_authserver_nonce: authserver_nonce.unwrap_or_default(), 367 dpop_host_nonce: auth_req_info 368 .dpop_data 369 .dpop_authserver_nonce 370 .unwrap_or_default(), 371 }, 372 token_set, 373 resolved_scopes: None, 374 }) 375 } 376 Err(e) => Err(e.into()), 377 } 378 } 379 380 async fn create_session(&self, data: ClientSessionData) -> Result<OAuthSession<T, S>> { 381 self.registry.set(data.clone()).await?; 382 Ok(OAuthSession::new( 383 self.registry.clone(), 384 self.client.clone(), 385 data.into_static(), 386 )) 387 } 388 389 /// Restore a previously created session from the backing store, refreshing tokens if needed. 390 pub async fn restore( 391 &self, 392 did: &Did<impl BosStr + Send + Sync>, 393 session_id: &str, 394 ) -> Result<OAuthSession<T, S>> { 395 self.create_session(self.registry.get(did, session_id, true).await?) 396 .await 397 } 398 399 /// Resume a stored session for `input`, or begin OAuth authorization and return a login URL. 400 pub async fn resume_or_start_auth_for<Str: BosStr>( 401 &self, 402 input: impl AsRef<str>, 403 options: AuthorizeOptions<Str>, 404 ) -> Result<OAuthResumeOrLogin<T, S>> 405 where 406 S: SessionSelector<OAuthSessionMatch, Error = SessionStoreError>, 407 Str: FromStr + Ord + Clone + core::fmt::Debug, 408 <Str as FromStr>::Err: core::fmt::Debug, 409 { 410 let input = input.as_ref(); 411 self.resume_or_start_auth(&SessionHint::from_input(input), options) 412 .await 413 } 414 415 /// Resume a stored session for `hint`, or begin OAuth authorization from the hint identity. 416 pub async fn resume_or_start_auth<HintStr, Str>( 417 &self, 418 hint: &SessionHint<HintStr>, 419 options: AuthorizeOptions<Str>, 420 ) -> Result<OAuthResumeOrLogin<T, S>> 421 where 422 S: SessionSelector<OAuthSessionMatch, Error = SessionStoreError>, 423 HintStr: BosStr + Send + Sync, 424 Str: BosStr + FromStr + Ord + Clone + core::fmt::Debug, 425 <Str as FromStr>::Err: core::fmt::Debug, 426 { 427 let input = oauth_start_auth_input_from_hint(hint); 428 match OAuthSessionSelector::new(self.registry.store.as_ref(), self.client.as_ref()) 429 .select_session(hint) 430 .await? 431 { 432 Some(matched) => match self 433 .restore(&matched.key.did, matched.key.session_id.as_str()) 434 .await 435 { 436 Ok(session) => Ok(OAuthResumeOrLogin::Resumed(session)), 437 Err(err) if should_start_auth_after_restore_error(&err) => { 438 let Some(input) = input else { 439 return Err(err); 440 }; 441 Ok(OAuthResumeOrLogin::LoginUrl( 442 self.start_auth(input, options).await?, 443 )) 444 } 445 Err(err) => Err(err), 446 }, 447 None => { 448 let Some(input) = input else { 449 return Ok(OAuthResumeOrLogin::NeedsInput); 450 }; 451 Ok(OAuthResumeOrLogin::LoginUrl( 452 self.start_auth(input, options).await?, 453 )) 454 } 455 } 456 } 457 458 /// Revoke a session by deleting it from the backing store. 459 /// 460 /// Note: this removes the session from local storage but does **not** call the authorization 461 /// server's revocation endpoint. To also invalidate the token server-side, prefer 462 /// [`OAuthSession::logout`], which calls `revoke` on the token before deleting the session. 463 pub async fn revoke( 464 &self, 465 did: &Did<impl BosStr + Send + Sync>, 466 session_id: &str, 467 ) -> Result<()> { 468 Ok(self.registry.del(did, session_id).await?) 469 } 470} 471 472fn oauth_start_auth_input_from_hint<S: BosStr>(hint: &SessionHint<S>) -> Option<SmolStr> { 473 match hint { 474 SessionHint::Did(did) => Some(did.as_ref().to_smolstr()), 475 SessionHint::Handle(handle) => Some(handle.as_ref().to_smolstr()), 476 SessionHint::Key(key) => Some(key.did.as_str().to_smolstr()), 477 SessionHint::Identifier(identifier) => Some(identifier.as_ref().to_smolstr()), 478 SessionHint::Any => None, 479 } 480} 481 482fn should_start_auth_after_restore_error(err: &OAuthError) -> bool { 483 matches!(err, OAuthError::Session(session_err) if session_err.is_permanent()) 484} 485 486/// Decode a percent-encoded audience string. 487/// 488/// The audience may contain percent-encoded characters like `%23` for `#`. 489/// This function decodes those and returns the decoded string. 490#[cfg(feature = "scope-check")] 491fn decode_audience(aud: &str) -> Result<String> { 492 // Use fluent_uri's percent-decoding to handle encoded characters. 493 // The audience is typically a DID, possibly with a fragment. 494 // EStr::new returns Option<&EStr>, so we match on that. 495 match EStr::<Query>::new(aud) { 496 Some(estr) => { 497 // estr.decode() returns a Decode struct 498 // The Decode type has a to_string() method that returns Result<Cow<str>, Vec<u8>> 499 let decoded = estr.decode(); 500 match decoded.to_string() { 501 Ok(cow) => Ok(cow.into_owned()), 502 Err(bytes) => Err(crate::error::CallbackError::ScopeResolution { 503 detail: format!( 504 "percent-decoded audience contains invalid UTF-8: {:?}", 505 bytes 506 ), 507 } 508 .into()), 509 } 510 } 511 None => { 512 // If it's not a valid percent-encoded string, use it as-is. 513 // This handles cases where no encoding was applied. 514 Ok(aud.to_string()) 515 } 516 } 517} 518 519/// Resolve all `include:` scopes in the given scope set into concrete permissions. 520/// 521/// Non-include scopes are passed through unchanged. Each `include:` scope is 522/// resolved via `resolve_permission_set`, which fetches the permission set 523/// lexicon and expands it into concrete `Scope<SmolStr>` values. 524#[cfg(feature = "scope-check")] 525async fn resolve_include_scopes<R>( 526 resolver: &R, 527 scopes: &Scopes<SmolStr>, 528) -> Result<Vec<Scope<SmolStr>>> 529where 530 R: OAuthResolver + LexiconSchemaResolver + Send + Sync, 531{ 532 let mut resolved = Vec::new(); 533 for scope in scopes.iter() { 534 match scope { 535 Scope::Include(IncludeScope { nsid, audience }) => { 536 let audience_did = if let Some(aud_str) = audience { 537 let decoded = decode_audience(aud_str)?; 538 match DidService::new_owned(&decoded) { 539 Ok(did) => Some(did), 540 Err(_) => { 541 return Err(crate::error::CallbackError::ScopeResolution { 542 detail: format!( 543 "invalid DID in include scope audience: {}", 544 decoded 545 ), 546 } 547 .into()); 548 } 549 } 550 } else { 551 None 552 }; 553 554 let nsid_smolstr = match Nsid::<SmolStr>::new_owned(nsid.as_str()) { 555 Ok(n) => n, 556 Err(_) => { 557 return Err(crate::error::CallbackError::ScopeResolution { 558 detail: format!("invalid NSID in include scope: {}", nsid), 559 } 560 .into()); 561 } 562 }; 563 564 let expanded = 565 resolve_permission_set(resolver, &nsid_smolstr, audience_did.as_ref()).await?; 566 resolved.extend(expanded); 567 } 568 other => { 569 resolved.push(other.convert()); 570 } 571 } 572 } 573 Ok(resolved) 574} 575 576impl<T, S> HttpClient for OAuthClient<T, S> 577where 578 S: ClientAuthStore + Send + Sync + 'static, 579 T: OAuthResolver + DpopExt + Send + Sync + 'static, 580{ 581 type Error = T::Error; 582 583 async fn send_http( 584 &self, 585 request: http::Request<Vec<u8>>, 586 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 587 self.client.send_http(request).await 588 } 589} 590 591impl<T, S> IdentityResolver for OAuthClient<T, S> 592where 593 S: ClientAuthStore + Send + Sync + 'static, 594 T: OAuthResolver + DpopExt + Send + Sync + 'static, 595{ 596 fn options(&self) -> &ResolverOptions { 597 self.client.options() 598 } 599 600 async fn resolve_handle<Str: BosStr + Sync>( 601 &self, 602 handle: &Handle<Str>, 603 ) -> jacquard_identity::resolver::Result<Did> { 604 self.client.resolve_handle(handle).await 605 } 606 607 async fn resolve_did_doc<Str: BosStr + Sync>( 608 &self, 609 did: &Did<Str>, 610 ) -> jacquard_identity::resolver::Result<DidDocResponse> { 611 self.client.resolve_did_doc(did).await 612 } 613} 614 615impl<T, S> XrpcClient for OAuthClient<T, S> 616where 617 S: ClientAuthStore + Send + Sync + 'static, 618 T: OAuthResolver + DpopExt + Send + Sync + 'static, 619{ 620 async fn base_uri(&self) -> Uri<String> { 621 self.endpoint.read().await.clone().unwrap_or_else(|| { 622 Uri::parse("https://public.api.bsky.app") 623 .expect("hardcoded URI is valid") 624 .to_owned() 625 }) 626 } 627 628 async fn opts(&self) -> CallOptions { 629 self.options.read().await.clone() 630 } 631 632 async fn set_opts(&self, opts: CallOptions) { 633 let mut guard = self.options.write().await; 634 *guard = opts.into_static(); 635 } 636 637 async fn set_base_uri(&self, uri: Uri<String>) { 638 let normalized = jacquard_common::xrpc::normalize_base_uri(uri); 639 let mut guard = self.endpoint.write().await; 640 *guard = Some(normalized); 641 } 642 643 async fn send<R>(&self, request: R) -> XrpcResult<XrpcResponse<R>> 644 where 645 R: XrpcRequest + Send + Sync + serde::Serialize, 646 <R as XrpcRequest>::Response: Send + Sync, 647 { 648 let opts = self.options.read().await.clone(); 649 self.send_with_opts(request, opts).await 650 } 651 652 async fn send_with_opts<R>(&self, request: R, opts: CallOptions) -> XrpcResult<XrpcResponse<R>> 653 where 654 R: XrpcRequest + Send + Sync + serde::Serialize, 655 <R as XrpcRequest>::Response: Send + Sync, 656 { 657 let base_uri = self.base_uri().await; 658 let http_request = build_http_request(&base_uri.borrow(), &request, &opts)?; 659 let http_response = self 660 .client 661 .send_http(http_request) 662 .await 663 .map_err(|e| ClientError::transport(e).for_nsid(R::NSID))?; 664 process_response(http_response) 665 } 666} 667 668/// An active OAuth session for a specific account, used to make authenticated API requests. 669/// 670/// `OAuthSession` holds the DPoP-bound token set for one account and handles transparent 671/// token refresh on `401 invalid_token` responses. The optional `W` type parameter allows 672/// attaching a WebSocket client (defaults to `()` when WebSocket support is not needed). 673/// 674/// Obtain an `OAuthSession` from [`OAuthClient::callback`] or [`OAuthClient::restore`]. 675pub struct OAuthSession<T, S, W = ()> 676where 677 T: OAuthResolver, 678 S: ClientAuthStore, 679{ 680 /// Shared registry used to persist and retrieve session data across refresh operations. 681 pub registry: Arc<SessionRegistry<T, S, SmolStr>>, 682 /// Underlying HTTP/identity/OAuth resolver shared with the parent `OAuthClient`. 683 pub client: Arc<T>, 684 /// Optional WebSocket client; `()` when WebSocket support is not required. 685 pub ws_client: W, 686 /// Mutable session data including DPoP key, nonces, and token set. 687 pub data: RwLock<ClientSessionData>, 688 /// Default call options applied to every outgoing XRPC request from this session. 689 pub options: RwLock<CallOptions>, 690} 691 692impl<T, S> OAuthSession<T, S, ()> 693where 694 T: OAuthResolver, 695 S: ClientAuthStore, 696{ 697 /// Create a new session without a WebSocket client. 698 /// 699 /// This is the standard constructor used by [`OAuthClient::callback`] and 700 /// [`OAuthClient::restore`]. For WebSocket support use [`OAuthSession::new_with_ws`]. 701 pub fn new( 702 registry: Arc<SessionRegistry<T, S, SmolStr>>, 703 client: Arc<T>, 704 data: ClientSessionData, 705 ) -> Self { 706 Self { 707 registry, 708 client, 709 ws_client: (), 710 data: RwLock::new(data), 711 options: RwLock::new(CallOptions::default()), 712 } 713 } 714} 715 716impl<T, S, W> OAuthSession<T, S, W> 717where 718 T: OAuthResolver, 719 S: ClientAuthStore, 720{ 721 /// Create a new session with an attached WebSocket client. 722 /// 723 /// Use this variant when the session needs to support WebSocket subscriptions in addition 724 /// to standard XRPC calls. The `ws_client` is exposed via [`OAuthSession::ws_client`] and 725 /// is used by the `WebSocketClient` impl when the `websocket` feature is enabled. 726 pub fn new_with_ws( 727 registry: Arc<SessionRegistry<T, S, SmolStr>>, 728 client: Arc<T>, 729 ws_client: W, 730 data: ClientSessionData, 731 ) -> Self { 732 Self { 733 registry, 734 client, 735 ws_client, 736 data: RwLock::new(data), 737 options: RwLock::new(CallOptions::default()), 738 } 739 } 740 741 /// Consume this session and return a new one with the given call options pre-applied. 742 /// 743 /// Useful for setting request-level defaults (e.g., `atproto-proxy` or custom headers) once 744 /// at construction time rather than passing them to every individual XRPC call. 745 pub fn with_options(self, options: CallOptions) -> Self { 746 Self { 747 registry: self.registry, 748 client: self.client, 749 ws_client: self.ws_client, 750 data: self.data, 751 options: RwLock::new(options.into_static()), 752 } 753 } 754 755 /// Get a reference to the WebSocket client. 756 pub fn ws_client(&self) -> &W { 757 &self.ws_client 758 } 759 760 /// Replace the default call options for this session without consuming it. 761 pub async fn set_options(&self, options: CallOptions) { 762 *self.options.write().await = options.into_static(); 763 } 764 765 /// Return the DID and session ID for this session. 766 /// 767 /// The session ID is the random `state` token generated during the PAR flow and can 768 /// be used together with the DID to restore the session via [`OAuthClient::restore`]. 769 pub async fn session_info(&self) -> (Did, smol_str::SmolStr) { 770 let data = self.data.read().await; 771 (data.account_did.clone(), data.session_id.clone()) 772 } 773 774 /// Return the resource server (PDS) base URI for this session. 775 pub async fn endpoint(&self) -> Uri<String> { 776 self.data.read().await.host_url.clone() 777 } 778 779 /// Return the current DPoP-bound access token for this session. 780 /// 781 /// The token may be stale if it has expired; use [`OAuthSession::refresh`] or 782 /// rely on the automatic refresh performed by `send_with_opts` to obtain a fresh one. 783 pub async fn access_token(&self) -> AuthorizationToken<SmolStr> { 784 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone()) 785 } 786 787 /// Return the current refresh token for this session, if one is present. 788 /// 789 /// Not all authorization servers issue refresh tokens. When `None` is returned, 790 /// the session cannot be silently renewed and the user must re-authenticate. 791 pub async fn refresh_token(&self) -> Option<AuthorizationToken<SmolStr>> { 792 self.data 793 .read() 794 .await 795 .token_set 796 .refresh_token 797 .clone() 798 .map(|t| AuthorizationToken::Dpop(t)) 799 } 800 801 /// Derive an unauthenticated [`OAuthClient`] that shares the same registry and resolver. 802 /// 803 /// Useful when you need to initiate a new authorization flow from within an existing 804 /// session context (e.g., to add a second account) without constructing a fresh client. 805 pub fn to_client(&self) -> OAuthClient<T, S> { 806 OAuthClient::from_session(self) 807 } 808} 809impl<T, S, W> OAuthSession<T, S, W> 810where 811 S: ClientAuthStore + Send + Sync + 'static, 812 T: OAuthResolver + DpopExt + Send + Sync + 'static, 813{ 814 /// Revoke the access token at the authorization server and delete the session from the store. 815 /// 816 /// Revocation is best-effort: if the server does not advertise a revocation endpoint, or if 817 /// the revocation call fails, the session is still deleted locally. This prevents a dangling 818 /// session record from blocking future logins for the same account. 819 pub async fn logout(&self) -> Result<()> { 820 use crate::request::{OAuthMetadata, revoke}; 821 let mut data = self.data.write().await; 822 let meta = 823 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?; 824 if meta.server_metadata.revocation_endpoint.is_some() { 825 let token = data.token_set.access_token.clone(); 826 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta) 827 .await 828 .ok(); 829 } 830 // Remove from store 831 self.registry 832 .del(&data.account_did, &data.session_id) 833 .await?; 834 Ok(()) 835 } 836} 837 838impl<T, S> OAuthClient<T, S> 839where 840 T: OAuthResolver, 841 S: ClientAuthStore, 842{ 843 /// Construct an `OAuthClient` that shares the registry and resolver of an existing session. 844 /// 845 /// Equivalent to [`OAuthSession::to_client`]; provided on `OAuthClient` for symmetry so 846 /// callers can obtain an unauthenticated client without holding a session reference. 847 pub fn from_session<W>(session: &OAuthSession<T, S, W>) -> Self { 848 Self { 849 registry: session.registry.clone(), 850 client: session.client.clone(), 851 options: RwLock::new(CallOptions::default()), 852 endpoint: RwLock::new(None), 853 } 854 } 855} 856impl<T, S, W> OAuthSession<T, S, W> 857where 858 S: ClientAuthStore + Send + Sync + 'static, 859 T: OAuthResolver + DpopExt + Send + Sync + 'static, 860{ 861 /// Explicitly refresh the access token using the stored refresh token. 862 /// 863 /// On success the new token set is written back into both the in-memory session data and 864 /// the backing store. The returned `AuthorizationToken` is the new access token, which 865 /// callers can immediately use to retry a failed request. 866 /// 867 /// The actual token exchange is serialized per `(DID, session_id)` pair via a `Mutex` inside 868 /// the registry, so concurrent refresh attempts will not result in duplicate token exchanges. 869 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))] 870 pub async fn refresh(&self) -> Result<AuthorizationToken<SmolStr>> { 871 // Read identifiers without holding the lock across await 872 let (did, sid) = { 873 let data = self.data.read().await; 874 (data.account_did.clone(), data.session_id.clone()) 875 }; 876 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?; 877 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone()); 878 // Write back updated session 879 *self.data.write().await = refreshed.clone().into_static(); 880 // Store in the registry 881 self.registry.set(refreshed).await?; 882 Ok(token) 883 } 884} 885 886impl<T, S, W> HttpClient for OAuthSession<T, S, W> 887where 888 S: ClientAuthStore + Send + Sync + 'static, 889 T: OAuthResolver + DpopExt + Send + Sync + 'static, 890 W: Send + Sync, 891{ 892 type Error = T::Error; 893 894 async fn send_http( 895 &self, 896 request: http::Request<Vec<u8>>, 897 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 898 self.client.send_http(request).await 899 } 900} 901 902impl<T, S, W> XrpcClient for OAuthSession<T, S, W> 903where 904 S: ClientAuthStore + Send + Sync + 'static, 905 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static, 906 W: Send + Sync, 907{ 908 async fn base_uri(&self) -> Uri<String> { 909 self.data.read().await.host_url.clone() 910 } 911 912 async fn opts(&self) -> CallOptions { 913 self.options.read().await.clone() 914 } 915 916 async fn set_opts(&self, opts: CallOptions) { 917 let mut guard = self.options.write().await; 918 *guard = opts.into_static(); 919 } 920 921 async fn set_base_uri(&self, uri: Uri<String>) { 922 let normalized = jacquard_common::xrpc::normalize_base_uri(uri); 923 let mut guard = self.data.write().await; 924 guard.host_url = normalized; 925 } 926 927 async fn send<R>(&self, request: R) -> XrpcResult<XrpcResponse<R>> 928 where 929 R: XrpcRequest + Send + Sync + serde::Serialize, 930 <R as XrpcRequest>::Response: Send + Sync, 931 { 932 let opts = self.options.read().await.clone(); 933 self.send_with_opts(request, opts).await 934 } 935 936 async fn send_with_opts<R>( 937 &self, 938 request: R, 939 mut opts: CallOptions, 940 ) -> XrpcResult<XrpcResponse<R>> 941 where 942 R: XrpcRequest + Send + Sync + serde::Serialize, 943 <R as XrpcRequest>::Response: Send + Sync, 944 { 945 // Pre-flight scope check: pure in-memory, no HTTP. 946 #[cfg(feature = "scope-check")] 947 { 948 self.check_scope::<R>().await.map_err(|e| { 949 ClientError::invalid_request(format!("scope check failed: {:?}", e)) 950 .for_nsid(R::NSID) 951 })?; 952 } 953 954 let base_uri = self.base_uri().await; 955 let original_token = self.access_token().await; 956 opts.auth = Some(original_token.clone()); 957 // Clone dpop_data and release read lock before the await point 958 let mut dpop = self.data.read().await.dpop_data.clone(); 959 let http_response = self 960 .client 961 .dpop_call(&mut dpop) 962 .send(build_http_request(&base_uri.borrow(), &request, &opts)?) 963 .await 964 .map_err(|e| ClientError::from(e).for_nsid(R::NSID))?; 965 let resp = process_response(http_response); 966 967 // Write back updated nonce to session data (dpop_call may have updated it) 968 { 969 let mut guard = self.data.write().await; 970 guard.dpop_data.dpop_host_nonce = dpop.dpop_host_nonce.clone(); 971 } 972 973 if is_invalid_token_response(&resp) { 974 // Optimistic refresh: check if another request already refreshed the token 975 let current_token = self.access_token().await; 976 if current_token != original_token { 977 // Token was already refreshed by another concurrent request, use it 978 opts.auth = Some(current_token); 979 } else { 980 // We need to refresh - this will be serialized by the registry's Mutex 981 opts.auth = Some( 982 self.refresh() 983 .await 984 .map_err(|e| ClientError::transport(e))?, 985 ); 986 } 987 // Re-read dpop_data after refresh (refresh may have updated it) 988 let mut dpop = self.data.read().await.dpop_data.clone(); 989 let http_response = self 990 .client 991 .dpop_call(&mut dpop) 992 .send(build_http_request(&base_uri.borrow(), &request, &opts)?) 993 .await 994 .map_err(|e| { 995 ClientError::from(e) 996 .for_nsid(R::NSID) 997 .append_context("after token refresh") 998 })?; 999 let resp = process_response(http_response); 1000 1001 // Write back updated nonce after retry 1002 { 1003 let mut guard = self.data.write().await; 1004 guard.dpop_data.dpop_host_nonce = dpop.dpop_host_nonce.clone(); 1005 } 1006 1007 resp 1008 } else { 1009 resp 1010 } 1011 } 1012} 1013 1014#[cfg(feature = "scope-check")] 1015impl<T, S, W> OAuthSession<T, S, W> 1016where 1017 S: ClientAuthStore + Send + Sync + 'static, 1018 T: OAuthResolver + Send + Sync + 'static, 1019 W: Send + Sync, 1020{ 1021 /// Check whether the session's resolved scopes grant access to 1022 /// the XRPC method identified by `R::NSID`. 1023 async fn check_scope<R: XrpcRequest>(&self) -> core::result::Result<(), ScopeError> { 1024 let data = self.data.read().await; 1025 1026 // Use the resolved scopes from Phase 5's eager resolution. 1027 // These are fully expanded — no include scopes remain. 1028 let resolved = data.resolved_scopes.as_ref(); 1029 1030 let is_permitted = match resolved { 1031 Some(scopes) => { 1032 let nsid = Nsid::<SmolStr>::new_static(R::NSID).expect("valid NSID"); 1033 1034 // Check if any granted scope covers this NSID. A request 1035 // may be covered by rpc: scopes (method access) or repo: 1036 // scopes (record operations). 1037 // 1038 // Note: `atproto` is the minimum base scope (auth only). 1039 // It does NOT grant rpc/repo access. 1040 // 1041 // For rpc: scopes, we check only the lxm (method) match 1042 // and ignore audience. At pre-flight time the client does 1043 // not know the target audience — audience enforcement is 1044 // the server's responsibility. A granted scope with a 1045 // specific aud (e.g., did:web:api.bsky.app) still permits 1046 // calling the method from the client's perspective. 1047 let rpc_ok = scopes.iter().any(|s| match s { 1048 Scope::Rpc(rpc) => rpc.lxm.iter().any(|l| match l { 1049 RpcLexicon::All => true, 1050 RpcLexicon::Nsid(granted_nsid) => granted_nsid.as_ref() == nsid.as_ref(), 1051 }), 1052 _ => false, 1053 }); 1054 1055 // For repo: scopes, check if the NSID matches a granted 1056 // collection. Any action suffices for pre-flight. 1057 let repo_ok = scopes.iter().any(|s| match s { 1058 Scope::Repo(repo) => match &repo.collection { 1059 RepoCollection::All => true, 1060 RepoCollection::Nsid(col) => col.as_ref() == nsid.as_ref(), 1061 }, 1062 _ => false, 1063 }); 1064 1065 rpc_ok || repo_ok 1066 } 1067 None => { 1068 // No resolved scopes means resolution was skipped 1069 // (e.g., no include scopes were present, or scope-check 1070 // was enabled after session creation). Allow the request. 1071 true 1072 } 1073 }; 1074 1075 if !is_permitted { 1076 let granted_summary = resolved 1077 .map(|scopes| { 1078 scopes 1079 .iter() 1080 .map(|s| s.to_string_normalized()) 1081 .collect::<Vec<_>>() 1082 .join(", ") 1083 }) 1084 .unwrap_or_default(); 1085 1086 return Err(ScopeError { 1087 nsid: SmolStr::new_static(R::NSID), 1088 granted: SmolStr::from(granted_summary), 1089 }); 1090 } 1091 1092 Ok(()) 1093 } 1094} 1095 1096#[cfg(feature = "streaming")] 1097impl<T, S, W> jacquard_common::http_client::HttpClientExt for OAuthSession<T, S, W> 1098where 1099 S: ClientAuthStore + Send + Sync + 'static, 1100 T: OAuthResolver 1101 + DpopExt 1102 + XrpcExt 1103 + jacquard_common::http_client::HttpClientExt 1104 + Send 1105 + Sync 1106 + 'static, 1107 W: Send + Sync, 1108{ 1109 async fn send_http_streaming( 1110 &self, 1111 request: http::Request<Vec<u8>>, 1112 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error> 1113 { 1114 self.client.send_http_streaming(request).await 1115 } 1116 1117 #[cfg(not(target_arch = "wasm32"))] 1118 async fn send_http_bidirectional<Str>( 1119 &self, 1120 parts: http::request::Parts, 1121 body: Str, 1122 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error> 1123 where 1124 Str: n0_future::Stream< 1125 Item = core::result::Result<bytes::Bytes, jacquard_common::StreamError>, 1126 > + Send 1127 + 'static, 1128 { 1129 self.client.send_http_bidirectional(parts, body).await 1130 } 1131 1132 #[cfg(target_arch = "wasm32")] 1133 async fn send_http_bidirectional<Str>( 1134 &self, 1135 parts: http::request::Parts, 1136 body: Str, 1137 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error> 1138 where 1139 Str: n0_future::Stream< 1140 Item = core::result::Result<bytes::Bytes, jacquard_common::StreamError>, 1141 > + 'static, 1142 { 1143 self.client.send_http_bidirectional(parts, body).await 1144 } 1145} 1146 1147#[cfg(feature = "streaming")] 1148impl<T, S, W> jacquard_common::xrpc::XrpcStreamingClient for OAuthSession<T, S, W> 1149where 1150 S: ClientAuthStore + Send + Sync + 'static, 1151 T: OAuthResolver 1152 + DpopExt 1153 + XrpcExt 1154 + jacquard_common::http_client::HttpClientExt 1155 + Send 1156 + Sync 1157 + 'static, 1158 W: Send + Sync, 1159{ 1160 async fn download<R>( 1161 &self, 1162 request: R, 1163 ) -> core::result::Result<jacquard_common::xrpc::StreamingResponse, jacquard_common::StreamError> 1164 where 1165 R: XrpcRequest + Send + Sync + serde::Serialize, 1166 <R as XrpcRequest>::Response: Send + Sync, 1167 { 1168 use jacquard_common::StreamError; 1169 1170 let base_uri = <Self as XrpcClient>::base_uri(self).await; 1171 let mut opts = self.options.read().await.clone(); 1172 opts.auth = Some(self.access_token().await); 1173 let http_request = build_http_request(&base_uri.borrow(), &request, &opts) 1174 .map_err(|e| StreamError::protocol(e.to_string()))?; 1175 let guard = self.data.read().await; 1176 let mut dpop = guard.dpop_data.clone(); 1177 let result = self 1178 .client 1179 .dpop_call(&mut dpop) 1180 .send_streaming(http_request) 1181 .await; 1182 drop(guard); 1183 1184 match result { 1185 Ok(response) => Ok(response), 1186 Err(_e) => { 1187 // Check if it's an auth error and retry 1188 opts.auth = Some( 1189 self.refresh() 1190 .await 1191 .map_err(|e| StreamError::transport(e))?, 1192 ); 1193 let http_request = build_http_request(&base_uri.borrow(), &request, &opts) 1194 .map_err(|e| StreamError::protocol(e.to_string()))?; 1195 let guard = self.data.read().await; 1196 let mut dpop = guard.dpop_data.clone(); 1197 self.client 1198 .dpop_call(&mut dpop) 1199 .send_streaming(http_request) 1200 .await 1201 .map_err(StreamError::transport) 1202 } 1203 } 1204 } 1205 1206 async fn stream<Str, B>( 1207 &self, 1208 stream: jacquard_common::xrpc::streaming::XrpcProcedureSend<Str::Frame<B>>, 1209 ) -> core::result::Result< 1210 jacquard_common::xrpc::streaming::XrpcResponseStream< 1211 <<Str as jacquard_common::xrpc::streaming::XrpcProcedureStream>::Response as jacquard_common::xrpc::streaming::XrpcStreamResp>::Frame<B>, 1212 >, 1213 jacquard_common::StreamError, 1214 > 1215 where 1216 Str: jacquard_common::xrpc::streaming::XrpcProcedureStream + 'static, 1217 <<Str as jacquard_common::xrpc::streaming::XrpcProcedureStream>::Response as jacquard_common::xrpc::streaming::XrpcStreamResp>::Frame<B>: jacquard_common::xrpc::streaming::XrpcStreamResp, 1218 B: BosStr + 'static, 1219 { 1220 use jacquard_common::StreamError; 1221 use n0_future::TryStreamExt; 1222 1223 let base_uri = self.base_uri().await; 1224 let mut opts = self.options.read().await.clone(); 1225 opts.auth = Some(self.access_token().await); 1226 1227 let mut path = String::from(base_uri.as_str().trim_end_matches('/')); 1228 path.push_str("/xrpc/"); 1229 path.push_str(<Str::Request as jacquard_common::xrpc::XrpcRequest>::NSID); 1230 1231 let mut builder = http::Request::post(path); 1232 1233 if let Some(token) = &opts.auth { 1234 use jacquard_common::AuthorizationToken; 1235 let hv = match token { 1236 AuthorizationToken::Bearer(t) => { 1237 http::HeaderValue::from_str(&format!("Bearer {}", t.as_str())) 1238 } 1239 AuthorizationToken::Dpop(t) => { 1240 http::HeaderValue::from_str(&format!("DPoP {}", t.as_str())) 1241 } 1242 } 1243 .map_err(|e| StreamError::protocol(format!("Invalid authorization token: {}", e)))?; 1244 builder = builder.header(http::header::AUTHORIZATION, hv); 1245 } 1246 1247 if let Some(proxy) = &opts.atproto_proxy { 1248 builder = builder.header("atproto-proxy", proxy.as_str()); 1249 } 1250 if let Some(labelers) = &opts.atproto_accept_labelers { 1251 if !labelers.is_empty() { 1252 let joined = labelers 1253 .iter() 1254 .map(|s| s.as_ref()) 1255 .collect::<Vec<_>>() 1256 .join(", "); 1257 builder = builder.header("atproto-accept-labelers", joined); 1258 } 1259 } 1260 for (name, value) in &opts.extra_headers { 1261 builder = builder.header(name, value); 1262 } 1263 1264 let (parts, _) = builder 1265 .body(()) 1266 .map_err(|e| StreamError::protocol(e.to_string()))? 1267 .into_parts(); 1268 1269 let body_stream = 1270 jacquard_common::stream::ByteStream::new(Box::pin(stream.0.map_ok(|f| f.buffer))); 1271 1272 let guard = self.data.read().await; 1273 let mut dpop = guard.dpop_data.clone(); 1274 let result = self 1275 .client 1276 .dpop_call(&mut dpop) 1277 .send_bidirectional(parts, body_stream) 1278 .await; 1279 drop(guard); 1280 1281 match result { 1282 Ok(response) => { 1283 let (resp_parts, resp_body) = response.into_parts(); 1284 Ok( 1285 jacquard_common::xrpc::streaming::XrpcResponseStream::from_typed_parts::<B>( 1286 resp_parts, resp_body, 1287 ), 1288 ) 1289 } 1290 Err(e) => { 1291 // OAuth token refresh and retry is handled by dpop wrapper 1292 // If we get here, it's a real error 1293 Err(StreamError::transport(e)) 1294 } 1295 } 1296 } 1297} 1298 1299fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool { 1300 use jacquard_common::error::ClientErrorKind; 1301 1302 match response { 1303 Err(e) => match e.kind() { 1304 ClientErrorKind::Auth(AuthError::InvalidToken) => true, 1305 ClientErrorKind::Auth(AuthError::Other(value)) => value 1306 .to_str() 1307 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), 1308 _ => false, 1309 }, 1310 // Some servers return 200/401 with an error in the body rather than using 1311 // WWW-Authenticate. Check the raw response bytes for the invalid_token pattern. 1312 Ok(resp) => { 1313 resp.status() == http::StatusCode::UNAUTHORIZED 1314 || serde_json::from_slice::<serde_json::Value>(resp.buffer()) 1315 .ok() 1316 .and_then(|v| v.get("error")?.as_str().map(|s| s == "invalid_token")) 1317 .unwrap_or(false) 1318 } 1319 } 1320} 1321 1322impl<T, S, W> IdentityResolver for OAuthSession<T, S, W> 1323where 1324 S: ClientAuthStore + Send + Sync + 'static, 1325 T: OAuthResolver + IdentityResolver + XrpcExt + Send + Sync + 'static, 1326 W: Send + Sync, 1327{ 1328 fn options(&self) -> &ResolverOptions { 1329 self.client.options() 1330 } 1331 1332 async fn resolve_handle<Str: BosStr + Sync>( 1333 &self, 1334 handle: &Handle<Str>, 1335 ) -> std::result::Result<Did, IdentityError> { 1336 self.client.resolve_handle(handle).await 1337 } 1338 1339 async fn resolve_did_doc<Str: BosStr + Sync>( 1340 &self, 1341 did: &Did<Str>, 1342 ) -> std::result::Result<DidDocResponse, IdentityError> { 1343 self.client.resolve_did_doc(did).await 1344 } 1345} 1346 1347#[cfg(feature = "websocket")] 1348impl<T, S, W> WebSocketClient for OAuthSession<T, S, W> 1349where 1350 S: ClientAuthStore + Send + Sync + 'static, 1351 T: OAuthResolver + Send + Sync + 'static, 1352 W: WebSocketClient + Send + Sync, 1353{ 1354 type Error = W::Error; 1355 1356 async fn connect( 1357 &self, 1358 uri: Uri<&str>, 1359 ) -> std::result::Result<WebSocketConnection, Self::Error> { 1360 self.ws_client.connect(uri).await 1361 } 1362 1363 async fn connect_with_headers( 1364 &self, 1365 uri: Uri<&str>, 1366 headers: Vec<(CowStr<'_>, CowStr<'_>)>, 1367 ) -> std::result::Result<WebSocketConnection, Self::Error> { 1368 self.ws_client.connect_with_headers(uri, headers).await 1369 } 1370} 1371 1372#[cfg(feature = "websocket")] 1373impl<T, S, W> jacquard_common::xrpc::SubscriptionClient for OAuthSession<T, S, W> 1374where 1375 S: ClientAuthStore + Send + Sync + 'static, 1376 T: OAuthResolver + Send + Sync + 'static, 1377 W: WebSocketClient + Send + Sync, 1378{ 1379 async fn base_uri(&self) -> Uri<String> { 1380 self.data.read().await.host_url.clone() 1381 } 1382 1383 async fn subscription_opts(&self) -> jacquard_common::xrpc::SubscriptionOptions<'_> { 1384 let mut opts = jacquard_common::xrpc::SubscriptionOptions::default(); 1385 let token = self.access_token().await; 1386 let auth_value = match token { 1387 AuthorizationToken::Bearer(t) => format!("Bearer {}", t.as_str()), 1388 AuthorizationToken::Dpop(t) => format!("DPoP {}", t.as_str()), 1389 }; 1390 opts.headers 1391 .push((CowStr::from("Authorization"), CowStr::from(auth_value))); 1392 opts 1393 } 1394 1395 async fn subscribe<Sub>( 1396 &self, 1397 params: &Sub, 1398 ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error> 1399 where 1400 Sub: XrpcSubscription + Send + Sync + serde::Serialize, 1401 { 1402 let opts = self.subscription_opts().await; 1403 self.subscribe_with_opts(params, opts).await 1404 } 1405 1406 async fn subscribe_with_opts<Sub>( 1407 &self, 1408 params: &Sub, 1409 opts: jacquard_common::xrpc::SubscriptionOptions<'_>, 1410 ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error> 1411 where 1412 Sub: XrpcSubscription + Send + Sync + serde::Serialize, 1413 { 1414 use jacquard_common::xrpc::SubscriptionExt; 1415 let base = self.base_uri().await; 1416 self.subscription(base) 1417 .with_options(opts) 1418 .subscribe(params) 1419 .await 1420 } 1421} 1422 1423#[cfg(all(test, feature = "scope-check"))] 1424mod tests { 1425 use super::*; 1426 use crate::scopes::{RepoAction, RepoScope, RpcAudience, RpcLexicon, RpcScope}; 1427 use std::collections::BTreeSet; 1428 1429 /// Test that a scope granting access to an RPC method works correctly. 1430 #[test] 1431 fn test_scope_check_permits_matching_rpc() { 1432 // AC7.1: Session with rpc:com.example.test grants access to com.example.test. 1433 let mut rpc_scope_set = BTreeSet::new(); 1434 rpc_scope_set.insert(RpcLexicon::Nsid( 1435 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1436 )); 1437 let mut aud_set = BTreeSet::new(); 1438 aud_set.insert(RpcAudience::All); 1439 1440 let granted_scope = Scope::Rpc(RpcScope { 1441 lxm: rpc_scope_set, 1442 aud: aud_set, 1443 }); 1444 1445 // Target scope for a request to com.example.test. 1446 let mut target_lxm = BTreeSet::new(); 1447 target_lxm.insert(RpcLexicon::Nsid( 1448 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1449 )); 1450 let mut target_aud = BTreeSet::new(); 1451 target_aud.insert(RpcAudience::All); 1452 1453 let target_scope = Scope::Rpc(RpcScope { 1454 lxm: target_lxm, 1455 aud: target_aud, 1456 }); 1457 1458 // The granted scope should permit the target scope. 1459 assert!( 1460 granted_scope.grants(&target_scope), 1461 "rpc:com.example.test should grant access to com.example.test" 1462 ); 1463 } 1464 1465 /// Test that rpc:* wildcard grants access to all RPC methods. 1466 #[test] 1467 fn test_scope_check_permits_rpc_wildcard() { 1468 // AC7.1: Session with rpc:* (wildcard) grants access to any RPC method. 1469 let mut rpc_scope_set: BTreeSet<RpcLexicon<SmolStr>> = BTreeSet::new(); 1470 rpc_scope_set.insert(RpcLexicon::All); 1471 let mut aud_set: BTreeSet<RpcAudience<SmolStr>> = BTreeSet::new(); 1472 aud_set.insert(RpcAudience::All); 1473 1474 let wildcard_scope = Scope::Rpc(RpcScope { 1475 lxm: rpc_scope_set, 1476 aud: aud_set, 1477 }); 1478 1479 // Target scope for any request. 1480 let mut target_lxm = BTreeSet::new(); 1481 target_lxm.insert(RpcLexicon::Nsid( 1482 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1483 )); 1484 let mut target_aud = BTreeSet::new(); 1485 target_aud.insert(RpcAudience::All); 1486 1487 let target_scope = Scope::Rpc(RpcScope { 1488 lxm: target_lxm, 1489 aud: target_aud, 1490 }); 1491 1492 // Wildcard should grant any target scope. 1493 assert!( 1494 wildcard_scope.grants(&target_scope), 1495 "rpc:* should grant access to any RPC method" 1496 ); 1497 } 1498 1499 /// Test that an unmatched scope denies access. 1500 #[test] 1501 fn test_scope_check_denies_ungranted() { 1502 // AC7.4: Session with rpc:com.example.other denies access to com.example.test. 1503 let mut rpc_scope_set = BTreeSet::new(); 1504 rpc_scope_set.insert(RpcLexicon::Nsid( 1505 Nsid::<SmolStr>::new_static("com.example.other").unwrap(), 1506 )); 1507 let mut aud_set = BTreeSet::new(); 1508 aud_set.insert(RpcAudience::All); 1509 1510 let granted_scope = Scope::Rpc(RpcScope { 1511 lxm: rpc_scope_set, 1512 aud: aud_set, 1513 }); 1514 1515 // Target scope for a request to com.example.test. 1516 let mut target_lxm = BTreeSet::new(); 1517 target_lxm.insert(RpcLexicon::Nsid( 1518 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1519 )); 1520 let mut target_aud = BTreeSet::new(); 1521 target_aud.insert(RpcAudience::All); 1522 1523 let target_scope = Scope::Rpc(RpcScope { 1524 lxm: target_lxm, 1525 aud: target_aud, 1526 }); 1527 1528 // rpc:com.example.other should NOT grant access to com.example.test. 1529 assert!( 1530 !granted_scope.grants(&target_scope), 1531 "rpc:com.example.other should NOT grant access to com.example.test" 1532 ); 1533 } 1534 1535 /// Test that a repo scope grants access to the specified collection. 1536 #[test] 1537 fn test_scope_check_permits_repo_scope() { 1538 // AC7.1: Session with repo:com.example.test grants access to that collection. 1539 let mut actions = BTreeSet::new(); 1540 actions.insert(RepoAction::Create); 1541 actions.insert(RepoAction::Update); 1542 actions.insert(RepoAction::Delete); 1543 1544 let granted_repo = Scope::Repo(RepoScope { 1545 collection: RepoCollection::Nsid( 1546 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1547 ), 1548 actions, 1549 }); 1550 1551 // Target scope for a request to com.example.test. 1552 let mut target_actions = BTreeSet::new(); 1553 target_actions.insert(RepoAction::Create); 1554 target_actions.insert(RepoAction::Update); 1555 target_actions.insert(RepoAction::Delete); 1556 1557 let target_repo = Scope::Repo(RepoScope { 1558 collection: RepoCollection::Nsid( 1559 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1560 ), 1561 actions: target_actions, 1562 }); 1563 1564 // The repo scope should grant the target scope. 1565 assert!( 1566 granted_repo.grants(&target_repo), 1567 "repo:com.example.test should grant repo access to com.example.test" 1568 ); 1569 } 1570 1571 /// Test that ScopeError provides diagnostic information. 1572 #[test] 1573 fn test_scope_error_diagnostic_info() { 1574 // AC7.4: ScopeError includes request NSID and granted scope summary. 1575 let err = ScopeError { 1576 nsid: SmolStr::from("com.example.test"), 1577 granted: SmolStr::from("rpc:com.example.other"), 1578 }; 1579 1580 assert_eq!(err.nsid, "com.example.test"); 1581 assert_eq!(err.granted, "rpc:com.example.other"); 1582 let error_msg = err.to_string(); 1583 assert!( 1584 error_msg.contains("not permitted"), 1585 "error message should indicate request is not permitted" 1586 ); 1587 } 1588 1589 /// Test that multiple granted scopes are checked correctly. 1590 #[test] 1591 fn test_scope_check_multiple_scopes() { 1592 // AC7.1: With multiple scopes, request matching one of them is permitted. 1593 let mut other_lxm = BTreeSet::new(); 1594 other_lxm.insert(RpcLexicon::Nsid( 1595 Nsid::<SmolStr>::new_static("com.example.other").unwrap(), 1596 )); 1597 let mut other_aud = BTreeSet::new(); 1598 other_aud.insert(RpcAudience::All); 1599 1600 let other_scope = Scope::Rpc(RpcScope { 1601 lxm: other_lxm, 1602 aud: other_aud, 1603 }); 1604 1605 let mut test_lxm = BTreeSet::new(); 1606 test_lxm.insert(RpcLexicon::Nsid( 1607 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1608 )); 1609 let mut test_aud = BTreeSet::new(); 1610 test_aud.insert(RpcAudience::All); 1611 1612 let test_scope = Scope::Rpc(RpcScope { 1613 lxm: test_lxm, 1614 aud: test_aud, 1615 }); 1616 1617 // Target scope for a request to com.example.test. 1618 let mut target_lxm = BTreeSet::new(); 1619 target_lxm.insert(RpcLexicon::Nsid( 1620 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1621 )); 1622 let mut target_aud = BTreeSet::new(); 1623 target_aud.insert(RpcAudience::All); 1624 1625 let target_scope = Scope::Rpc(RpcScope { 1626 lxm: target_lxm, 1627 aud: target_aud, 1628 }); 1629 1630 // With multiple scopes, if one matches, the check passes. 1631 let granted_scopes = vec![other_scope, test_scope]; 1632 let is_permitted = granted_scopes.iter().any(|s| s.grants(&target_scope)); 1633 assert!( 1634 is_permitted, 1635 "at least one granted scope should permit the target request" 1636 ); 1637 } 1638 1639 /// Test that both RPC and repo scopes are checked when determining permissions. 1640 #[test] 1641 fn test_scope_check_rpc_and_repo_paths() { 1642 // AC7.1: A request can be granted via either rpc: or repo: scopes. 1643 1644 // Create a repo scope for the collection. 1645 let mut repo_actions = BTreeSet::new(); 1646 repo_actions.insert(RepoAction::Create); 1647 repo_actions.insert(RepoAction::Update); 1648 repo_actions.insert(RepoAction::Delete); 1649 1650 let repo_scope = Scope::Repo(RepoScope { 1651 collection: RepoCollection::Nsid( 1652 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1653 ), 1654 actions: repo_actions, 1655 }); 1656 1657 // Target scope for a request to com.example.test (as repo operations). 1658 let mut target_actions = BTreeSet::new(); 1659 target_actions.insert(RepoAction::Create); 1660 target_actions.insert(RepoAction::Update); 1661 target_actions.insert(RepoAction::Delete); 1662 1663 let target_repo = Scope::Repo(RepoScope { 1664 collection: RepoCollection::Nsid( 1665 Nsid::<SmolStr>::new_static("com.example.test").unwrap(), 1666 ), 1667 actions: target_actions, 1668 }); 1669 1670 // The repo scope should satisfy the request. 1671 assert!( 1672 repo_scope.grants(&target_repo), 1673 "repo scope should grant repo-based requests" 1674 ); 1675 } 1676}