A better Rust ATProto crate
1

Configure Feed

Select the types of activity you want to include in your feed.

at main 44 kB View raw
1use std::str::FromStr; 2 3use chrono::{TimeDelta, Utc}; 4use http::{Method, Request, StatusCode}; 5use jacquard_common::{ 6 CowStr, 7 bos::{BosStr, DefaultStr}, 8 http_client::HttpClient, 9 session::SessionStoreError, 10 types::{ 11 did::Did, 12 string::{AtStrError, Datetime}, 13 }, 14}; 15use jacquard_identity::resolver::IdentityError; 16use serde::Serialize; 17use serde_json::Value; 18use smol_str::ToSmolStr; 19 20use jose_jwa::Signing; 21 22use crate::{ 23 FALLBACK_ALG, 24 atproto::atproto_client_metadata, 25 dpop::DpopExt, 26 jose::jwt::{RegisteredClaims, RegisteredClaimsAud}, 27 keyset::Keyset, 28 resolver::OAuthResolver, 29 scopes::Scopes, 30 session::{ 31 AuthRequestData, ClientData, ClientSessionData, DpopClientData, DpopDataSource, DpopReqData, 32 }, 33 types::{ 34 AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptionPrompt, 35 OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthParResponse, 36 OAuthTokenResponse, ParParameters, RefreshRequestParameters, RevocationRequestParameters, 37 TokenGrantType, TokenRequestParameters, TokenSet, 38 }, 39 utils::{generate_dpop_key, generate_nonce, generate_pkce}, 40}; 41 42// https://datatracker.ietf.org/doc/html/rfc7523#section-2.2 43const CLIENT_ASSERTION_TYPE_JWT_BEARER: &str = 44 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; 45 46use smol_str::SmolStr; 47 48/// Convenience alias for a heap-allocated, thread-safe, `'static` error value. 49pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; 50 51/// OAuth request error for token operations and auth flows 52#[derive(Debug, thiserror::Error, miette::Diagnostic)] 53#[error("{kind}")] 54pub struct RequestError { 55 #[diagnostic_source] 56 kind: RequestErrorKind, 57 #[source] 58 source: Option<BoxError>, 59 #[help] 60 help: Option<SmolStr>, 61 context: Option<SmolStr>, 62 url: Option<SmolStr>, 63 details: Option<SmolStr>, 64 location: Option<SmolStr>, 65} 66 67/// Error categories for OAuth request operations 68#[derive(Debug, thiserror::Error, miette::Diagnostic)] 69#[non_exhaustive] 70pub enum RequestErrorKind { 71 /// No endpoint available 72 #[error("no {0} endpoint available")] 73 #[diagnostic( 74 code(jacquard_oauth::request::no_endpoint), 75 help("server does not advertise this endpoint") 76 )] 77 NoEndpoint(SmolStr), 78 79 /// Token response verification failed 80 #[error("token response verification failed")] 81 #[diagnostic(code(jacquard_oauth::request::token_verification))] 82 TokenVerification, 83 84 /// Unsupported authentication method 85 #[error("unsupported authentication method")] 86 #[diagnostic( 87 code(jacquard_oauth::request::unsupported_auth_method), 88 help( 89 "server must support `private_key_jwt` or `none`; configure client metadata accordingly" 90 ) 91 )] 92 UnsupportedAuthMethod, 93 94 /// No refresh token available 95 #[error("no refresh token available")] 96 #[diagnostic(code(jacquard_oauth::request::no_refresh_token))] 97 NoRefreshToken, 98 99 /// Invalid DID 100 #[error("failed to parse DID")] 101 #[diagnostic(code(jacquard_oauth::request::invalid_did))] 102 InvalidDid, 103 104 /// DPoP client error 105 #[error("dpop error")] 106 #[diagnostic(code(jacquard_oauth::request::dpop))] 107 Dpop, 108 109 /// Session storage error 110 #[error("storage error")] 111 #[diagnostic(code(jacquard_oauth::request::storage))] 112 Storage, 113 114 /// Resolver error 115 #[error("resolver error")] 116 #[diagnostic(code(jacquard_oauth::request::resolver))] 117 Resolver, 118 119 /// HTTP build error 120 #[error("http build error")] 121 #[diagnostic(code(jacquard_oauth::request::http_build))] 122 HttpBuild, 123 124 /// HTTP status error 125 #[error("http status: {0}")] 126 #[diagnostic( 127 code(jacquard_oauth::request::http_status), 128 help("see server response for details") 129 )] 130 HttpStatus(StatusCode), 131 132 /// HTTP status with error body 133 #[error("http status: {status}, body: {body:?}")] 134 #[diagnostic( 135 code(jacquard_oauth::request::http_status_body), 136 help("server returned error JSON; inspect fields like `error`, `error_description`") 137 )] 138 HttpStatusWithBody { 139 /// HTTP status code returned by the server. 140 status: StatusCode, 141 /// Parsed JSON body containing OAuth error fields such as `error` and `error_description`. 142 body: Value, 143 }, 144 145 /// Identity resolution error 146 #[error("identity error")] 147 #[diagnostic(code(jacquard_oauth::request::identity))] 148 Identity, 149 150 /// Keyset error 151 #[error("keyset error")] 152 #[diagnostic(code(jacquard_oauth::request::keyset))] 153 Keyset, 154 155 /// Form serialization error 156 #[error("form serialization error")] 157 #[diagnostic(code(jacquard_oauth::request::serde_form))] 158 SerdeHtmlForm, 159 160 /// JSON error 161 #[error("json error")] 162 #[diagnostic(code(jacquard_oauth::request::serde_json))] 163 SerdeJson, 164 165 /// Atproto OAuth requires pushed authorization requests, but the server does not advertise PAR support. 166 #[error("atproto OAuth requires pushed authorization requests")] 167 #[diagnostic( 168 code(jacquard_oauth::request::par_required), 169 help( 170 "use an atproto-compatible authorization server that supports pushed authorization requests" 171 ) 172 )] 173 ParRequired, 174 175 /// Atproto metadata error 176 #[error("atproto error")] 177 #[diagnostic(code(jacquard_oauth::request::atproto))] 178 Atproto, 179} 180 181impl RequestError { 182 /// Create a new error with the given kind and optional source 183 pub fn new(kind: RequestErrorKind, source: Option<BoxError>) -> Self { 184 Self { 185 kind, 186 source, 187 help: None, 188 context: None, 189 url: None, 190 details: None, 191 location: None, 192 } 193 } 194 195 /// Get the error kind 196 pub fn kind(&self) -> &RequestErrorKind { 197 &self.kind 198 } 199 200 /// Get the source error if present 201 pub fn source_err(&self) -> Option<&BoxError> { 202 self.source.as_ref() 203 } 204 205 /// Get the context string if present 206 pub fn context(&self) -> Option<&str> { 207 self.context.as_ref().map(|s| s.as_str()) 208 } 209 210 /// Get the URL if present 211 pub fn url(&self) -> Option<&str> { 212 self.url.as_ref().map(|s| s.as_str()) 213 } 214 215 /// Get the details if present 216 pub fn details(&self) -> Option<&str> { 217 self.details.as_ref().map(|s| s.as_str()) 218 } 219 220 /// Get the location if present 221 pub fn location(&self) -> Option<&str> { 222 self.location.as_ref().map(|s| s.as_str()) 223 } 224 225 /// Add help text to this error 226 pub fn with_help(mut self, help: impl Into<SmolStr>) -> Self { 227 self.help = Some(help.into()); 228 self 229 } 230 231 /// Add context to this error 232 pub fn with_context(mut self, context: impl Into<SmolStr>) -> Self { 233 self.context = Some(context.into()); 234 self 235 } 236 237 /// Add URL to this error 238 pub fn with_url(mut self, url: impl Into<SmolStr>) -> Self { 239 self.url = Some(url.into()); 240 self 241 } 242 243 /// Add details to this error 244 pub fn with_details(mut self, details: impl Into<SmolStr>) -> Self { 245 self.details = Some(details.into()); 246 self 247 } 248 249 /// Add location to this error 250 pub fn with_location(mut self, location: impl Into<SmolStr>) -> Self { 251 self.location = Some(location.into()); 252 self 253 } 254 255 // Constructors for each kind 256 257 /// Create a no endpoint error 258 pub fn no_endpoint(endpoint: impl Into<SmolStr>) -> Self { 259 Self::new(RequestErrorKind::NoEndpoint(endpoint.into()), None) 260 } 261 262 /// Create a token verification error 263 pub fn token_verification() -> Self { 264 Self::new(RequestErrorKind::TokenVerification, None) 265 } 266 267 /// Create an unsupported authentication method error 268 pub fn unsupported_auth_method() -> Self { 269 Self::new(RequestErrorKind::UnsupportedAuthMethod, None) 270 } 271 272 /// Create a no refresh token error 273 pub fn no_refresh_token() -> Self { 274 Self::new(RequestErrorKind::NoRefreshToken, None) 275 } 276 277 /// Create an invalid DID error 278 pub fn invalid_did(source: impl std::error::Error + Send + Sync + 'static) -> Self { 279 Self::new(RequestErrorKind::InvalidDid, Some(Box::new(source))) 280 } 281 282 /// Create a DPoP error 283 pub fn dpop(source: impl std::error::Error + Send + Sync + 'static) -> Self { 284 Self::new(RequestErrorKind::Dpop, Some(Box::new(source))) 285 } 286 287 /// Create a storage error 288 pub fn storage(source: impl std::error::Error + Send + Sync + 'static) -> Self { 289 Self::new(RequestErrorKind::Storage, Some(Box::new(source))) 290 } 291 292 /// Create a resolver error 293 pub fn resolver(source: impl std::error::Error + Send + Sync + 'static) -> Self { 294 Self::new(RequestErrorKind::Resolver, Some(Box::new(source))) 295 } 296 297 /// Create an HTTP build error 298 pub fn http_build(source: impl std::error::Error + Send + Sync + 'static) -> Self { 299 Self::new(RequestErrorKind::HttpBuild, Some(Box::new(source))) 300 } 301 302 /// Create an HTTP status error 303 pub fn http_status(status: StatusCode) -> Self { 304 Self::new(RequestErrorKind::HttpStatus(status), None) 305 } 306 307 /// Create an HTTP status with body error 308 pub fn http_status_with_body(status: StatusCode, body: Value) -> Self { 309 Self::new(RequestErrorKind::HttpStatusWithBody { status, body }, None) 310 } 311 312 /// Create an identity error 313 pub fn identity(source: impl std::error::Error + Send + Sync + 'static) -> Self { 314 Self::new(RequestErrorKind::Identity, Some(Box::new(source))) 315 } 316 317 /// Create a keyset error 318 pub fn keyset(source: impl std::error::Error + Send + Sync + 'static) -> Self { 319 Self::new(RequestErrorKind::Keyset, Some(Box::new(source))) 320 } 321 322 /// Create an atproto metadata error 323 pub fn atproto(source: impl std::error::Error + Send + Sync + 'static) -> Self { 324 Self::new(RequestErrorKind::Atproto, Some(Box::new(source))) 325 } 326 327 /// Create an error for atproto metadata that does not advertise required PAR support. 328 pub fn par_required() -> Self { 329 Self::new(RequestErrorKind::ParRequired, None) 330 .with_context("atproto OAuth requires pushed authorization requests") 331 .with_details("authorization server metadata did not advertise `pushed_authorization_request_endpoint` and did not set `require_pushed_authorization_requests = true`") 332 .with_help("use an atproto-compatible authorization server that supports pushed authorization requests") 333 } 334 335 /// Returns true if this error indicates permanent auth failure 336 /// (token revoked, refresh_token expired, etc.) 337 /// 338 /// When this returns true, the session should be cleared from storage 339 /// rather than retried. 340 pub fn is_permanent(&self) -> bool { 341 match &self.kind { 342 RequestErrorKind::NoRefreshToken => true, 343 RequestErrorKind::HttpStatusWithBody { body, .. } => body 344 .get("error") 345 .and_then(|e| e.as_str()) 346 .is_some_and(|e| matches!(e, "invalid_grant" | "access_denied")), 347 _ => false, 348 } 349 } 350} 351 352// From impls for common error types 353 354impl From<AtStrError> for RequestError { 355 fn from(e: AtStrError) -> Self { 356 let msg = smol_str::format_smolstr!("{:?}", e); 357 Self::new(RequestErrorKind::InvalidDid, Some(Box::new(e))) 358 .with_context(msg) 359 .with_help("ensure DID is correctly formatted (e.g., did:plc:abc123)") 360 } 361} 362 363impl From<crate::dpop::DpopError> for RequestError { 364 fn from(e: crate::dpop::DpopError) -> Self { 365 let msg = smol_str::format_smolstr!("{:?}", e); 366 Self::new(RequestErrorKind::Dpop, Some(Box::new(e))) 367 .with_context(msg) 368 .with_help("check DPoP key configuration and nonce handling") 369 } 370} 371 372impl From<SessionStoreError> for RequestError { 373 fn from(e: SessionStoreError) -> Self { 374 let msg = smol_str::format_smolstr!("{:?}", e); 375 Self::new(RequestErrorKind::Storage, Some(Box::new(e))) 376 .with_context(msg) 377 .with_help("verify session store is accessible and writable") 378 } 379} 380 381impl From<crate::resolver::ResolverError> for RequestError { 382 fn from(e: crate::resolver::ResolverError) -> Self { 383 let msg = smol_str::format_smolstr!("{:?}", e); 384 Self::new(RequestErrorKind::Resolver, Some(Box::new(e))) 385 .with_context(msg) 386 .with_help("check identity resolution and OAuth metadata endpoints") 387 } 388} 389 390impl From<http::Error> for RequestError { 391 fn from(e: http::Error) -> Self { 392 let msg = smol_str::format_smolstr!("{:?}", e); 393 Self::new(RequestErrorKind::HttpBuild, Some(Box::new(e))) 394 .with_context(msg) 395 .with_help("verify request URIs and headers are valid") 396 } 397} 398 399impl From<IdentityError> for RequestError { 400 fn from(e: IdentityError) -> Self { 401 let msg = smol_str::format_smolstr!("{:?}", e); 402 Self::new(RequestErrorKind::Identity, Some(Box::new(e))) 403 .with_context(msg) 404 .with_help("check handle/DID is valid and identity resolver is configured") 405 } 406} 407 408impl From<crate::keyset::Error> for RequestError { 409 fn from(e: crate::keyset::Error) -> Self { 410 let msg = smol_str::format_smolstr!("{:?}", e); 411 Self::new(RequestErrorKind::Keyset, Some(Box::new(e))) 412 .with_context(msg) 413 .with_help("verify keyset configuration and signing algorithm support") 414 } 415} 416 417impl From<serde_html_form::ser::Error> for RequestError { 418 fn from(e: serde_html_form::ser::Error) -> Self { 419 let msg = smol_str::format_smolstr!("{:?}", e); 420 Self::new(RequestErrorKind::SerdeHtmlForm, Some(Box::new(e))) 421 .with_context(msg) 422 .with_help("check OAuth request parameters are serializable") 423 } 424} 425 426impl From<serde_json::Error> for RequestError { 427 fn from(e: serde_json::Error) -> Self { 428 let msg = smol_str::format_smolstr!("{:?}", e); 429 Self::new(RequestErrorKind::SerdeJson, Some(Box::new(e))) 430 .with_context(msg) 431 .with_help("verify OAuth response body is valid JSON") 432 } 433} 434 435impl From<crate::atproto::Error> for RequestError { 436 fn from(e: crate::atproto::Error) -> Self { 437 let msg = smol_str::format_smolstr!("{:?}", e); 438 Self::new(RequestErrorKind::Atproto, Some(Box::new(e))) 439 .with_context(msg) 440 .with_help("ensure client metadata matches atproto requirements") 441 } 442} 443 444/// Convenience `Result` type for OAuth request operations, defaulting to [`RequestError`]. 445pub type Result<T> = core::result::Result<T, RequestError>; 446 447/// Represents the different OAuth token-endpoint request types sent by this crate. 448#[allow(dead_code)] 449#[non_exhaustive] 450pub enum OAuthRequest<'a> { 451 /// Standard authorization-code token exchange. 452 Token(TokenRequestParameters<&'a str>), 453 /// Refresh-token grant to obtain a fresh access token. 454 Refresh(RefreshRequestParameters<&'a str>), 455 /// Token revocation request (RFC 7009). 456 Revocation(RevocationRequestParameters<&'a str>), 457 /// Reserved for future atproto OAuth token introspection support. 458 /// 459 /// Atproto OAuth does not currently specify or operationally support RFC 7662 460 /// token introspection. This variant is intentionally unconstructible until 461 /// support is added to the atproto OAuth profile. 462 #[doc(hidden)] 463 Introspection(core::convert::Infallible), 464 /// Pushed authorization request (RFC 9126) for pre-registering auth parameters. 465 PushedAuthorizationRequest(ParParameters<&'a str>), 466} 467 468impl OAuthRequest<'_> { 469 /// Return a human-readable name for this request variant, used in error messages. 470 pub fn name(&self) -> &'static str { 471 match self { 472 Self::Token(_) => "token", 473 Self::Refresh(_) => "refresh", 474 Self::Revocation(_) => "revocation", 475 Self::Introspection(never) => match *never {}, 476 Self::PushedAuthorizationRequest(_) => "pushed_authorization_request", 477 } 478 } 479 /// Returns the HTTP status code that a successful response to this request should carry. 480 pub fn expected_status(&self) -> StatusCode { 481 match self { 482 Self::Token(_) | Self::Refresh(_) => StatusCode::OK, 483 Self::PushedAuthorizationRequest(_) => StatusCode::CREATED, 484 // Unlike https://datatracker.ietf.org/doc/html/rfc7009#section-2.2, oauth-provider seems to return `204`. 485 Self::Revocation(_) => StatusCode::NO_CONTENT, 486 Self::Introspection(never) => match *never {}, 487 } 488 } 489} 490 491/// The serialized body of an OAuth token-endpoint request. 492#[derive(Debug, Serialize)] 493pub struct RequestPayload<'a, T> 494where 495 T: Serialize, 496{ 497 /// The OAuth `client_id` advertised in the client metadata document. 498 client_id: CowStr<'a>, 499 /// The assertion type URI; set to `urn:ietf:params:oauth:client-assertion-type:jwt-bearer` 500 /// when using `private_key_jwt` client authentication. 501 #[serde(skip_serializing_if = "Option::is_none")] 502 client_assertion_type: Option<CowStr<'a>>, 503 /// A JWT signed with the client's private key, proving client identity to the server. 504 #[serde(skip_serializing_if = "Option::is_none")] 505 client_assertion: Option<SmolStr>, 506 /// The grant-specific parameters (token request, refresh, PAR, etc.) flattened into the body. 507 #[serde(flatten)] 508 parameters: T, 509} 510 511/// Bundled OAuth metadata needed to perform token-endpoint operations. 512/// 513/// Aggregates the server's authorization server metadata, the client's own registered metadata, 514/// and the optional signing keyset into a single value that is passed to helper functions such 515/// as [`par`], [`exchange_code`], [`refresh`], and [`revoke`]. 516#[derive(Debug, Clone)] 517pub struct OAuthMetadata<S: BosStr = DefaultStr> { 518 /// Metadata fetched from the authorization server's `/.well-known/oauth-authorization-server` document. 519 pub server_metadata: OAuthAuthorizationServerMetadata, 520 /// This client's registered metadata, derived from [`crate::atproto::AtprotoClientMetadata`]. 521 pub client_metadata: OAuthClientMetadata<S>, 522 /// Optional signing keyset; required for `private_key_jwt` client authentication. 523 pub keyset: Option<Keyset>, 524} 525 526impl<S: BosStr> OAuthMetadata<S> { 527 /// Fetch server metadata and assemble an `OAuthMetadata` from an active session context. 528 /// 529 /// Contacts the authorization server recorded in `session_data` to retrieve its current 530 /// metadata, then combines it with the client configuration. This is the preferred way to 531 /// build an `OAuthMetadata` during token refresh or revocation. 532 pub async fn new<T: HttpClient + OAuthResolver + Send + Sync>( 533 client: &T, 534 ClientData { keyset, config }: &ClientData<S>, 535 session_data: &ClientSessionData, 536 ) -> Result<Self> 537 where 538 S: Clone + FromStr + Ord, 539 <S as FromStr>::Err: core::fmt::Debug, 540 { 541 Ok(OAuthMetadata { 542 server_metadata: client 543 .get_authorization_server_metadata(session_data.authserver_url.as_ref()) 544 .await?, 545 client_metadata: atproto_client_metadata(&config, &keyset)?, 546 keyset: keyset.clone(), 547 }) 548 } 549} 550 551/// Perform a Pushed Authorization Request (PAR) and return the resulting state for the auth flow. 552/// 553/// Generates a PKCE code challenge, a fresh DPoP key, and a random `state` token, then POSTs 554/// them to the authorization server's PAR endpoint. The returned [`AuthRequestData`] must be 555/// persisted (e.g., in the auth store) so it can be retrieved and verified during 556/// [`crate::client::OAuthClient::callback`]. 557#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all, fields(login_hint = login_hint.as_ref().map(|h| h.as_ref()))))] 558pub async fn par< 559 S: BosStr + Clone + Send + Sync, 560 T: OAuthResolver + DpopExt + Send + Sync + 'static, 561>( 562 client: &T, 563 login_hint: Option<S>, 564 prompt: Option<AuthorizeOptionPrompt>, 565 metadata: &mut OAuthMetadata<S>, 566 state: Option<SmolStr>, 567) -> crate::request::Result<AuthRequestData> { 568 let state = state.unwrap_or_else(generate_nonce); 569 let (code_challenge, verifier) = generate_pkce(); 570 571 let Some(dpop_key) = generate_dpop_key(&mut metadata.server_metadata) else { 572 return Err(RequestError::token_verification()); 573 }; 574 let mut dpop_data = DpopReqData { 575 dpop_key, 576 dpop_authserver_nonce: None, 577 }; 578 let parameters: ParParameters<&str> = ParParameters { 579 response_type: AuthorizationResponseType::Code, 580 redirect_uri: metadata.client_metadata.redirect_uris[0].as_ref(), 581 state: state.as_ref(), 582 scope: metadata.client_metadata.scope.as_ref().map(|s| s.as_ref()), 583 response_mode: None, 584 code_challenge: code_challenge.as_str(), 585 code_challenge_method: AuthorizationCodeChallengeMethod::S256, 586 login_hint: login_hint.as_ref().map(|h| h.as_ref()), 587 prompt: prompt.map(|p| p.into()), 588 }; 589 590 if metadata 591 .server_metadata 592 .pushed_authorization_request_endpoint 593 .is_some() 594 { 595 let par_response = oauth_request::<OAuthParResponse, T, DpopReqData, _>( 596 &client, 597 &mut dpop_data, 598 OAuthRequest::PushedAuthorizationRequest(parameters), 599 metadata, 600 ) 601 .await?; 602 603 let scopes = if let Some(scope) = &metadata.client_metadata.scope { 604 Scopes::new(scope.as_ref().to_smolstr()).expect("Failed to parse scopes") 605 } else { 606 Scopes::empty() 607 }; 608 let auth_req_data: AuthRequestData = AuthRequestData { 609 state: state.into(), 610 authserver_url: metadata.server_metadata.issuer.to_smolstr(), 611 account_did: None, 612 scopes, 613 request_uri: par_response.request_uri.clone(), 614 authserver_token_endpoint: metadata.server_metadata.token_endpoint.to_smolstr(), 615 authserver_revocation_endpoint: metadata.server_metadata.revocation_endpoint.clone(), 616 pkce_verifier: verifier.into(), 617 dpop_data, 618 }; 619 620 Ok(auth_req_data) 621 } else if metadata 622 .server_metadata 623 .require_pushed_authorization_requests 624 == Some(true) 625 { 626 Err(RequestError::no_endpoint("pushed_authorization_request")) 627 } else { 628 Err(RequestError::par_required()) 629 } 630} 631 632/// Exchange a refresh token for a fresh token set and update the session data in place. 633#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all, fields(did = %session_data.account_did)))] 634pub async fn refresh<S, T>( 635 client: &T, 636 mut session_data: ClientSessionData, 637 metadata: &OAuthMetadata<S>, 638) -> Result<ClientSessionData> 639where 640 S: BosStr + FromStr, 641 T: OAuthResolver + DpopExt + Send + Sync + 'static, 642{ 643 let Some(refresh_token) = session_data.token_set.refresh_token.as_ref() else { 644 return Err(RequestError::no_refresh_token()); 645 }; 646 647 // /!\ IMPORTANT /!\ 648 // 649 // The "sub" MUST be a DID, whose issuer authority is indeed the server we 650 // are trying to obtain credentials from. Note that we are doing this 651 // *before* we actually try to refresh the token: 652 // 1) To avoid unnecessary refresh 653 // 2) So that the refresh is the last async operation, ensuring as few 654 // async operations happen before the result gets a chance to be stored. 655 let aud = client 656 .verify_issuer(&metadata.server_metadata, &session_data.token_set.sub) 657 .await?; 658 let iss = metadata.server_metadata.issuer.clone(); 659 660 let response = oauth_request::<OAuthTokenResponse, T, DpopClientData, _>( 661 client, 662 &mut session_data.dpop_data, 663 OAuthRequest::Refresh(RefreshRequestParameters { 664 grant_type: TokenGrantType::RefreshToken, 665 refresh_token: refresh_token.as_ref(), 666 scope: None, 667 }), 668 metadata, 669 ) 670 .await?; 671 672 let expires_at = response.expires_in.and_then(|expires_in| { 673 let now = Datetime::now(); 674 now.as_ref() 675 .checked_add_signed(TimeDelta::seconds(expires_in)) 676 .map(Datetime::new) 677 }); 678 679 session_data.update_with_tokens(&TokenSet { 680 iss, 681 sub: session_data.token_set.sub.clone(), 682 aud: aud.as_str().to_smolstr(), 683 scope: response.scope, 684 access_token: response.access_token, 685 refresh_token: response.refresh_token, 686 token_type: response.token_type, 687 expires_at, 688 }); 689 690 Ok(session_data) 691} 692 693/// Exchange an authorization code for a token set and return a fully-verified [`TokenSet`]. 694/// 695/// Per the AT Protocol OAuth spec, the `sub` claim in the token response **must** be verified 696/// against the expected authorization server issuer before the token can be trusted. This 697/// function performs that verification as part of the exchange, so callers receive a token 698/// set that is safe to persist. 699#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))] 700pub async fn exchange_code<S, T, D>( 701 client: &T, 702 data_source: &mut D, 703 code: &str, 704 verifier: &str, 705 metadata: &OAuthMetadata<S>, 706) -> Result<TokenSet> 707where 708 S: BosStr + Send + Sync, 709 T: OAuthResolver + DpopExt + Send + Sync + 'static, 710 D: DpopDataSource, 711{ 712 let token_response = oauth_request::<OAuthTokenResponse, T, D, _>( 713 client, 714 data_source, 715 OAuthRequest::Token(TokenRequestParameters { 716 grant_type: TokenGrantType::AuthorizationCode, 717 code: code.into(), 718 redirect_uri: metadata.client_metadata.redirect_uris[0].as_ref(), 719 code_verifier: verifier.into(), 720 }), 721 metadata, 722 ) 723 .await?; 724 let Some(sub) = token_response.sub else { 725 return Err(RequestError::token_verification()); 726 }; 727 let sub = Did::new_owned(sub)?; 728 let iss = metadata.server_metadata.issuer.clone(); 729 // /!\ IMPORTANT /!\ 730 // 731 // The token_response MUST always be valid before the "sub" it contains 732 // can be trusted (see Atproto's OAuth spec for details). 733 let aud = client 734 .verify_issuer(&metadata.server_metadata, &sub) 735 .await?; 736 737 let expires_at = token_response.expires_in.and_then(|expires_in| { 738 Datetime::now() 739 .as_ref() 740 .checked_add_signed(TimeDelta::seconds(expires_in)) 741 .map(Datetime::new) 742 }); 743 Ok(TokenSet { 744 iss, 745 sub, 746 aud: aud.as_str().to_smolstr(), 747 scope: token_response.scope, 748 access_token: token_response.access_token, 749 refresh_token: token_response.refresh_token, 750 token_type: token_response.token_type, 751 expires_at, 752 }) 753} 754 755/// Send a token revocation request (RFC 7009) to the authorization server. 756/// 757/// This function is called by [`crate::client::OAuthSession::logout`] when a revocation endpoint is advertised 758/// by the server. The caller is responsible for deleting the session from local storage regardless 759/// of whether revocation succeeds. 760#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))] 761pub async fn revoke<S: BosStr + Send + Sync, T, D>( 762 client: &T, 763 data_source: &mut D, 764 token: &str, 765 metadata: &OAuthMetadata<S>, 766) -> Result<()> 767where 768 T: OAuthResolver + DpopExt + Send + Sync + 'static, 769 D: DpopDataSource, 770{ 771 oauth_request::<(), T, D, _>( 772 client, 773 data_source, 774 OAuthRequest::Revocation(RevocationRequestParameters { 775 token: token.into(), 776 }), 777 metadata, 778 ) 779 .await?; 780 Ok(()) 781} 782 783/// Low-level function for sending an OAuth token-endpoint request and deserializing the response. 784/// 785/// Selects the correct server endpoint for `request`, builds the form-encoded body with 786/// client authentication, performs the DPoP-wrapped HTTP POST, and deserializes the response 787/// body into `O`. The type parameter `O` is inferred from the call site; use `()` for requests 788/// where the response body is empty (e.g., revocation). 789pub async fn oauth_request<'r, O, T, D, S: BosStr>( 790 client: &T, 791 data_source: &mut D, 792 request: OAuthRequest<'r>, 793 metadata: &OAuthMetadata<S>, 794) -> Result<O> 795where 796 T: OAuthResolver + DpopExt + Send + Sync + 'static, 797 O: serde::de::DeserializeOwned, 798 D: DpopDataSource, 799{ 800 let Some(url) = endpoint_for_req(&metadata.server_metadata, &request) else { 801 return Err(RequestError::no_endpoint(request.name())); 802 }; 803 let client_assertions = build_auth( 804 metadata.keyset.as_ref(), 805 &metadata.server_metadata, 806 &metadata.client_metadata, 807 )?; 808 let body = match &request { 809 OAuthRequest::Token(params) => build_oauth_req_body(client_assertions, params)?, 810 OAuthRequest::Refresh(params) => build_oauth_req_body(client_assertions, params)?, 811 OAuthRequest::Revocation(params) => build_oauth_req_body(client_assertions, params)?, 812 OAuthRequest::PushedAuthorizationRequest(params) => { 813 build_oauth_req_body(client_assertions, params)? 814 } 815 OAuthRequest::Introspection(never) => match *never {}, 816 }; 817 let req = Request::builder() 818 .uri(url) 819 .method(Method::POST) 820 .header("Content-Type", "application/x-www-form-urlencoded") 821 .body(body.into_bytes())?; 822 let res = client.dpop_server_call(data_source).send(req).await?; 823 if res.status() == request.expected_status() { 824 let body = res.body(); 825 if body.is_empty() { 826 // since an empty body cannot be deserialized, use “null” temporarily to allow deserialization to `()`. 827 Ok(serde_json::from_slice(b"null")?) 828 } else { 829 let output: O = serde_json::from_slice(body)?; 830 Ok(output) 831 } 832 } else if res.status().is_client_error() { 833 Err(RequestError::http_status_with_body( 834 res.status(), 835 serde_json::from_slice(res.body())?, 836 )) 837 } else { 838 Err(RequestError::http_status(res.status())) 839 } 840} 841 842#[inline] 843fn endpoint_for_req<'r, S: BosStr>( 844 server_metadata: &'r OAuthAuthorizationServerMetadata<S>, 845 request: &'r OAuthRequest, 846) -> Option<&'r str> { 847 match request { 848 OAuthRequest::Token(_) | OAuthRequest::Refresh(_) => { 849 Some(server_metadata.token_endpoint.as_ref()) 850 } 851 OAuthRequest::Revocation(_) => server_metadata 852 .revocation_endpoint 853 .as_ref() 854 .map(AsRef::as_ref), 855 OAuthRequest::Introspection(never) => match *never {}, 856 OAuthRequest::PushedAuthorizationRequest(_) => server_metadata 857 .pushed_authorization_request_endpoint 858 .as_ref() 859 .map(AsRef::as_ref), 860 } 861} 862 863#[inline] 864fn build_oauth_req_body<'a, S>(client_assertions: ClientAuth<'a>, parameters: S) -> Result<String> 865where 866 S: Serialize, 867{ 868 Ok(serde_html_form::to_string(RequestPayload { 869 client_id: client_assertions.client_id, 870 client_assertion_type: client_assertions.assertion_type, 871 client_assertion: client_assertions.assertion, 872 parameters, 873 })?) 874} 875 876/// Client identity fields appended to every token-endpoint request body. 877/// 878/// Encapsulates the result of choosing a client authentication method (`none` vs. 879/// `private_key_jwt`). The `build_auth` helper selects the appropriate variant based 880/// on server capabilities and client configuration. 881#[derive(Debug, Clone, Default)] 882pub struct ClientAuth<'a> { 883 /// The OAuth `client_id` for this client. 884 client_id: CowStr<'a>, 885 /// Either absent (for `none` auth) or `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`. 886 assertion_type: Option<CowStr<'a>>, 887 /// A signed JWT proving client identity; present only for `private_key_jwt` auth. 888 assertion: Option<SmolStr>, 889} 890 891impl<'s> ClientAuth<'s> { 892 /// Construct a `ClientAuth` with only a `client_id` and no assertion (the `none` method). 893 pub fn new_id(client_id: CowStr<'s>) -> Self { 894 Self { 895 client_id, 896 assertion_type: None, 897 assertion: None, 898 } 899 } 900} 901 902fn build_auth<'a, S: BosStr>( 903 keyset: Option<&Keyset>, 904 server_metadata: &'a OAuthAuthorizationServerMetadata, 905 client_metadata: &'a OAuthClientMetadata<S>, 906) -> Result<ClientAuth<'a>> { 907 let method_supported = server_metadata 908 .token_endpoint_auth_methods_supported 909 .as_ref(); 910 911 let client_id = CowStr::Borrowed(client_metadata.client_id.as_ref()); 912 if let Some(method) = client_metadata.token_endpoint_auth_method.as_ref() { 913 match method.as_ref() { 914 "private_key_jwt" 915 if method_supported 916 .as_ref() 917 .is_some_and(|v| v.iter().any(|s| s.as_str() == "private_key_jwt")) => 918 { 919 if let Some(keyset) = &keyset { 920 let mut alg_strs: Vec<&str> = server_metadata 921 .token_endpoint_auth_signing_alg_values_supported 922 .as_ref() 923 .map(|v| v.iter().map(|s| s.as_ref()).collect()) 924 .unwrap_or_default(); 925 if alg_strs.is_empty() { 926 alg_strs.push(FALLBACK_ALG); 927 } 928 let algs: Vec<Signing> = alg_strs 929 .iter() 930 .filter_map(|s| crate::keyset::parse_signing_alg(s)) 931 .collect(); 932 let iat = Utc::now().timestamp(); 933 let client_id_str: &str = client_metadata.client_id.as_ref(); 934 let issuer_str: &str = server_metadata.issuer.as_ref(); 935 return Ok(ClientAuth { 936 client_id: client_id.clone(), 937 assertion_type: Some(CowStr::new_static(CLIENT_ASSERTION_TYPE_JWT_BEARER)), 938 assertion: Some( 939 keyset.create_jwt( 940 &algs, 941 // https://datatracker.ietf.org/doc/html/rfc7523#section-3 942 RegisteredClaims { 943 iss: Some(client_id_str), 944 sub: Some(client_id_str), 945 aud: Some(RegisteredClaimsAud::Single(issuer_str)), 946 exp: Some(iat + 60), 947 // "iat" is required and **MUST** be less than one minute 948 // https://datatracker.ietf.org/doc/html/rfc9101 949 iat: Some(iat), 950 // atproto oauth-provider requires "jti" to be present 951 jti: Some(generate_nonce()), 952 ..Default::default() 953 } 954 .into(), 955 )?, 956 ), 957 }); 958 } 959 } 960 "none" 961 if method_supported 962 .as_ref() 963 .is_some_and(|v| v.iter().any(|s| s.as_str() == "none")) => 964 { 965 return Ok(ClientAuth::new_id(client_id)); 966 } 967 _ => {} 968 } 969 } 970 971 Err(RequestError::unsupported_auth_method()) 972} 973 974#[cfg(test)] 975mod tests { 976 use super::*; 977 use crate::types::{OAuthAuthorizationServerMetadata, OAuthClientMetadata}; 978 use bytes::Bytes; 979 use http::{Response as HttpResponse, StatusCode}; 980 use jacquard_common::{ 981 bos::BosStr, deps::fluent_uri::Uri, http_client::HttpClient, types::string::Did, 982 }; 983 use jacquard_identity::resolver::IdentityResolver; 984 use smol_str::SmolStr; 985 use std::sync::Arc; 986 use tokio::sync::Mutex; 987 988 #[derive(Clone, Default)] 989 struct MockClient { 990 resp: Arc<Mutex<Option<HttpResponse<Vec<u8>>>>>, 991 } 992 993 impl HttpClient for MockClient { 994 type Error = std::convert::Infallible; 995 fn send_http( 996 &self, 997 _request: http::Request<Vec<u8>>, 998 ) -> impl core::future::Future< 999 Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>, 1000 > + Send { 1001 let resp = self.resp.clone(); 1002 async move { Ok(resp.lock().await.take().unwrap()) } 1003 } 1004 } 1005 1006 // IdentityResolver methods won't be called in these tests; provide stubs. 1007 impl IdentityResolver for MockClient { 1008 fn options(&self) -> &jacquard_identity::resolver::ResolverOptions { 1009 use std::sync::LazyLock; 1010 static OPTS: LazyLock<jacquard_identity::resolver::ResolverOptions> = 1011 LazyLock::new(|| jacquard_identity::resolver::ResolverOptions::default()); 1012 &OPTS 1013 } 1014 async fn resolve_handle<S: BosStr + Sync>( 1015 &self, 1016 _handle: &jacquard_common::types::string::Handle<S>, 1017 ) -> std::result::Result<Did, jacquard_identity::resolver::IdentityError> { 1018 Ok(Did::new_static("did:plc:alice").unwrap()) 1019 } 1020 async fn resolve_did_doc<S: BosStr + Sync>( 1021 &self, 1022 _did: &Did<S>, 1023 ) -> std::result::Result< 1024 jacquard_identity::resolver::DidDocResponse, 1025 jacquard_identity::resolver::IdentityError, 1026 > { 1027 let doc = serde_json::json!({ 1028 "id": "did:plc:alice", 1029 "service": [{ 1030 "id": "#pds", 1031 "type": "AtprotoPersonalDataServer", 1032 "serviceEndpoint": "https://pds" 1033 }] 1034 }); 1035 let buf = Bytes::from(serde_json::to_vec(&doc).unwrap()); 1036 Ok(jacquard_identity::resolver::DidDocResponse { 1037 buffer: buf, 1038 status: StatusCode::OK, 1039 requested: None, 1040 }) 1041 } 1042 } 1043 1044 // Allow using DPoP helpers on MockClient 1045 impl crate::dpop::DpopExt for MockClient {} 1046 impl crate::resolver::OAuthResolver for MockClient {} 1047 1048 fn base_metadata() -> OAuthMetadata { 1049 let mut server = OAuthAuthorizationServerMetadata::default(); 1050 server.issuer = SmolStr::new_static("https://issuer"); 1051 server.authorization_endpoint = SmolStr::new_static("https://issuer/authorize"); 1052 server.token_endpoint = SmolStr::new_static("https://issuer/token"); 1053 server.token_endpoint_auth_methods_supported = Some(vec![SmolStr::new_static("none")]); 1054 OAuthMetadata { 1055 server_metadata: server, 1056 client_metadata: OAuthClientMetadata { 1057 client_id: SmolStr::new_static("https://client"), 1058 client_uri: None, 1059 redirect_uris: vec![SmolStr::new_static("https://client/cb")], 1060 scope: Some(SmolStr::new_static("atproto")), 1061 grant_types: None, 1062 response_types: vec![SmolStr::new_static("code")], 1063 application_type: Some(SmolStr::new_static("web")), 1064 token_endpoint_auth_method: Some(SmolStr::new_static("none")), 1065 dpop_bound_access_tokens: None, 1066 jwks_uri: None, 1067 jwks: None, 1068 token_endpoint_auth_signing_alg: None, 1069 client_name: None, 1070 privacy_policy_uri: None, 1071 tos_uri: None, 1072 logo_uri: None, 1073 }, 1074 keyset: None, 1075 } 1076 } 1077 1078 #[tokio::test] 1079 async fn par_missing_endpoint() { 1080 let mut meta = base_metadata(); 1081 meta.server_metadata.require_pushed_authorization_requests = Some(true); 1082 meta.server_metadata.pushed_authorization_request_endpoint = None; 1083 // require_pushed_authorization_requests is true and no endpoint. 1084 let err = super::par(&MockClient::default(), None, None, &mut meta, None) 1085 .await 1086 .unwrap_err(); 1087 assert!( 1088 matches!(err.kind(), RequestErrorKind::NoEndpoint(name) if name == "pushed_authorization_request") 1089 ); 1090 } 1091 1092 #[tokio::test] 1093 async fn par_not_advertised_returns_structured_error() { 1094 let mut meta = base_metadata(); 1095 meta.server_metadata.require_pushed_authorization_requests = Some(false); 1096 meta.server_metadata.pushed_authorization_request_endpoint = None; 1097 let err = super::par(&MockClient::default(), None, None, &mut meta, None) 1098 .await 1099 .unwrap_err(); 1100 1101 assert!(matches!(err.kind(), RequestErrorKind::ParRequired)); 1102 assert_eq!( 1103 err.context(), 1104 Some("atproto OAuth requires pushed authorization requests") 1105 ); 1106 assert!( 1107 err.details() 1108 .is_some_and(|details| details.contains("pushed_authorization_request_endpoint")) 1109 ); 1110 } 1111 1112 #[tokio::test] 1113 async fn refresh_no_refresh_token() { 1114 let client = MockClient::default(); 1115 let meta = base_metadata(); 1116 let session = ClientSessionData { 1117 account_did: Did::new_static("did:plc:alice").unwrap(), 1118 session_id: SmolStr::new_static("state"), 1119 host_url: Uri::parse("https://pds").expect("valid").to_owned(), 1120 authserver_url: SmolStr::new_static("https://issuer"), 1121 authserver_token_endpoint: SmolStr::new_static("https://issuer/token"), 1122 authserver_revocation_endpoint: None, 1123 scopes: Scopes::empty(), 1124 dpop_data: DpopClientData { 1125 dpop_key: crate::utils::generate_key(&[SmolStr::new_static("ES256")]).unwrap(), 1126 dpop_authserver_nonce: SmolStr::default(), 1127 dpop_host_nonce: SmolStr::default(), 1128 }, 1129 token_set: crate::types::TokenSet { 1130 iss: SmolStr::new_static("https://issuer"), 1131 sub: Did::new_static("did:plc:alice").unwrap(), 1132 aud: SmolStr::new_static("https://pds"), 1133 scope: None, 1134 refresh_token: None, 1135 access_token: SmolStr::new_static("abc"), 1136 token_type: crate::types::OAuthTokenType::DPoP, 1137 expires_at: None, 1138 }, 1139 resolved_scopes: None, 1140 }; 1141 let err = super::refresh(&client, session, &meta).await.unwrap_err(); 1142 assert!(matches!(err.kind(), RequestErrorKind::NoRefreshToken)); 1143 } 1144 1145 #[tokio::test] 1146 async fn exchange_code_missing_sub() { 1147 let client = MockClient::default(); 1148 // set mock HTTP response body: token response without `sub` 1149 *client.resp.lock().await = Some( 1150 HttpResponse::builder() 1151 .status(StatusCode::OK) 1152 .body( 1153 serde_json::to_vec(&serde_json::json!({ 1154 "access_token":"tok", 1155 "token_type":"DPoP", 1156 "expires_in": 3600 1157 })) 1158 .unwrap(), 1159 ) 1160 .unwrap(), 1161 ); 1162 let meta = base_metadata(); 1163 let mut dpop = DpopReqData { 1164 dpop_key: crate::utils::generate_key(&[SmolStr::new_static("ES256")]).unwrap(), 1165 dpop_authserver_nonce: None, 1166 }; 1167 let err = super::exchange_code(&client, &mut dpop, "abc", "verifier", &meta) 1168 .await 1169 .unwrap_err(); 1170 assert!(matches!(err.kind(), RequestErrorKind::TokenVerification)); 1171 } 1172}