A better Rust ATProto crate
1use crate::{
2 atproto::atproto_client_metadata,
3 authstore::{ClientAuthStore, OAuthSessionMatch, OAuthSessionSelector},
4 dpop::DpopExt,
5 error::{CallbackError, OAuthError, Result},
6 request::{OAuthMetadata, exchange_code, par},
7 resolver::OAuthResolver,
8 scopes::Scopes,
9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry},
10 types::{AuthorizeOptions, CallbackParams},
11};
12#[cfg(feature = "scope-check")]
13use crate::{
14 error::ScopeError,
15 resolver::resolve_permission_set,
16 scopes::{IncludeScope, RepoCollection, RpcLexicon, Scope},
17};
18#[cfg(feature = "websocket")]
19use jacquard_common::CowStr;
20#[cfg(feature = "scope-check")]
21use jacquard_common::types::{nsid::Nsid, string::DidService};
22use jacquard_common::{
23 AuthorizationToken, IntoStatic,
24 bos::BosStr,
25 deps::fluent_uri::Uri,
26 error::{AuthError, ClientError, XrpcResult},
27 http_client::HttpClient,
28 session::{SessionHint, SessionSelector, SessionStoreError},
29 types::{did::Did, string::Handle},
30 xrpc::{
31 CallOptions, Response, XrpcClient, XrpcExt, XrpcRequest, XrpcResp, XrpcResponse,
32 build_http_request, process_response,
33 },
34};
35#[cfg(feature = "scope-check")]
36use jacquard_identity::lexicon_resolver::LexiconSchemaResolver;
37
38#[cfg(feature = "scope-check")]
39use jacquard_common::deps::fluent_uri::pct_enc::{EStr, encoder::Query};
40
41#[cfg(feature = "websocket")]
42use jacquard_common::websocket::{WebSocketClient, WebSocketConnection};
43#[cfg(feature = "websocket")]
44use jacquard_common::xrpc::XrpcSubscription;
45use jacquard_identity::{
46 JacquardResolver,
47 resolver::{DidDocResponse, IdentityError, IdentityResolver, ResolverOptions},
48};
49use jose_jwk::JwkSet;
50use smol_str::{SmolStr, ToSmolStr};
51use std::{str::FromStr, sync::Arc};
52use tokio::sync::RwLock;
53
54/// Result of resuming an OAuth session or starting a new authorization flow.
55pub enum OAuthResumeOrLogin<T, S>
56where
57 T: OAuthResolver,
58 S: ClientAuthStore,
59{
60 /// A stored session was found and restored/refreshed.
61 Resumed(OAuthSession<T, S>),
62 /// No stored session matched; redirect the user to this login URL.
63 LoginUrl(String),
64 /// No stored session matched, and the hint did not contain enough information to start OAuth.
65 NeedsInput,
66}
67
68/// The top-level OAuth client responsible for driving the authorization flow.
69pub struct OAuthClient<T, S>
70where
71 T: OAuthResolver,
72 S: ClientAuthStore,
73{
74 /// Shared session registry that mediates access to the backing auth store.
75 pub registry: Arc<SessionRegistry<T, S, SmolStr>>,
76 /// Default call options applied to every outgoing XRPC request.
77 pub options: RwLock<CallOptions>,
78 /// Override for the XRPC base URI; falls back to the public Bluesky AppView when `None`.
79 pub endpoint: RwLock<Option<Uri<String>>>,
80 /// Underlying HTTP/identity/OAuth resolver used for all network operations.
81 pub client: Arc<T>,
82}
83
84impl<S: ClientAuthStore, C: HttpClient + Sync> OAuthClient<JacquardResolver<C>, S> {
85 /// Create an `OAuthClient` using the default [`JacquardResolver`] for identity and metadata resolution.
86 pub fn new(store: S, client_data: ClientData<SmolStr>, http: C) -> Self {
87 let client = JacquardResolver::new(http, ResolverOptions::default());
88 Self::new_from_resolver(store, client, client_data)
89 }
90}
91
92impl<S: ClientAuthStore> OAuthClient<JacquardResolver<reqwest::Client>, S> {
93 /// Create an OAuth client with the provided store and default localhost client metadata.
94 ///
95 /// This is a convenience constructor for quickly setting up an OAuth client
96 /// with default localhost redirect URIs and "atproto transition:generic" scopes.
97 ///
98 /// # Example
99 ///
100 /// ```no_run
101 /// # use jacquard_oauth::client::OAuthClient;
102 /// # use jacquard_oauth::authstore::MemoryAuthStore;
103 /// # #[tokio::main]
104 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
105 /// let store = MemoryAuthStore::new();
106 /// let oauth = OAuthClient::with_default_config(store);
107 /// # Ok(())
108 /// # }
109 /// ```
110 pub fn with_default_config(store: S) -> Self {
111 let client_data = ClientData {
112 keyset: None,
113 config: crate::atproto::AtprotoClientMetadata::default_localhost(),
114 };
115 Self::new(store, client_data, reqwest::Client::new())
116 }
117}
118
119impl OAuthClient<JacquardResolver<reqwest::Client>, crate::authstore::MemoryAuthStore> {
120 /// Create an OAuth client with an in-memory auth store and default localhost client metadata.
121 ///
122 /// This is a convenience constructor for simple testing and development.
123 /// The session will not persist across restarts.
124 ///
125 /// # Example
126 ///
127 /// ```no_run
128 /// # use jacquard_oauth::client::OAuthClient;
129 /// # #[tokio::main]
130 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
131 /// let oauth = OAuthClient::with_memory_store();
132 /// # Ok(())
133 /// # }
134 /// ```
135 pub fn with_memory_store() -> Self {
136 Self::with_default_config(crate::authstore::MemoryAuthStore::new())
137 }
138}
139
140impl<T, S> OAuthClient<T, S>
141where
142 T: OAuthResolver,
143 S: ClientAuthStore,
144{
145 /// Create an OAuth client from an explicit resolver instance, taking ownership of both.
146 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<SmolStr>) -> Self {
147 // #[cfg(feature = "tracing")]
148 // tracing::info!(
149 // redirect_uris = ?client_data.config.redirect_uris,
150 // scopes = ?client_data.config.scopes,
151 // has_keyset = client_data.keyset.is_some(),
152 // "oauth client created:"
153 // );
154
155 let client = Arc::new(client);
156 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data));
157 Self {
158 registry,
159 client,
160 options: RwLock::new(CallOptions::default()),
161 endpoint: RwLock::new(None),
162 }
163 }
164
165 /// Create an OAuth client from already-`Arc`-wrapped store and resolver.
166 pub fn new_with_shared(
167 store: Arc<S>,
168 client: Arc<T>,
169 client_data: ClientData<SmolStr>,
170 ) -> Self {
171 let registry = Arc::new(SessionRegistry::new_shared(
172 store,
173 client.clone(),
174 client_data,
175 ));
176 Self {
177 registry,
178 client,
179 options: RwLock::new(CallOptions::default()),
180 endpoint: RwLock::new(None),
181 }
182 }
183}
184
185impl<T, S> OAuthClient<T, S>
186where
187 S: ClientAuthStore + Send + Sync + 'static,
188 T: OAuthResolver + DpopExt + Send + Sync + 'static,
189{
190 /// Return the public JWK set for this client's keyset, or an empty set if no keyset is configured.
191 pub fn jwks(&self) -> JwkSet {
192 self.registry
193 .client_data
194 .keyset
195 .as_ref()
196 .map(|keyset| keyset.public_jwks())
197 .unwrap_or_default()
198 }
199 /// Begin an OAuth authorization flow and return the URL to which the user should be redirected.
200 ///
201 /// This resolves OAuth metadata for the given `input` (a handle, DID, or PDS/entryway URL),
202 /// performs a Pushed Authorization Request (PAR) to the authorization server, persists the
203 /// resulting state for later callback verification, and returns a fully-constructed
204 /// authorization endpoint URL.
205 ///
206 /// The caller is responsible for redirecting the user's browser to the returned URL.
207 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self, input), fields(input = input.as_ref())))]
208 pub async fn start_auth<Str: BosStr>(
209 &self,
210 input: impl AsRef<str>,
211 options: AuthorizeOptions<Str>,
212 ) -> Result<String>
213 where
214 Str: FromStr + Ord + Clone + core::fmt::Debug,
215 <Str as FromStr>::Err: core::fmt::Debug,
216 {
217 let client_metadata = atproto_client_metadata(
218 &self.registry.client_data.config,
219 &self.registry.client_data.keyset,
220 )?;
221 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?;
222 let login_hint = if identity.is_some() {
223 Some(input.as_ref().into())
224 } else {
225 None
226 };
227 let mut metadata = OAuthMetadata {
228 server_metadata,
229 client_metadata,
230 keyset: self.registry.client_data.keyset.clone(),
231 };
232
233 let auth_req_info = par(
234 self.client.as_ref(),
235 login_hint,
236 options.prompt,
237 &mut metadata,
238 options.state.map(|s| s.as_ref().to_smolstr()),
239 )
240 .await?;
241
242 // Persist state for callback handling
243 self.registry
244 .store
245 .save_auth_req_info(&auth_req_info)
246 .await?;
247
248 #[derive(serde::Serialize)]
249 struct Parameters {
250 client_id: smol_str::SmolStr,
251 request_uri: smol_str::SmolStr,
252 }
253 Ok(metadata.server_metadata.authorization_endpoint.to_string()
254 + "?"
255 + &serde_html_form::to_string(Parameters {
256 client_id: metadata.client_metadata.client_id,
257 request_uri: auth_req_info.request_uri,
258 })
259 .unwrap())
260 }
261
262 /// Complete the OAuth authorization flow after the authorization server redirects back to the client.
263 ///
264 /// Validates the `state` and optional `iss` parameters, exchanges the authorization code for
265 /// tokens via the token endpoint, verifies the `sub` claim against the expected issuer, and
266 /// persists the resulting session. On success returns an [`OAuthSession`] ready for API calls.
267 ///
268 /// When the `scope-check` feature is enabled, this method also eagerly resolves any `include:`
269 /// scopes by fetching the referenced permission sets. `T` must implement
270 /// `LexiconSchemaResolver` in that case.
271 #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip_all, fields(state = params.state.as_ref().map(|s| s.as_str()))))]
272 #[cfg(not(feature = "scope-check"))]
273 pub async fn callback(&self, params: CallbackParams) -> Result<OAuthSession<T, S>> {
274 let client_data = self.callback_core(params).await?;
275 self.create_session(client_data).await
276 }
277
278 /// Complete the OAuth authorization flow (scope-check variant).
279 ///
280 /// Same as `callback`, but eagerly resolves `include:` scopes into
281 /// concrete permissions via `LexiconSchemaResolver`.
282 #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip_all, fields(state = params.state.as_ref().map(|s| s.as_str()))))]
283 #[cfg(feature = "scope-check")]
284 pub async fn callback(&self, params: CallbackParams) -> Result<OAuthSession<T, S>>
285 where
286 T: LexiconSchemaResolver,
287 {
288 let mut client_data = self.callback_core(params).await?;
289 client_data.resolved_scopes =
290 Some(resolve_include_scopes(self.client.as_ref(), &client_data.scopes).await?);
291 self.create_session(client_data).await
292 }
293
294 /// Shared callback logic: validate state/iss, exchange code, build session data.
295 async fn callback_core(&self, params: CallbackParams) -> Result<ClientSessionData> {
296 let Some(state_key) = params.state else {
297 return Err(CallbackError::MissingState.into());
298 };
299
300 let Some(auth_req_info) = self
301 .registry
302 .store
303 .get_auth_req_info(state_key.as_str())
304 .await?
305 else {
306 return Err(CallbackError::MissingState.into());
307 };
308
309 self.registry
310 .store
311 .delete_auth_req_info(state_key.as_str())
312 .await?;
313
314 let metadata = self
315 .client
316 .get_authorization_server_metadata(auth_req_info.authserver_url.as_str())
317 .await?;
318
319 if let Some(iss) = params.iss {
320 if iss != metadata.issuer {
321 return Err(CallbackError::IssuerMismatch {
322 expected: metadata.issuer.to_string(),
323 got: iss.to_string(),
324 }
325 .into());
326 }
327 } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
328 return Err(CallbackError::MissingIssuer.into());
329 }
330 let metadata = OAuthMetadata {
331 server_metadata: metadata,
332 client_metadata: atproto_client_metadata(
333 &self.registry.client_data.config,
334 &self.registry.client_data.keyset,
335 )?,
336 keyset: self.registry.client_data.keyset.clone(),
337 };
338 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone();
339
340 match exchange_code(
341 self.client.as_ref(),
342 &mut auth_req_info.dpop_data.clone(),
343 params.code.as_str(),
344 auth_req_info.pkce_verifier.as_str(),
345 &metadata,
346 )
347 .await
348 {
349 Ok(token_set) => {
350 let scopes = if let Some(scope) = &token_set.scope {
351 Scopes::new(scope.as_str().to_smolstr())
352 .expect("Failed to parse scopes from token response")
353 } else {
354 Scopes::empty()
355 };
356 Ok(ClientSessionData {
357 account_did: token_set.sub.clone(),
358 session_id: auth_req_info.state,
359 host_url: Uri::parse(token_set.aud.as_str())?.to_owned(),
360 authserver_url: auth_req_info.authserver_url,
361 authserver_token_endpoint: auth_req_info.authserver_token_endpoint,
362 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint,
363 scopes,
364 dpop_data: DpopClientData {
365 dpop_key: auth_req_info.dpop_data.dpop_key.clone(),
366 dpop_authserver_nonce: authserver_nonce.unwrap_or_default(),
367 dpop_host_nonce: auth_req_info
368 .dpop_data
369 .dpop_authserver_nonce
370 .unwrap_or_default(),
371 },
372 token_set,
373 resolved_scopes: None,
374 })
375 }
376 Err(e) => Err(e.into()),
377 }
378 }
379
380 async fn create_session(&self, data: ClientSessionData) -> Result<OAuthSession<T, S>> {
381 self.registry.set(data.clone()).await?;
382 Ok(OAuthSession::new(
383 self.registry.clone(),
384 self.client.clone(),
385 data.into_static(),
386 ))
387 }
388
389 /// Restore a previously created session from the backing store, refreshing tokens if needed.
390 pub async fn restore(
391 &self,
392 did: &Did<impl BosStr + Send + Sync>,
393 session_id: &str,
394 ) -> Result<OAuthSession<T, S>> {
395 self.create_session(self.registry.get(did, session_id, true).await?)
396 .await
397 }
398
399 /// Resume a stored session for `input`, or begin OAuth authorization and return a login URL.
400 pub async fn resume_or_start_auth_for<Str: BosStr>(
401 &self,
402 input: impl AsRef<str>,
403 options: AuthorizeOptions<Str>,
404 ) -> Result<OAuthResumeOrLogin<T, S>>
405 where
406 S: SessionSelector<OAuthSessionMatch, Error = SessionStoreError>,
407 Str: FromStr + Ord + Clone + core::fmt::Debug,
408 <Str as FromStr>::Err: core::fmt::Debug,
409 {
410 let input = input.as_ref();
411 self.resume_or_start_auth(&SessionHint::from_input(input), options)
412 .await
413 }
414
415 /// Resume a stored session for `hint`, or begin OAuth authorization from the hint identity.
416 pub async fn resume_or_start_auth<HintStr, Str>(
417 &self,
418 hint: &SessionHint<HintStr>,
419 options: AuthorizeOptions<Str>,
420 ) -> Result<OAuthResumeOrLogin<T, S>>
421 where
422 S: SessionSelector<OAuthSessionMatch, Error = SessionStoreError>,
423 HintStr: BosStr + Send + Sync,
424 Str: BosStr + FromStr + Ord + Clone + core::fmt::Debug,
425 <Str as FromStr>::Err: core::fmt::Debug,
426 {
427 let input = oauth_start_auth_input_from_hint(hint);
428 match OAuthSessionSelector::new(self.registry.store.as_ref(), self.client.as_ref())
429 .select_session(hint)
430 .await?
431 {
432 Some(matched) => match self
433 .restore(&matched.key.did, matched.key.session_id.as_str())
434 .await
435 {
436 Ok(session) => Ok(OAuthResumeOrLogin::Resumed(session)),
437 Err(err) if should_start_auth_after_restore_error(&err) => {
438 let Some(input) = input else {
439 return Err(err);
440 };
441 Ok(OAuthResumeOrLogin::LoginUrl(
442 self.start_auth(input, options).await?,
443 ))
444 }
445 Err(err) => Err(err),
446 },
447 None => {
448 let Some(input) = input else {
449 return Ok(OAuthResumeOrLogin::NeedsInput);
450 };
451 Ok(OAuthResumeOrLogin::LoginUrl(
452 self.start_auth(input, options).await?,
453 ))
454 }
455 }
456 }
457
458 /// Revoke a session by deleting it from the backing store.
459 ///
460 /// Note: this removes the session from local storage but does **not** call the authorization
461 /// server's revocation endpoint. To also invalidate the token server-side, prefer
462 /// [`OAuthSession::logout`], which calls `revoke` on the token before deleting the session.
463 pub async fn revoke(
464 &self,
465 did: &Did<impl BosStr + Send + Sync>,
466 session_id: &str,
467 ) -> Result<()> {
468 Ok(self.registry.del(did, session_id).await?)
469 }
470}
471
472fn oauth_start_auth_input_from_hint<S: BosStr>(hint: &SessionHint<S>) -> Option<SmolStr> {
473 match hint {
474 SessionHint::Did(did) => Some(did.as_ref().to_smolstr()),
475 SessionHint::Handle(handle) => Some(handle.as_ref().to_smolstr()),
476 SessionHint::Key(key) => Some(key.did.as_str().to_smolstr()),
477 SessionHint::Identifier(identifier) => Some(identifier.as_ref().to_smolstr()),
478 SessionHint::Any => None,
479 }
480}
481
482fn should_start_auth_after_restore_error(err: &OAuthError) -> bool {
483 matches!(err, OAuthError::Session(session_err) if session_err.is_permanent())
484}
485
486/// Decode a percent-encoded audience string.
487///
488/// The audience may contain percent-encoded characters like `%23` for `#`.
489/// This function decodes those and returns the decoded string.
490#[cfg(feature = "scope-check")]
491fn decode_audience(aud: &str) -> Result<String> {
492 // Use fluent_uri's percent-decoding to handle encoded characters.
493 // The audience is typically a DID, possibly with a fragment.
494 // EStr::new returns Option<&EStr>, so we match on that.
495 match EStr::<Query>::new(aud) {
496 Some(estr) => {
497 // estr.decode() returns a Decode struct
498 // The Decode type has a to_string() method that returns Result<Cow<str>, Vec<u8>>
499 let decoded = estr.decode();
500 match decoded.to_string() {
501 Ok(cow) => Ok(cow.into_owned()),
502 Err(bytes) => Err(crate::error::CallbackError::ScopeResolution {
503 detail: format!(
504 "percent-decoded audience contains invalid UTF-8: {:?}",
505 bytes
506 ),
507 }
508 .into()),
509 }
510 }
511 None => {
512 // If it's not a valid percent-encoded string, use it as-is.
513 // This handles cases where no encoding was applied.
514 Ok(aud.to_string())
515 }
516 }
517}
518
519/// Resolve all `include:` scopes in the given scope set into concrete permissions.
520///
521/// Non-include scopes are passed through unchanged. Each `include:` scope is
522/// resolved via `resolve_permission_set`, which fetches the permission set
523/// lexicon and expands it into concrete `Scope<SmolStr>` values.
524#[cfg(feature = "scope-check")]
525async fn resolve_include_scopes<R>(
526 resolver: &R,
527 scopes: &Scopes<SmolStr>,
528) -> Result<Vec<Scope<SmolStr>>>
529where
530 R: OAuthResolver + LexiconSchemaResolver + Send + Sync,
531{
532 let mut resolved = Vec::new();
533 for scope in scopes.iter() {
534 match scope {
535 Scope::Include(IncludeScope { nsid, audience }) => {
536 let audience_did = if let Some(aud_str) = audience {
537 let decoded = decode_audience(aud_str)?;
538 match DidService::new_owned(&decoded) {
539 Ok(did) => Some(did),
540 Err(_) => {
541 return Err(crate::error::CallbackError::ScopeResolution {
542 detail: format!(
543 "invalid DID in include scope audience: {}",
544 decoded
545 ),
546 }
547 .into());
548 }
549 }
550 } else {
551 None
552 };
553
554 let nsid_smolstr = match Nsid::<SmolStr>::new_owned(nsid.as_str()) {
555 Ok(n) => n,
556 Err(_) => {
557 return Err(crate::error::CallbackError::ScopeResolution {
558 detail: format!("invalid NSID in include scope: {}", nsid),
559 }
560 .into());
561 }
562 };
563
564 let expanded =
565 resolve_permission_set(resolver, &nsid_smolstr, audience_did.as_ref()).await?;
566 resolved.extend(expanded);
567 }
568 other => {
569 resolved.push(other.convert());
570 }
571 }
572 }
573 Ok(resolved)
574}
575
576impl<T, S> HttpClient for OAuthClient<T, S>
577where
578 S: ClientAuthStore + Send + Sync + 'static,
579 T: OAuthResolver + DpopExt + Send + Sync + 'static,
580{
581 type Error = T::Error;
582
583 async fn send_http(
584 &self,
585 request: http::Request<Vec<u8>>,
586 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
587 self.client.send_http(request).await
588 }
589}
590
591impl<T, S> IdentityResolver for OAuthClient<T, S>
592where
593 S: ClientAuthStore + Send + Sync + 'static,
594 T: OAuthResolver + DpopExt + Send + Sync + 'static,
595{
596 fn options(&self) -> &ResolverOptions {
597 self.client.options()
598 }
599
600 async fn resolve_handle<Str: BosStr + Sync>(
601 &self,
602 handle: &Handle<Str>,
603 ) -> jacquard_identity::resolver::Result<Did> {
604 self.client.resolve_handle(handle).await
605 }
606
607 async fn resolve_did_doc<Str: BosStr + Sync>(
608 &self,
609 did: &Did<Str>,
610 ) -> jacquard_identity::resolver::Result<DidDocResponse> {
611 self.client.resolve_did_doc(did).await
612 }
613}
614
615impl<T, S> XrpcClient for OAuthClient<T, S>
616where
617 S: ClientAuthStore + Send + Sync + 'static,
618 T: OAuthResolver + DpopExt + Send + Sync + 'static,
619{
620 async fn base_uri(&self) -> Uri<String> {
621 self.endpoint.read().await.clone().unwrap_or_else(|| {
622 Uri::parse("https://public.api.bsky.app")
623 .expect("hardcoded URI is valid")
624 .to_owned()
625 })
626 }
627
628 async fn opts(&self) -> CallOptions {
629 self.options.read().await.clone()
630 }
631
632 async fn set_opts(&self, opts: CallOptions) {
633 let mut guard = self.options.write().await;
634 *guard = opts.into_static();
635 }
636
637 async fn set_base_uri(&self, uri: Uri<String>) {
638 let normalized = jacquard_common::xrpc::normalize_base_uri(uri);
639 let mut guard = self.endpoint.write().await;
640 *guard = Some(normalized);
641 }
642
643 async fn send<R>(&self, request: R) -> XrpcResult<XrpcResponse<R>>
644 where
645 R: XrpcRequest + Send + Sync + serde::Serialize,
646 <R as XrpcRequest>::Response: Send + Sync,
647 {
648 let opts = self.options.read().await.clone();
649 self.send_with_opts(request, opts).await
650 }
651
652 async fn send_with_opts<R>(&self, request: R, opts: CallOptions) -> XrpcResult<XrpcResponse<R>>
653 where
654 R: XrpcRequest + Send + Sync + serde::Serialize,
655 <R as XrpcRequest>::Response: Send + Sync,
656 {
657 let base_uri = self.base_uri().await;
658 let http_request = build_http_request(&base_uri.borrow(), &request, &opts)?;
659 let http_response = self
660 .client
661 .send_http(http_request)
662 .await
663 .map_err(|e| ClientError::transport(e).for_nsid(R::NSID))?;
664 process_response(http_response)
665 }
666}
667
668/// An active OAuth session for a specific account, used to make authenticated API requests.
669///
670/// `OAuthSession` holds the DPoP-bound token set for one account and handles transparent
671/// token refresh on `401 invalid_token` responses. The optional `W` type parameter allows
672/// attaching a WebSocket client (defaults to `()` when WebSocket support is not needed).
673///
674/// Obtain an `OAuthSession` from [`OAuthClient::callback`] or [`OAuthClient::restore`].
675pub struct OAuthSession<T, S, W = ()>
676where
677 T: OAuthResolver,
678 S: ClientAuthStore,
679{
680 /// Shared registry used to persist and retrieve session data across refresh operations.
681 pub registry: Arc<SessionRegistry<T, S, SmolStr>>,
682 /// Underlying HTTP/identity/OAuth resolver shared with the parent `OAuthClient`.
683 pub client: Arc<T>,
684 /// Optional WebSocket client; `()` when WebSocket support is not required.
685 pub ws_client: W,
686 /// Mutable session data including DPoP key, nonces, and token set.
687 pub data: RwLock<ClientSessionData>,
688 /// Default call options applied to every outgoing XRPC request from this session.
689 pub options: RwLock<CallOptions>,
690}
691
692impl<T, S> OAuthSession<T, S, ()>
693where
694 T: OAuthResolver,
695 S: ClientAuthStore,
696{
697 /// Create a new session without a WebSocket client.
698 ///
699 /// This is the standard constructor used by [`OAuthClient::callback`] and
700 /// [`OAuthClient::restore`]. For WebSocket support use [`OAuthSession::new_with_ws`].
701 pub fn new(
702 registry: Arc<SessionRegistry<T, S, SmolStr>>,
703 client: Arc<T>,
704 data: ClientSessionData,
705 ) -> Self {
706 Self {
707 registry,
708 client,
709 ws_client: (),
710 data: RwLock::new(data),
711 options: RwLock::new(CallOptions::default()),
712 }
713 }
714}
715
716impl<T, S, W> OAuthSession<T, S, W>
717where
718 T: OAuthResolver,
719 S: ClientAuthStore,
720{
721 /// Create a new session with an attached WebSocket client.
722 ///
723 /// Use this variant when the session needs to support WebSocket subscriptions in addition
724 /// to standard XRPC calls. The `ws_client` is exposed via [`OAuthSession::ws_client`] and
725 /// is used by the `WebSocketClient` impl when the `websocket` feature is enabled.
726 pub fn new_with_ws(
727 registry: Arc<SessionRegistry<T, S, SmolStr>>,
728 client: Arc<T>,
729 ws_client: W,
730 data: ClientSessionData,
731 ) -> Self {
732 Self {
733 registry,
734 client,
735 ws_client,
736 data: RwLock::new(data),
737 options: RwLock::new(CallOptions::default()),
738 }
739 }
740
741 /// Consume this session and return a new one with the given call options pre-applied.
742 ///
743 /// Useful for setting request-level defaults (e.g., `atproto-proxy` or custom headers) once
744 /// at construction time rather than passing them to every individual XRPC call.
745 pub fn with_options(self, options: CallOptions) -> Self {
746 Self {
747 registry: self.registry,
748 client: self.client,
749 ws_client: self.ws_client,
750 data: self.data,
751 options: RwLock::new(options.into_static()),
752 }
753 }
754
755 /// Get a reference to the WebSocket client.
756 pub fn ws_client(&self) -> &W {
757 &self.ws_client
758 }
759
760 /// Replace the default call options for this session without consuming it.
761 pub async fn set_options(&self, options: CallOptions) {
762 *self.options.write().await = options.into_static();
763 }
764
765 /// Return the DID and session ID for this session.
766 ///
767 /// The session ID is the random `state` token generated during the PAR flow and can
768 /// be used together with the DID to restore the session via [`OAuthClient::restore`].
769 pub async fn session_info(&self) -> (Did, smol_str::SmolStr) {
770 let data = self.data.read().await;
771 (data.account_did.clone(), data.session_id.clone())
772 }
773
774 /// Return the resource server (PDS) base URI for this session.
775 pub async fn endpoint(&self) -> Uri<String> {
776 self.data.read().await.host_url.clone()
777 }
778
779 /// Return the current DPoP-bound access token for this session.
780 ///
781 /// The token may be stale if it has expired; use [`OAuthSession::refresh`] or
782 /// rely on the automatic refresh performed by `send_with_opts` to obtain a fresh one.
783 pub async fn access_token(&self) -> AuthorizationToken<SmolStr> {
784 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone())
785 }
786
787 /// Return the current refresh token for this session, if one is present.
788 ///
789 /// Not all authorization servers issue refresh tokens. When `None` is returned,
790 /// the session cannot be silently renewed and the user must re-authenticate.
791 pub async fn refresh_token(&self) -> Option<AuthorizationToken<SmolStr>> {
792 self.data
793 .read()
794 .await
795 .token_set
796 .refresh_token
797 .clone()
798 .map(|t| AuthorizationToken::Dpop(t))
799 }
800
801 /// Derive an unauthenticated [`OAuthClient`] that shares the same registry and resolver.
802 ///
803 /// Useful when you need to initiate a new authorization flow from within an existing
804 /// session context (e.g., to add a second account) without constructing a fresh client.
805 pub fn to_client(&self) -> OAuthClient<T, S> {
806 OAuthClient::from_session(self)
807 }
808}
809impl<T, S, W> OAuthSession<T, S, W>
810where
811 S: ClientAuthStore + Send + Sync + 'static,
812 T: OAuthResolver + DpopExt + Send + Sync + 'static,
813{
814 /// Revoke the access token at the authorization server and delete the session from the store.
815 ///
816 /// Revocation is best-effort: if the server does not advertise a revocation endpoint, or if
817 /// the revocation call fails, the session is still deleted locally. This prevents a dangling
818 /// session record from blocking future logins for the same account.
819 pub async fn logout(&self) -> Result<()> {
820 use crate::request::{OAuthMetadata, revoke};
821 let mut data = self.data.write().await;
822 let meta =
823 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?;
824 if meta.server_metadata.revocation_endpoint.is_some() {
825 let token = data.token_set.access_token.clone();
826 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta)
827 .await
828 .ok();
829 }
830 // Remove from store
831 self.registry
832 .del(&data.account_did, &data.session_id)
833 .await?;
834 Ok(())
835 }
836}
837
838impl<T, S> OAuthClient<T, S>
839where
840 T: OAuthResolver,
841 S: ClientAuthStore,
842{
843 /// Construct an `OAuthClient` that shares the registry and resolver of an existing session.
844 ///
845 /// Equivalent to [`OAuthSession::to_client`]; provided on `OAuthClient` for symmetry so
846 /// callers can obtain an unauthenticated client without holding a session reference.
847 pub fn from_session<W>(session: &OAuthSession<T, S, W>) -> Self {
848 Self {
849 registry: session.registry.clone(),
850 client: session.client.clone(),
851 options: RwLock::new(CallOptions::default()),
852 endpoint: RwLock::new(None),
853 }
854 }
855}
856impl<T, S, W> OAuthSession<T, S, W>
857where
858 S: ClientAuthStore + Send + Sync + 'static,
859 T: OAuthResolver + DpopExt + Send + Sync + 'static,
860{
861 /// Explicitly refresh the access token using the stored refresh token.
862 ///
863 /// On success the new token set is written back into both the in-memory session data and
864 /// the backing store. The returned `AuthorizationToken` is the new access token, which
865 /// callers can immediately use to retry a failed request.
866 ///
867 /// The actual token exchange is serialized per `(DID, session_id)` pair via a `Mutex` inside
868 /// the registry, so concurrent refresh attempts will not result in duplicate token exchanges.
869 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))]
870 pub async fn refresh(&self) -> Result<AuthorizationToken<SmolStr>> {
871 // Read identifiers without holding the lock across await
872 let (did, sid) = {
873 let data = self.data.read().await;
874 (data.account_did.clone(), data.session_id.clone())
875 };
876 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?;
877 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone());
878 // Write back updated session
879 *self.data.write().await = refreshed.clone().into_static();
880 // Store in the registry
881 self.registry.set(refreshed).await?;
882 Ok(token)
883 }
884}
885
886impl<T, S, W> HttpClient for OAuthSession<T, S, W>
887where
888 S: ClientAuthStore + Send + Sync + 'static,
889 T: OAuthResolver + DpopExt + Send + Sync + 'static,
890 W: Send + Sync,
891{
892 type Error = T::Error;
893
894 async fn send_http(
895 &self,
896 request: http::Request<Vec<u8>>,
897 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
898 self.client.send_http(request).await
899 }
900}
901
902impl<T, S, W> XrpcClient for OAuthSession<T, S, W>
903where
904 S: ClientAuthStore + Send + Sync + 'static,
905 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static,
906 W: Send + Sync,
907{
908 async fn base_uri(&self) -> Uri<String> {
909 self.data.read().await.host_url.clone()
910 }
911
912 async fn opts(&self) -> CallOptions {
913 self.options.read().await.clone()
914 }
915
916 async fn set_opts(&self, opts: CallOptions) {
917 let mut guard = self.options.write().await;
918 *guard = opts.into_static();
919 }
920
921 async fn set_base_uri(&self, uri: Uri<String>) {
922 let normalized = jacquard_common::xrpc::normalize_base_uri(uri);
923 let mut guard = self.data.write().await;
924 guard.host_url = normalized;
925 }
926
927 async fn send<R>(&self, request: R) -> XrpcResult<XrpcResponse<R>>
928 where
929 R: XrpcRequest + Send + Sync + serde::Serialize,
930 <R as XrpcRequest>::Response: Send + Sync,
931 {
932 let opts = self.options.read().await.clone();
933 self.send_with_opts(request, opts).await
934 }
935
936 async fn send_with_opts<R>(
937 &self,
938 request: R,
939 mut opts: CallOptions,
940 ) -> XrpcResult<XrpcResponse<R>>
941 where
942 R: XrpcRequest + Send + Sync + serde::Serialize,
943 <R as XrpcRequest>::Response: Send + Sync,
944 {
945 // Pre-flight scope check: pure in-memory, no HTTP.
946 #[cfg(feature = "scope-check")]
947 {
948 self.check_scope::<R>().await.map_err(|e| {
949 ClientError::invalid_request(format!("scope check failed: {:?}", e))
950 .for_nsid(R::NSID)
951 })?;
952 }
953
954 let base_uri = self.base_uri().await;
955 let original_token = self.access_token().await;
956 opts.auth = Some(original_token.clone());
957 // Clone dpop_data and release read lock before the await point
958 let mut dpop = self.data.read().await.dpop_data.clone();
959 let http_response = self
960 .client
961 .dpop_call(&mut dpop)
962 .send(build_http_request(&base_uri.borrow(), &request, &opts)?)
963 .await
964 .map_err(|e| ClientError::from(e).for_nsid(R::NSID))?;
965 let resp = process_response(http_response);
966
967 // Write back updated nonce to session data (dpop_call may have updated it)
968 {
969 let mut guard = self.data.write().await;
970 guard.dpop_data.dpop_host_nonce = dpop.dpop_host_nonce.clone();
971 }
972
973 if is_invalid_token_response(&resp) {
974 // Optimistic refresh: check if another request already refreshed the token
975 let current_token = self.access_token().await;
976 if current_token != original_token {
977 // Token was already refreshed by another concurrent request, use it
978 opts.auth = Some(current_token);
979 } else {
980 // We need to refresh - this will be serialized by the registry's Mutex
981 opts.auth = Some(
982 self.refresh()
983 .await
984 .map_err(|e| ClientError::transport(e))?,
985 );
986 }
987 // Re-read dpop_data after refresh (refresh may have updated it)
988 let mut dpop = self.data.read().await.dpop_data.clone();
989 let http_response = self
990 .client
991 .dpop_call(&mut dpop)
992 .send(build_http_request(&base_uri.borrow(), &request, &opts)?)
993 .await
994 .map_err(|e| {
995 ClientError::from(e)
996 .for_nsid(R::NSID)
997 .append_context("after token refresh")
998 })?;
999 let resp = process_response(http_response);
1000
1001 // Write back updated nonce after retry
1002 {
1003 let mut guard = self.data.write().await;
1004 guard.dpop_data.dpop_host_nonce = dpop.dpop_host_nonce.clone();
1005 }
1006
1007 resp
1008 } else {
1009 resp
1010 }
1011 }
1012}
1013
1014#[cfg(feature = "scope-check")]
1015impl<T, S, W> OAuthSession<T, S, W>
1016where
1017 S: ClientAuthStore + Send + Sync + 'static,
1018 T: OAuthResolver + Send + Sync + 'static,
1019 W: Send + Sync,
1020{
1021 /// Check whether the session's resolved scopes grant access to
1022 /// the XRPC method identified by `R::NSID`.
1023 async fn check_scope<R: XrpcRequest>(&self) -> core::result::Result<(), ScopeError> {
1024 let data = self.data.read().await;
1025
1026 // Use the resolved scopes from Phase 5's eager resolution.
1027 // These are fully expanded — no include scopes remain.
1028 let resolved = data.resolved_scopes.as_ref();
1029
1030 let is_permitted = match resolved {
1031 Some(scopes) => {
1032 let nsid = Nsid::<SmolStr>::new_static(R::NSID).expect("valid NSID");
1033
1034 // Check if any granted scope covers this NSID. A request
1035 // may be covered by rpc: scopes (method access) or repo:
1036 // scopes (record operations).
1037 //
1038 // Note: `atproto` is the minimum base scope (auth only).
1039 // It does NOT grant rpc/repo access.
1040 //
1041 // For rpc: scopes, we check only the lxm (method) match
1042 // and ignore audience. At pre-flight time the client does
1043 // not know the target audience — audience enforcement is
1044 // the server's responsibility. A granted scope with a
1045 // specific aud (e.g., did:web:api.bsky.app) still permits
1046 // calling the method from the client's perspective.
1047 let rpc_ok = scopes.iter().any(|s| match s {
1048 Scope::Rpc(rpc) => rpc.lxm.iter().any(|l| match l {
1049 RpcLexicon::All => true,
1050 RpcLexicon::Nsid(granted_nsid) => granted_nsid.as_ref() == nsid.as_ref(),
1051 }),
1052 _ => false,
1053 });
1054
1055 // For repo: scopes, check if the NSID matches a granted
1056 // collection. Any action suffices for pre-flight.
1057 let repo_ok = scopes.iter().any(|s| match s {
1058 Scope::Repo(repo) => match &repo.collection {
1059 RepoCollection::All => true,
1060 RepoCollection::Nsid(col) => col.as_ref() == nsid.as_ref(),
1061 },
1062 _ => false,
1063 });
1064
1065 rpc_ok || repo_ok
1066 }
1067 None => {
1068 // No resolved scopes means resolution was skipped
1069 // (e.g., no include scopes were present, or scope-check
1070 // was enabled after session creation). Allow the request.
1071 true
1072 }
1073 };
1074
1075 if !is_permitted {
1076 let granted_summary = resolved
1077 .map(|scopes| {
1078 scopes
1079 .iter()
1080 .map(|s| s.to_string_normalized())
1081 .collect::<Vec<_>>()
1082 .join(", ")
1083 })
1084 .unwrap_or_default();
1085
1086 return Err(ScopeError {
1087 nsid: SmolStr::new_static(R::NSID),
1088 granted: SmolStr::from(granted_summary),
1089 });
1090 }
1091
1092 Ok(())
1093 }
1094}
1095
1096#[cfg(feature = "streaming")]
1097impl<T, S, W> jacquard_common::http_client::HttpClientExt for OAuthSession<T, S, W>
1098where
1099 S: ClientAuthStore + Send + Sync + 'static,
1100 T: OAuthResolver
1101 + DpopExt
1102 + XrpcExt
1103 + jacquard_common::http_client::HttpClientExt
1104 + Send
1105 + Sync
1106 + 'static,
1107 W: Send + Sync,
1108{
1109 async fn send_http_streaming(
1110 &self,
1111 request: http::Request<Vec<u8>>,
1112 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error>
1113 {
1114 self.client.send_http_streaming(request).await
1115 }
1116
1117 #[cfg(not(target_arch = "wasm32"))]
1118 async fn send_http_bidirectional<Str>(
1119 &self,
1120 parts: http::request::Parts,
1121 body: Str,
1122 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error>
1123 where
1124 Str: n0_future::Stream<
1125 Item = core::result::Result<bytes::Bytes, jacquard_common::StreamError>,
1126 > + Send
1127 + 'static,
1128 {
1129 self.client.send_http_bidirectional(parts, body).await
1130 }
1131
1132 #[cfg(target_arch = "wasm32")]
1133 async fn send_http_bidirectional<Str>(
1134 &self,
1135 parts: http::request::Parts,
1136 body: Str,
1137 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error>
1138 where
1139 Str: n0_future::Stream<
1140 Item = core::result::Result<bytes::Bytes, jacquard_common::StreamError>,
1141 > + 'static,
1142 {
1143 self.client.send_http_bidirectional(parts, body).await
1144 }
1145}
1146
1147#[cfg(feature = "streaming")]
1148impl<T, S, W> jacquard_common::xrpc::XrpcStreamingClient for OAuthSession<T, S, W>
1149where
1150 S: ClientAuthStore + Send + Sync + 'static,
1151 T: OAuthResolver
1152 + DpopExt
1153 + XrpcExt
1154 + jacquard_common::http_client::HttpClientExt
1155 + Send
1156 + Sync
1157 + 'static,
1158 W: Send + Sync,
1159{
1160 async fn download<R>(
1161 &self,
1162 request: R,
1163 ) -> core::result::Result<jacquard_common::xrpc::StreamingResponse, jacquard_common::StreamError>
1164 where
1165 R: XrpcRequest + Send + Sync + serde::Serialize,
1166 <R as XrpcRequest>::Response: Send + Sync,
1167 {
1168 use jacquard_common::StreamError;
1169
1170 let base_uri = <Self as XrpcClient>::base_uri(self).await;
1171 let mut opts = self.options.read().await.clone();
1172 opts.auth = Some(self.access_token().await);
1173 let http_request = build_http_request(&base_uri.borrow(), &request, &opts)
1174 .map_err(|e| StreamError::protocol(e.to_string()))?;
1175 let guard = self.data.read().await;
1176 let mut dpop = guard.dpop_data.clone();
1177 let result = self
1178 .client
1179 .dpop_call(&mut dpop)
1180 .send_streaming(http_request)
1181 .await;
1182 drop(guard);
1183
1184 match result {
1185 Ok(response) => Ok(response),
1186 Err(_e) => {
1187 // Check if it's an auth error and retry
1188 opts.auth = Some(
1189 self.refresh()
1190 .await
1191 .map_err(|e| StreamError::transport(e))?,
1192 );
1193 let http_request = build_http_request(&base_uri.borrow(), &request, &opts)
1194 .map_err(|e| StreamError::protocol(e.to_string()))?;
1195 let guard = self.data.read().await;
1196 let mut dpop = guard.dpop_data.clone();
1197 self.client
1198 .dpop_call(&mut dpop)
1199 .send_streaming(http_request)
1200 .await
1201 .map_err(StreamError::transport)
1202 }
1203 }
1204 }
1205
1206 async fn stream<Str, B>(
1207 &self,
1208 stream: jacquard_common::xrpc::streaming::XrpcProcedureSend<Str::Frame<B>>,
1209 ) -> core::result::Result<
1210 jacquard_common::xrpc::streaming::XrpcResponseStream<
1211 <<Str as jacquard_common::xrpc::streaming::XrpcProcedureStream>::Response as jacquard_common::xrpc::streaming::XrpcStreamResp>::Frame<B>,
1212 >,
1213 jacquard_common::StreamError,
1214 >
1215 where
1216 Str: jacquard_common::xrpc::streaming::XrpcProcedureStream + 'static,
1217 <<Str as jacquard_common::xrpc::streaming::XrpcProcedureStream>::Response as jacquard_common::xrpc::streaming::XrpcStreamResp>::Frame<B>: jacquard_common::xrpc::streaming::XrpcStreamResp,
1218 B: BosStr + 'static,
1219 {
1220 use jacquard_common::StreamError;
1221 use n0_future::TryStreamExt;
1222
1223 let base_uri = self.base_uri().await;
1224 let mut opts = self.options.read().await.clone();
1225 opts.auth = Some(self.access_token().await);
1226
1227 let mut path = String::from(base_uri.as_str().trim_end_matches('/'));
1228 path.push_str("/xrpc/");
1229 path.push_str(<Str::Request as jacquard_common::xrpc::XrpcRequest>::NSID);
1230
1231 let mut builder = http::Request::post(path);
1232
1233 if let Some(token) = &opts.auth {
1234 use jacquard_common::AuthorizationToken;
1235 let hv = match token {
1236 AuthorizationToken::Bearer(t) => {
1237 http::HeaderValue::from_str(&format!("Bearer {}", t.as_str()))
1238 }
1239 AuthorizationToken::Dpop(t) => {
1240 http::HeaderValue::from_str(&format!("DPoP {}", t.as_str()))
1241 }
1242 }
1243 .map_err(|e| StreamError::protocol(format!("Invalid authorization token: {}", e)))?;
1244 builder = builder.header(http::header::AUTHORIZATION, hv);
1245 }
1246
1247 if let Some(proxy) = &opts.atproto_proxy {
1248 builder = builder.header("atproto-proxy", proxy.as_str());
1249 }
1250 if let Some(labelers) = &opts.atproto_accept_labelers {
1251 if !labelers.is_empty() {
1252 let joined = labelers
1253 .iter()
1254 .map(|s| s.as_ref())
1255 .collect::<Vec<_>>()
1256 .join(", ");
1257 builder = builder.header("atproto-accept-labelers", joined);
1258 }
1259 }
1260 for (name, value) in &opts.extra_headers {
1261 builder = builder.header(name, value);
1262 }
1263
1264 let (parts, _) = builder
1265 .body(())
1266 .map_err(|e| StreamError::protocol(e.to_string()))?
1267 .into_parts();
1268
1269 let body_stream =
1270 jacquard_common::stream::ByteStream::new(Box::pin(stream.0.map_ok(|f| f.buffer)));
1271
1272 let guard = self.data.read().await;
1273 let mut dpop = guard.dpop_data.clone();
1274 let result = self
1275 .client
1276 .dpop_call(&mut dpop)
1277 .send_bidirectional(parts, body_stream)
1278 .await;
1279 drop(guard);
1280
1281 match result {
1282 Ok(response) => {
1283 let (resp_parts, resp_body) = response.into_parts();
1284 Ok(
1285 jacquard_common::xrpc::streaming::XrpcResponseStream::from_typed_parts::<B>(
1286 resp_parts, resp_body,
1287 ),
1288 )
1289 }
1290 Err(e) => {
1291 // OAuth token refresh and retry is handled by dpop wrapper
1292 // If we get here, it's a real error
1293 Err(StreamError::transport(e))
1294 }
1295 }
1296 }
1297}
1298
1299fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool {
1300 use jacquard_common::error::ClientErrorKind;
1301
1302 match response {
1303 Err(e) => match e.kind() {
1304 ClientErrorKind::Auth(AuthError::InvalidToken) => true,
1305 ClientErrorKind::Auth(AuthError::Other(value)) => value
1306 .to_str()
1307 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
1308 _ => false,
1309 },
1310 // Some servers return 200/401 with an error in the body rather than using
1311 // WWW-Authenticate. Check the raw response bytes for the invalid_token pattern.
1312 Ok(resp) => {
1313 resp.status() == http::StatusCode::UNAUTHORIZED
1314 || serde_json::from_slice::<serde_json::Value>(resp.buffer())
1315 .ok()
1316 .and_then(|v| v.get("error")?.as_str().map(|s| s == "invalid_token"))
1317 .unwrap_or(false)
1318 }
1319 }
1320}
1321
1322impl<T, S, W> IdentityResolver for OAuthSession<T, S, W>
1323where
1324 S: ClientAuthStore + Send + Sync + 'static,
1325 T: OAuthResolver + IdentityResolver + XrpcExt + Send + Sync + 'static,
1326 W: Send + Sync,
1327{
1328 fn options(&self) -> &ResolverOptions {
1329 self.client.options()
1330 }
1331
1332 async fn resolve_handle<Str: BosStr + Sync>(
1333 &self,
1334 handle: &Handle<Str>,
1335 ) -> std::result::Result<Did, IdentityError> {
1336 self.client.resolve_handle(handle).await
1337 }
1338
1339 async fn resolve_did_doc<Str: BosStr + Sync>(
1340 &self,
1341 did: &Did<Str>,
1342 ) -> std::result::Result<DidDocResponse, IdentityError> {
1343 self.client.resolve_did_doc(did).await
1344 }
1345}
1346
1347#[cfg(feature = "websocket")]
1348impl<T, S, W> WebSocketClient for OAuthSession<T, S, W>
1349where
1350 S: ClientAuthStore + Send + Sync + 'static,
1351 T: OAuthResolver + Send + Sync + 'static,
1352 W: WebSocketClient + Send + Sync,
1353{
1354 type Error = W::Error;
1355
1356 async fn connect(
1357 &self,
1358 uri: Uri<&str>,
1359 ) -> std::result::Result<WebSocketConnection, Self::Error> {
1360 self.ws_client.connect(uri).await
1361 }
1362
1363 async fn connect_with_headers(
1364 &self,
1365 uri: Uri<&str>,
1366 headers: Vec<(CowStr<'_>, CowStr<'_>)>,
1367 ) -> std::result::Result<WebSocketConnection, Self::Error> {
1368 self.ws_client.connect_with_headers(uri, headers).await
1369 }
1370}
1371
1372#[cfg(feature = "websocket")]
1373impl<T, S, W> jacquard_common::xrpc::SubscriptionClient for OAuthSession<T, S, W>
1374where
1375 S: ClientAuthStore + Send + Sync + 'static,
1376 T: OAuthResolver + Send + Sync + 'static,
1377 W: WebSocketClient + Send + Sync,
1378{
1379 async fn base_uri(&self) -> Uri<String> {
1380 self.data.read().await.host_url.clone()
1381 }
1382
1383 async fn subscription_opts(&self) -> jacquard_common::xrpc::SubscriptionOptions<'_> {
1384 let mut opts = jacquard_common::xrpc::SubscriptionOptions::default();
1385 let token = self.access_token().await;
1386 let auth_value = match token {
1387 AuthorizationToken::Bearer(t) => format!("Bearer {}", t.as_str()),
1388 AuthorizationToken::Dpop(t) => format!("DPoP {}", t.as_str()),
1389 };
1390 opts.headers
1391 .push((CowStr::from("Authorization"), CowStr::from(auth_value)));
1392 opts
1393 }
1394
1395 async fn subscribe<Sub>(
1396 &self,
1397 params: &Sub,
1398 ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error>
1399 where
1400 Sub: XrpcSubscription + Send + Sync + serde::Serialize,
1401 {
1402 let opts = self.subscription_opts().await;
1403 self.subscribe_with_opts(params, opts).await
1404 }
1405
1406 async fn subscribe_with_opts<Sub>(
1407 &self,
1408 params: &Sub,
1409 opts: jacquard_common::xrpc::SubscriptionOptions<'_>,
1410 ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error>
1411 where
1412 Sub: XrpcSubscription + Send + Sync + serde::Serialize,
1413 {
1414 use jacquard_common::xrpc::SubscriptionExt;
1415 let base = self.base_uri().await;
1416 self.subscription(base)
1417 .with_options(opts)
1418 .subscribe(params)
1419 .await
1420 }
1421}
1422
1423#[cfg(all(test, feature = "scope-check"))]
1424mod tests {
1425 use super::*;
1426 use crate::scopes::{RepoAction, RepoScope, RpcAudience, RpcLexicon, RpcScope};
1427 use std::collections::BTreeSet;
1428
1429 /// Test that a scope granting access to an RPC method works correctly.
1430 #[test]
1431 fn test_scope_check_permits_matching_rpc() {
1432 // AC7.1: Session with rpc:com.example.test grants access to com.example.test.
1433 let mut rpc_scope_set = BTreeSet::new();
1434 rpc_scope_set.insert(RpcLexicon::Nsid(
1435 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1436 ));
1437 let mut aud_set = BTreeSet::new();
1438 aud_set.insert(RpcAudience::All);
1439
1440 let granted_scope = Scope::Rpc(RpcScope {
1441 lxm: rpc_scope_set,
1442 aud: aud_set,
1443 });
1444
1445 // Target scope for a request to com.example.test.
1446 let mut target_lxm = BTreeSet::new();
1447 target_lxm.insert(RpcLexicon::Nsid(
1448 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1449 ));
1450 let mut target_aud = BTreeSet::new();
1451 target_aud.insert(RpcAudience::All);
1452
1453 let target_scope = Scope::Rpc(RpcScope {
1454 lxm: target_lxm,
1455 aud: target_aud,
1456 });
1457
1458 // The granted scope should permit the target scope.
1459 assert!(
1460 granted_scope.grants(&target_scope),
1461 "rpc:com.example.test should grant access to com.example.test"
1462 );
1463 }
1464
1465 /// Test that rpc:* wildcard grants access to all RPC methods.
1466 #[test]
1467 fn test_scope_check_permits_rpc_wildcard() {
1468 // AC7.1: Session with rpc:* (wildcard) grants access to any RPC method.
1469 let mut rpc_scope_set: BTreeSet<RpcLexicon<SmolStr>> = BTreeSet::new();
1470 rpc_scope_set.insert(RpcLexicon::All);
1471 let mut aud_set: BTreeSet<RpcAudience<SmolStr>> = BTreeSet::new();
1472 aud_set.insert(RpcAudience::All);
1473
1474 let wildcard_scope = Scope::Rpc(RpcScope {
1475 lxm: rpc_scope_set,
1476 aud: aud_set,
1477 });
1478
1479 // Target scope for any request.
1480 let mut target_lxm = BTreeSet::new();
1481 target_lxm.insert(RpcLexicon::Nsid(
1482 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1483 ));
1484 let mut target_aud = BTreeSet::new();
1485 target_aud.insert(RpcAudience::All);
1486
1487 let target_scope = Scope::Rpc(RpcScope {
1488 lxm: target_lxm,
1489 aud: target_aud,
1490 });
1491
1492 // Wildcard should grant any target scope.
1493 assert!(
1494 wildcard_scope.grants(&target_scope),
1495 "rpc:* should grant access to any RPC method"
1496 );
1497 }
1498
1499 /// Test that an unmatched scope denies access.
1500 #[test]
1501 fn test_scope_check_denies_ungranted() {
1502 // AC7.4: Session with rpc:com.example.other denies access to com.example.test.
1503 let mut rpc_scope_set = BTreeSet::new();
1504 rpc_scope_set.insert(RpcLexicon::Nsid(
1505 Nsid::<SmolStr>::new_static("com.example.other").unwrap(),
1506 ));
1507 let mut aud_set = BTreeSet::new();
1508 aud_set.insert(RpcAudience::All);
1509
1510 let granted_scope = Scope::Rpc(RpcScope {
1511 lxm: rpc_scope_set,
1512 aud: aud_set,
1513 });
1514
1515 // Target scope for a request to com.example.test.
1516 let mut target_lxm = BTreeSet::new();
1517 target_lxm.insert(RpcLexicon::Nsid(
1518 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1519 ));
1520 let mut target_aud = BTreeSet::new();
1521 target_aud.insert(RpcAudience::All);
1522
1523 let target_scope = Scope::Rpc(RpcScope {
1524 lxm: target_lxm,
1525 aud: target_aud,
1526 });
1527
1528 // rpc:com.example.other should NOT grant access to com.example.test.
1529 assert!(
1530 !granted_scope.grants(&target_scope),
1531 "rpc:com.example.other should NOT grant access to com.example.test"
1532 );
1533 }
1534
1535 /// Test that a repo scope grants access to the specified collection.
1536 #[test]
1537 fn test_scope_check_permits_repo_scope() {
1538 // AC7.1: Session with repo:com.example.test grants access to that collection.
1539 let mut actions = BTreeSet::new();
1540 actions.insert(RepoAction::Create);
1541 actions.insert(RepoAction::Update);
1542 actions.insert(RepoAction::Delete);
1543
1544 let granted_repo = Scope::Repo(RepoScope {
1545 collection: RepoCollection::Nsid(
1546 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1547 ),
1548 actions,
1549 });
1550
1551 // Target scope for a request to com.example.test.
1552 let mut target_actions = BTreeSet::new();
1553 target_actions.insert(RepoAction::Create);
1554 target_actions.insert(RepoAction::Update);
1555 target_actions.insert(RepoAction::Delete);
1556
1557 let target_repo = Scope::Repo(RepoScope {
1558 collection: RepoCollection::Nsid(
1559 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1560 ),
1561 actions: target_actions,
1562 });
1563
1564 // The repo scope should grant the target scope.
1565 assert!(
1566 granted_repo.grants(&target_repo),
1567 "repo:com.example.test should grant repo access to com.example.test"
1568 );
1569 }
1570
1571 /// Test that ScopeError provides diagnostic information.
1572 #[test]
1573 fn test_scope_error_diagnostic_info() {
1574 // AC7.4: ScopeError includes request NSID and granted scope summary.
1575 let err = ScopeError {
1576 nsid: SmolStr::from("com.example.test"),
1577 granted: SmolStr::from("rpc:com.example.other"),
1578 };
1579
1580 assert_eq!(err.nsid, "com.example.test");
1581 assert_eq!(err.granted, "rpc:com.example.other");
1582 let error_msg = err.to_string();
1583 assert!(
1584 error_msg.contains("not permitted"),
1585 "error message should indicate request is not permitted"
1586 );
1587 }
1588
1589 /// Test that multiple granted scopes are checked correctly.
1590 #[test]
1591 fn test_scope_check_multiple_scopes() {
1592 // AC7.1: With multiple scopes, request matching one of them is permitted.
1593 let mut other_lxm = BTreeSet::new();
1594 other_lxm.insert(RpcLexicon::Nsid(
1595 Nsid::<SmolStr>::new_static("com.example.other").unwrap(),
1596 ));
1597 let mut other_aud = BTreeSet::new();
1598 other_aud.insert(RpcAudience::All);
1599
1600 let other_scope = Scope::Rpc(RpcScope {
1601 lxm: other_lxm,
1602 aud: other_aud,
1603 });
1604
1605 let mut test_lxm = BTreeSet::new();
1606 test_lxm.insert(RpcLexicon::Nsid(
1607 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1608 ));
1609 let mut test_aud = BTreeSet::new();
1610 test_aud.insert(RpcAudience::All);
1611
1612 let test_scope = Scope::Rpc(RpcScope {
1613 lxm: test_lxm,
1614 aud: test_aud,
1615 });
1616
1617 // Target scope for a request to com.example.test.
1618 let mut target_lxm = BTreeSet::new();
1619 target_lxm.insert(RpcLexicon::Nsid(
1620 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1621 ));
1622 let mut target_aud = BTreeSet::new();
1623 target_aud.insert(RpcAudience::All);
1624
1625 let target_scope = Scope::Rpc(RpcScope {
1626 lxm: target_lxm,
1627 aud: target_aud,
1628 });
1629
1630 // With multiple scopes, if one matches, the check passes.
1631 let granted_scopes = vec![other_scope, test_scope];
1632 let is_permitted = granted_scopes.iter().any(|s| s.grants(&target_scope));
1633 assert!(
1634 is_permitted,
1635 "at least one granted scope should permit the target request"
1636 );
1637 }
1638
1639 /// Test that both RPC and repo scopes are checked when determining permissions.
1640 #[test]
1641 fn test_scope_check_rpc_and_repo_paths() {
1642 // AC7.1: A request can be granted via either rpc: or repo: scopes.
1643
1644 // Create a repo scope for the collection.
1645 let mut repo_actions = BTreeSet::new();
1646 repo_actions.insert(RepoAction::Create);
1647 repo_actions.insert(RepoAction::Update);
1648 repo_actions.insert(RepoAction::Delete);
1649
1650 let repo_scope = Scope::Repo(RepoScope {
1651 collection: RepoCollection::Nsid(
1652 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1653 ),
1654 actions: repo_actions,
1655 });
1656
1657 // Target scope for a request to com.example.test (as repo operations).
1658 let mut target_actions = BTreeSet::new();
1659 target_actions.insert(RepoAction::Create);
1660 target_actions.insert(RepoAction::Update);
1661 target_actions.insert(RepoAction::Delete);
1662
1663 let target_repo = Scope::Repo(RepoScope {
1664 collection: RepoCollection::Nsid(
1665 Nsid::<SmolStr>::new_static("com.example.test").unwrap(),
1666 ),
1667 actions: target_actions,
1668 });
1669
1670 // The repo scope should satisfy the request.
1671 assert!(
1672 repo_scope.grants(&target_repo),
1673 "repo scope should grant repo-based requests"
1674 );
1675 }
1676}