A better Rust ATProto crate
1

Configure Feed

Select the types of activity you want to include in your feed.

at main 29 kB View raw
1//! Service authentication extractor and middleware. 2//! 3//! Service auth verifies AT Protocol inter-service JWTs. Normal 4//! [`ServiceAuthConfig::new`] configurations require `lxm` method binding and, 5//! when the `service-auth-replay` feature is enabled, reject missing or replayed 6//! `jti` values by default. Use [`ServiceAuthConfig::disable_replay_protection`] 7//! only for legacy compatibility. 8//! 9//! Global service-id allow-lists constrain present `aud` fragments but do not 10//! require a fragment. Use [`require_service_id`] as a route layer for endpoints 11//! that require a specific `did:web:example.com#service_id` audience fragment. 12//! 13//! [`ExtractOptionalServiceAuth`] treats only an absent Authorization header as 14//! anonymous. Present malformed, invalid, or replayed credentials are rejected. 15//! 16//! The default replay store is in-memory and per process. Horizontally scaled 17//! deployments should provide a shared [`ReplayStore`] implementation. 18//! Legacy configs created with [`ServiceAuthConfig::new_legacy`] disable `lxm` 19//! and replay requirements. 20//! 21//! # Example 22//! 23//! ```no_run 24//! use axum::{Router, routing::get}; 25//! use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth}; 26//! use jacquard_identity::JacquardResolver; 27//! use jacquard_identity::resolver::ResolverOptions; 28//! use jacquard_common::types::string::Did; 29//! 30//! async fn handler( 31//! ExtractServiceAuth(auth): ExtractServiceAuth, 32//! ) -> String { 33//! format!("Authenticated as {}", auth.did()) 34//! } 35//! 36//! #[tokio::main] 37//! async fn main() { 38//! let resolver = JacquardResolver::new( 39//! reqwest::Client::new(), 40//! ResolverOptions::default(), 41//! ); 42//! let config = ServiceAuthConfig::new( 43//! Did::new_static("did:web:feedgen.example.com").unwrap(), 44//! resolver, 45//! ); 46//! 47//! let app = Router::new() 48//! .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler)) 49//! .with_state(config); 50//! 51//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 52//! .await 53//! .unwrap(); 54//! axum::serve(listener, app).await.unwrap(); 55//! } 56//! ``` 57 58use axum::{ 59 Extension, Json, 60 extract::FromRequestParts, 61 http::{HeaderValue, StatusCode, header, request::Parts}, 62 middleware::Next, 63 response::{IntoResponse, Response}, 64}; 65use jacquard_common::deps::smol_str::SmolStr; 66use jacquard_common::{ 67 CowStr, IntoStatic, 68 service_auth::{self, PublicKey}, 69 types::{ 70 did_doc::VerificationMethod, 71 string::{Did, DidService, Nsid}, 72 }, 73}; 74use jacquard_identity::resolver::IdentityResolver; 75use serde_json::json; 76use std::future::Future; 77use std::pin::Pin; 78use std::sync::{Arc, Mutex}; 79use thiserror::Error; 80 81/// Replay key for service auth JWT `jti` replay protection. 82#[derive(Debug, Clone, PartialEq, Eq, Hash)] 83pub struct ReplayKey { 84 /// Issuer DID from the JWT. 85 iss: Did, 86 /// Full audience, including any service-id fragment. 87 aud: DidService, 88 /// JWT ID nonce. 89 jti: SmolStr, 90} 91 92impl ReplayKey { 93 /// Create a new replay key. 94 pub fn new(iss: Did, aud: DidService, jti: impl Into<SmolStr>) -> Self { 95 Self { 96 iss, 97 aud, 98 jti: jti.into(), 99 } 100 } 101} 102 103/// Errors returned by replay stores. 104#[derive(Debug, Error)] 105#[non_exhaustive] 106pub enum ReplayStoreError { 107 /// The replay key has already been presented and has not expired. 108 #[error("service auth JWT replay detected")] 109 Replayed, 110 111 /// The replay store failed. 112 #[error("replay store failed: {0}")] 113 Store(String), 114} 115 116/// Store used to reject replayed service auth JWT IDs. 117pub trait ReplayStore: Send + Sync + 'static { 118 /// Check whether `key` has been seen, and record it until `expires_at`. 119 fn check_and_insert( 120 &self, 121 key: ReplayKey, 122 expires_at: i64, 123 ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>>; 124} 125 126/// Replay store that disables replay protection. 127#[derive(Debug, Clone, Copy, Default)] 128pub struct NoopReplayStore; 129 130impl ReplayStore for NoopReplayStore { 131 fn check_and_insert( 132 &self, 133 _key: ReplayKey, 134 _expires_at: i64, 135 ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>> { 136 Box::pin(async { Ok(()) }) 137 } 138} 139 140/// Default in-memory replay store. 141#[cfg(feature = "service-auth-replay")] 142#[derive(Debug, Clone)] 143pub struct InMemoryReplayStore { 144 cache: mini_moka::sync::Cache<ReplayKey, i64>, 145 lock: Arc<Mutex<()>>, 146} 147 148#[cfg(feature = "service-auth-replay")] 149impl Default for InMemoryReplayStore { 150 fn default() -> Self { 151 Self::new(100_000) 152 } 153} 154 155#[cfg(feature = "service-auth-replay")] 156impl InMemoryReplayStore { 157 /// Create an in-memory replay store with a maximum key capacity. 158 pub fn new(max_capacity: u64) -> Self { 159 Self { 160 cache: mini_moka::sync::Cache::new(max_capacity), 161 lock: Arc::new(Mutex::new(())), 162 } 163 } 164} 165 166#[cfg(feature = "service-auth-replay")] 167impl ReplayStore for InMemoryReplayStore { 168 fn check_and_insert( 169 &self, 170 key: ReplayKey, 171 expires_at: i64, 172 ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>> { 173 Box::pin(async move { 174 let _guard = self 175 .lock 176 .lock() 177 .map_err(|_| ReplayStoreError::Store("replay store lock poisoned".to_string()))?; 178 let now = chrono::Utc::now().timestamp(); 179 if let Some(existing_expires_at) = self.cache.get(&key) { 180 if existing_expires_at > now { 181 return Err(ReplayStoreError::Replayed); 182 } 183 self.cache.invalidate(&key); 184 } 185 self.cache.insert(key, expires_at); 186 Ok(()) 187 }) 188 } 189} 190 191/// Trait for providing service authentication configuration. 192/// 193/// This trait allows custom state types to provide service auth configuration 194/// without requiring `ServiceAuthConfig<R>` directly. 195pub trait ServiceAuth { 196 /// The identity resolver type 197 type Resolver: IdentityResolver; 198 199 /// Get the service DID (expected audience) 200 fn service_did(&self) -> Did<&str>; 201 202 /// Get a reference to the identity resolver 203 fn resolver(&self) -> &Self::Resolver; 204 205 /// Whether to require the `lxm` (method binding) field. 206 fn require_lxm(&self) -> bool; 207 208 /// Service-id fragments allowed by global validation. 209 fn allowed_services(&self) -> &[SmolStr]; 210 211 /// Whether replay protection is enabled. 212 fn replay_protection_enabled(&self) -> bool; 213 214 /// Replay store used by replay protection. 215 fn replay_store(&self) -> &dyn ReplayStore; 216} 217 218/// Configuration for service auth verification. 219/// 220/// This should be stored in your Axum app state and will be extracted 221/// by the `ExtractServiceAuth` extractor. 222pub struct ServiceAuthConfig<R> { 223 /// The DID of your service (the expected audience). 224 service_did: Did, 225 /// Identity resolver for fetching DID documents. 226 resolver: Arc<R>, 227 /// Whether to require the `lxm` (method binding) field. 228 require_lxm: bool, 229 /// Globally allowed service-id fragments. 230 allowed_services: Vec<SmolStr>, 231 /// Replay store used when replay protection is enabled. 232 replay_store: Arc<dyn ReplayStore>, 233 /// Whether replay protection is enabled. 234 replay_protection_enabled: bool, 235} 236 237impl<R> Clone for ServiceAuthConfig<R> { 238 fn clone(&self) -> Self { 239 Self { 240 service_did: self.service_did.clone(), 241 resolver: Arc::clone(&self.resolver), 242 require_lxm: self.require_lxm, 243 allowed_services: self.allowed_services.clone(), 244 replay_store: Arc::clone(&self.replay_store), 245 replay_protection_enabled: self.replay_protection_enabled, 246 } 247 } 248} 249 250fn default_replay_store() -> (Arc<dyn ReplayStore>, bool) { 251 #[cfg(feature = "service-auth-replay")] 252 { 253 (Arc::new(InMemoryReplayStore::default()), true) 254 } 255 #[cfg(not(feature = "service-auth-replay"))] 256 { 257 (Arc::new(NoopReplayStore), false) 258 } 259} 260 261impl<R: IdentityResolver> ServiceAuthConfig<R> { 262 /// Create a new service auth config. 263 /// 264 /// This enables `lxm` (method binding). If you need backward compatibility, 265 /// use `ServiceAuthConfig::new_legacy()` 266 pub fn new(service_did: Did, resolver: R) -> Self { 267 let (replay_store, replay_protection_enabled) = default_replay_store(); 268 Self { 269 service_did, 270 resolver: Arc::new(resolver), 271 require_lxm: true, 272 allowed_services: Vec::new(), 273 replay_store, 274 replay_protection_enabled, 275 } 276 } 277 278 /// Create a new service auth config. 279 /// 280 /// `lxm` (method binding) is disabled for backwards compatibility 281 pub fn new_legacy(service_did: Did, resolver: R) -> Self { 282 Self { 283 service_did, 284 resolver: Arc::new(resolver), 285 require_lxm: false, 286 allowed_services: Vec::new(), 287 replay_store: Arc::new(NoopReplayStore), 288 replay_protection_enabled: false, 289 } 290 } 291 292 /// Set whether to require the `lxm` field (method binding). 293 /// 294 /// When enabled, the JWT must contain an `lxm` field matching the requested endpoint. 295 /// This prevents token reuse across different methods. 296 pub fn require_lxm(mut self, require: bool) -> Self { 297 self.require_lxm = require; 298 self 299 } 300 301 /// Replace the global allowed service-id fragments. 302 pub fn with_allowed_services<I, Svc>(mut self, services: I) -> Self 303 where 304 I: IntoIterator<Item = Svc>, 305 Svc: Into<SmolStr>, 306 { 307 self.allowed_services = services.into_iter().map(Into::into).collect(); 308 self 309 } 310 311 /// Add a single global allowed service-id fragment. 312 pub fn allow_service(mut self, service: impl Into<SmolStr>) -> Self { 313 self.allowed_services.push(service.into()); 314 self 315 } 316 317 /// Replace the replay store and enable replay protection. 318 pub fn with_replay_store(mut self, store: impl ReplayStore) -> Self { 319 self.replay_store = Arc::new(store); 320 self.replay_protection_enabled = true; 321 self 322 } 323 324 /// Disable replay protection for legacy compatibility. 325 pub fn disable_replay_protection(mut self) -> Self { 326 self.replay_store = Arc::new(NoopReplayStore); 327 self.replay_protection_enabled = false; 328 self 329 } 330 331 /// Get the globally allowed service-id fragments. 332 pub fn allowed_services(&self) -> &[SmolStr] { 333 &self.allowed_services 334 } 335 336 /// Get the service DID. 337 pub fn service_did(&self) -> Did<&str> { 338 self.service_did.borrow() 339 } 340 341 /// Get a reference to the identity resolver. 342 pub fn resolver(&self) -> &R { 343 &self.resolver 344 } 345} 346 347impl<R: IdentityResolver> ServiceAuth for ServiceAuthConfig<R> { 348 type Resolver = R; 349 350 fn service_did(&self) -> Did<&str> { 351 self.service_did.borrow() 352 } 353 354 fn resolver(&self) -> &Self::Resolver { 355 &self.resolver 356 } 357 358 fn require_lxm(&self) -> bool { 359 self.require_lxm 360 } 361 362 fn allowed_services(&self) -> &[SmolStr] { 363 &self.allowed_services 364 } 365 366 fn replay_protection_enabled(&self) -> bool { 367 self.replay_protection_enabled 368 } 369 370 fn replay_store(&self) -> &dyn ReplayStore { 371 self.replay_store.as_ref() 372 } 373} 374 375/// Route-scoped service auth policy. 376#[derive(Debug, Clone, Default)] 377pub struct ServiceAuthRoutePolicy { 378 /// Required service-id fragment for this route. 379 required_service_id: Option<SmolStr>, 380} 381 382impl ServiceAuthRoutePolicy { 383 /// Require a specific service-id fragment for this route. 384 pub fn require_service_id(service_id: impl Into<SmolStr>) -> Self { 385 Self { 386 required_service_id: Some(service_id.into()), 387 } 388 } 389 390 /// Get the required service-id fragment. 391 pub fn required_service_id(&self) -> Option<&str> { 392 self.required_service_id.as_deref() 393 } 394} 395 396/// Create an axum route layer that requires a specific service-id fragment. 397pub fn require_service_id(service_id: impl Into<SmolStr>) -> Extension<ServiceAuthRoutePolicy> { 398 Extension(ServiceAuthRoutePolicy::require_service_id(service_id)) 399} 400 401/// Verified service authentication information. 402/// 403/// This is the result of successfully verifying a service auth JWT. 404/// This type is extracted by the `ExtractServiceAuth` extractor. 405#[derive(Debug, Clone, jacquard_derive::IntoStatic)] 406pub struct VerifiedServiceAuth<'a> { 407 /// The authenticated user's DID (from `iss` claim) 408 did: Did, 409 /// The audience (should match your service DID, with optional service fragment). 410 aud: DidService, 411 /// The lexicon method NSID, if present 412 lxm: Option<Nsid>, 413 /// JWT ID (nonce), if present 414 jti: Option<CowStr<'a>>, 415} 416 417impl<'a> VerifiedServiceAuth<'a> { 418 /// Get the authenticated user's DID. 419 pub fn did(&self) -> Did<&str> { 420 self.did.borrow() 421 } 422 423 /// Get the full audience, including any service-id fragment. 424 pub fn aud(&self) -> DidService<&str> { 425 self.aud.borrow() 426 } 427 428 /// Get the fragmentless service DID audience. 429 pub fn audience(&self) -> Did<&str> { 430 self.aud.audience() 431 } 432 433 /// Get the optional service-id fragment. 434 pub fn service(&self) -> Option<&str> { 435 self.aud.service() 436 } 437 438 /// Get the lexicon method NSID, if present. 439 pub fn lxm(&self) -> Option<Nsid<&str>> { 440 self.lxm.as_ref().map(|l| l.borrow()) 441 } 442 443 /// Get the JWT ID (nonce), if present. 444 /// 445 /// You can use this for replay protection by tracking seen JTIs 446 /// until their expiration time. 447 pub fn jti(&self) -> Option<&str> { 448 self.jti.as_ref().map(|j| j.as_ref()) 449 } 450} 451 452/// Axum extractor for service authentication. 453/// 454/// This extracts and verifies a service auth JWT from the Authorization header, 455/// resolving the issuer's DID to verify the signature. 456/// 457/// # Example 458/// 459/// ```no_run 460/// use axum::{Router, routing::get}; 461/// use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth}; 462/// use jacquard_identity::JacquardResolver; 463/// use jacquard_identity::resolver::ResolverOptions; 464/// use jacquard_common::types::string::Did; 465/// 466/// async fn handler( 467/// ExtractServiceAuth(auth): ExtractServiceAuth, 468/// ) -> String { 469/// format!("Authenticated as {}", auth.did()) 470/// } 471/// 472/// #[tokio::main] 473/// async fn main() { 474/// let resolver = JacquardResolver::new( 475/// reqwest::Client::new(), 476/// ResolverOptions::default(), 477/// ); 478/// let config = ServiceAuthConfig::new( 479/// Did::new_static("did:web:feedgen.example.com").unwrap(), 480/// resolver, 481/// ); 482/// 483/// let app = Router::new() 484/// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler)) 485/// .with_state(config); 486/// 487/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 488/// .await 489/// .unwrap(); 490/// axum::serve(listener, app).await.unwrap(); 491/// } 492/// ``` 493pub struct ExtractServiceAuth(pub VerifiedServiceAuth<'static>); 494 495/// Axum extractor for optional service authentication. 496/// 497/// Like `ExtractServiceAuth`, but returns `None` if no Authorization header 498/// is present. If a header IS present but invalid, returns an error. 499/// 500/// Use this for endpoints that work for both authenticated and anonymous users, 501/// but show different content based on auth status. 502/// 503/// # Example 504/// 505/// ```no_run 506/// use axum::{Router, routing::get}; 507/// use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractOptionalServiceAuth}; 508/// use jacquard_identity::JacquardResolver; 509/// use jacquard_identity::resolver::ResolverOptions; 510/// use jacquard_common::types::string::Did; 511/// 512/// async fn handler( 513/// ExtractOptionalServiceAuth(auth): ExtractOptionalServiceAuth, 514/// ) -> String { 515/// match auth { 516/// Some(a) => format!("Authenticated as {}", a.did()), 517/// None => "Anonymous request".to_string(), 518/// } 519/// } 520/// 521/// #[tokio::main] 522/// async fn main() { 523/// let resolver = JacquardResolver::new( 524/// reqwest::Client::new(), 525/// ResolverOptions::default(), 526/// ); 527/// let config = ServiceAuthConfig::new( 528/// Did::new_static("did:web:example.com").unwrap(), 529/// resolver, 530/// ); 531/// 532/// let app = Router::new() 533/// .route("/xrpc/com.example.getData", get(handler)) 534/// .with_state(config); 535/// 536/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 537/// .await 538/// .unwrap(); 539/// axum::serve(listener, app).await.unwrap(); 540/// } 541/// ``` 542pub struct ExtractOptionalServiceAuth(pub Option<VerifiedServiceAuth<'static>>); 543 544/// Errors that can occur during service auth verification. 545#[derive(Debug, Error, miette::Diagnostic)] 546#[non_exhaustive] 547pub enum ServiceAuthError { 548 /// Authorization header is missing 549 #[error("missing Authorization header")] 550 MissingAuthHeader, 551 552 /// Authorization header is malformed (not "Bearer `token`") 553 #[error("invalid Authorization header format")] 554 InvalidAuthHeader, 555 556 /// JWT parsing or verification failed 557 #[error("JWT verification failed: {0}")] 558 JwtError(#[from] service_auth::ServiceAuthError), 559 560 /// DID resolution failed 561 #[error("failed to resolve DID {did}: {source}")] 562 DidResolutionFailed { 563 did: Did, 564 #[source] 565 source: Box<dyn std::error::Error + Send + Sync>, 566 }, 567 568 /// No valid signing key found in DID document 569 #[error("no valid signing key found in DID document for {0}")] 570 NoSigningKey(Did), 571 572 /// Method binding required but missing 573 #[error("lxm (method binding) is required but missing from token")] 574 MethodBindingRequired, 575 576 /// Invalid key format 577 #[error("invalid key format: {0}")] 578 InvalidKey(String), 579 580 /// Service-id fragment is required for this route. 581 #[error("service id {required} is required but missing from token audience")] 582 ServiceIdRequired { 583 /// Required service-id fragment. 584 required: SmolStr, 585 }, 586 587 /// Service-id fragment does not match this route. 588 #[error("service id mismatch: required {required}, got {actual}")] 589 RouteServiceIdMismatch { 590 /// Required service-id fragment. 591 required: SmolStr, 592 /// Actual service-id fragment. 593 actual: SmolStr, 594 }, 595 596 /// Replay protection is enabled but the token has no `jti`. 597 #[error("service auth JWT is missing required jti")] 598 MissingJti, 599 600 /// Replay protection rejected the token. 601 #[error("replay protection failed: {0}")] 602 Replay(#[from] ReplayStoreError), 603} 604 605impl IntoResponse for ServiceAuthError { 606 fn into_response(self) -> Response { 607 let (status, error_code, message) = match &self { 608 ServiceAuthError::MissingAuthHeader => { 609 (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string()) 610 } 611 ServiceAuthError::InvalidAuthHeader => { 612 (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string()) 613 } 614 ServiceAuthError::JwtError(_) => ( 615 StatusCode::UNAUTHORIZED, 616 "AuthenticationRequired", 617 self.to_string(), 618 ), 619 ServiceAuthError::DidResolutionFailed { .. } => ( 620 StatusCode::UNAUTHORIZED, 621 "AuthenticationRequired", 622 self.to_string(), 623 ), 624 ServiceAuthError::NoSigningKey(_) => ( 625 StatusCode::UNAUTHORIZED, 626 "AuthenticationRequired", 627 self.to_string(), 628 ), 629 ServiceAuthError::MethodBindingRequired => ( 630 StatusCode::UNAUTHORIZED, 631 "AuthenticationRequired", 632 self.to_string(), 633 ), 634 ServiceAuthError::InvalidKey(_) 635 | ServiceAuthError::ServiceIdRequired { .. } 636 | ServiceAuthError::RouteServiceIdMismatch { .. } 637 | ServiceAuthError::MissingJti 638 | ServiceAuthError::Replay(_) => ( 639 StatusCode::UNAUTHORIZED, 640 "AuthenticationRequired", 641 self.to_string(), 642 ), 643 }; 644 645 tracing::warn!("Service auth failed: {}", message); 646 647 ( 648 status, 649 [( 650 header::CONTENT_TYPE, 651 HeaderValue::from_static("application/json"), 652 )], 653 Json(json!({ 654 "error": error_code, 655 "message": message, 656 })), 657 ) 658 .into_response() 659 } 660} 661 662fn owned_did<S: jacquard_common::BosStr>(did: &Did<S>) -> Did { 663 Did::new_owned(did.as_str()).unwrap() 664} 665 666fn bearer_token_from_parts(parts: &Parts) -> Result<Option<&str>, ServiceAuthError> { 667 let Some(auth_header) = parts.headers.get(header::AUTHORIZATION) else { 668 return Ok(None); 669 }; 670 671 let auth_str = auth_header 672 .to_str() 673 .map_err(|_| ServiceAuthError::InvalidAuthHeader)?; 674 let token = auth_str 675 .strip_prefix("Bearer ") 676 .ok_or(ServiceAuthError::InvalidAuthHeader)?; 677 Ok(Some(token)) 678} 679 680async fn verify_service_auth<S>( 681 parts: &Parts, 682 state: &S, 683 token: &str, 684) -> Result<VerifiedServiceAuth<'static>, ServiceAuthError> 685where 686 S: ServiceAuth + Send + Sync, 687 S::Resolver: Send + Sync, 688{ 689 let parsed = service_auth::parse_jwt(token)?; 690 let claims = parsed.claims(); 691 692 let did_doc = state 693 .resolver() 694 .resolve_did_doc(&claims.iss) 695 .await 696 .map_err(|e| ServiceAuthError::DidResolutionFailed { 697 did: owned_did(&claims.iss), 698 source: Box::new(e), 699 })?; 700 701 let doc = did_doc 702 .parse() 703 .map_err(|e| ServiceAuthError::DidResolutionFailed { 704 did: owned_did(&claims.iss), 705 source: Box::new(e), 706 })?; 707 708 let verification_methods = doc 709 .verification_method 710 .as_deref() 711 .ok_or_else(|| ServiceAuthError::NoSigningKey(owned_did(&claims.iss)))?; 712 713 let signing_key = extract_signing_key(verification_methods) 714 .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 715 716 service_auth::verify_signature(&parsed, &signing_key)?; 717 claims.validate(&state.service_did(), state.allowed_services())?; 718 719 if state.require_lxm() && claims.lxm.is_none() { 720 return Err(ServiceAuthError::MethodBindingRequired); 721 } 722 723 if let Some(policy) = parts.extensions.get::<ServiceAuthRoutePolicy>() { 724 if let Some(required) = policy.required_service_id() { 725 match claims.aud.service() { 726 Some(actual) if actual == required => {} 727 Some(actual) => { 728 return Err(ServiceAuthError::RouteServiceIdMismatch { 729 required: SmolStr::new(required), 730 actual: SmolStr::new(actual), 731 }); 732 } 733 None => { 734 return Err(ServiceAuthError::ServiceIdRequired { 735 required: SmolStr::new(required), 736 }); 737 } 738 } 739 } 740 } 741 742 if state.replay_protection_enabled() { 743 let jti = claims.jti.as_ref().ok_or(ServiceAuthError::MissingJti)?; 744 let key = ReplayKey::new( 745 claims.iss.clone().into_static(), 746 claims.aud.clone().into_static(), 747 jti.clone(), 748 ); 749 state 750 .replay_store() 751 .check_and_insert(key, claims.exp) 752 .await?; 753 } 754 755 Ok(VerifiedServiceAuth { 756 did: claims.iss.clone().into_static(), 757 aud: claims.aud.clone().into_static(), 758 lxm: claims.lxm.as_ref().map(|l| l.clone().into_static()), 759 jti: claims.jti.as_ref().map(|j| CowStr::from(j.clone())), 760 }) 761} 762 763impl<S> FromRequestParts<S> for ExtractServiceAuth 764where 765 S: ServiceAuth + Send + Sync, 766 S::Resolver: Send + Sync, 767{ 768 type Rejection = ServiceAuthError; 769 770 fn from_request_parts( 771 parts: &mut Parts, 772 state: &S, 773 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send { 774 async move { 775 let token = 776 bearer_token_from_parts(parts)?.ok_or(ServiceAuthError::MissingAuthHeader)?; 777 verify_service_auth(parts, state, token).await.map(Self) 778 } 779 } 780} 781 782impl<S> FromRequestParts<S> for ExtractOptionalServiceAuth 783where 784 S: ServiceAuth + Send + Sync, 785 S::Resolver: Send + Sync, 786{ 787 type Rejection = ServiceAuthError; 788 789 fn from_request_parts( 790 parts: &mut Parts, 791 state: &S, 792 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send { 793 async move { 794 let Some(token) = bearer_token_from_parts(parts)? else { 795 return Ok(Self(None)); 796 }; 797 verify_service_auth(parts, state, token) 798 .await 799 .map(|auth| Self(Some(auth))) 800 } 801 } 802} 803 804/// Extract the signing key from a DID document's verification methods. 805/// 806/// This looks for a key with type "atproto" or the first available key 807/// if no atproto-specific key is found. 808fn extract_signing_key(methods: &[VerificationMethod<CowStr<'_>>]) -> Option<PublicKey> { 809 // First try to find an atproto-specific key 810 let atproto_method = methods 811 .iter() 812 .find(|m| m.r#type.as_ref() == "Multikey" || m.r#type.as_ref() == "atproto"); 813 814 let method = atproto_method.or_else(|| methods.first())?; 815 816 // Parse the multikey 817 let public_key_multibase = method.public_key_multibase.as_ref()?; 818 819 // Decode multibase 820 let (_, key_bytes) = multibase::decode(public_key_multibase.as_ref()).ok()?; 821 822 // First two bytes are the multicodec prefix 823 if key_bytes.len() < 2 { 824 return None; 825 } 826 827 let codec = &key_bytes[..2]; 828 let key_material = &key_bytes[2..]; 829 830 match codec { 831 // p256-pub (0x1200) 832 [0x80, 0x24] => PublicKey::from_p256_bytes(key_material) 833 .inspect_err(|_e| { 834 #[cfg(feature = "tracing")] 835 tracing::error!("Failed to parse p256 public key: {}", _e); 836 }) 837 .ok(), 838 // secp256k1-pub (0xe7) 839 [0xe7, 0x01] => PublicKey::from_k256_bytes(key_material) 840 .inspect_err(|_e| { 841 #[cfg(feature = "tracing")] 842 tracing::error!("Failed to parse secp256k1 public key: {}", _e); 843 }) 844 .ok(), 845 _ => { 846 #[cfg(feature = "tracing")] 847 tracing::error!("Unsupported public key multicodec: {:?}", codec); 848 None 849 } 850 } 851} 852 853/// Middleware for verifying service authentication on all requests. 854/// 855/// This middleware extracts and verifies the service auth JWT, then adds the 856/// `VerifiedServiceAuth` to request extensions for downstream handlers to access. 857/// 858/// # Example 859/// 860/// ```no_run 861/// use axum::{Router, routing::get, middleware, Extension}; 862/// use jacquard_axum::service_auth::{ServiceAuthConfig, service_auth_middleware}; 863/// use jacquard_identity::{PublicResolver, JacquardResolver}; 864/// use jacquard_identity::resolver::ResolverOptions; 865/// use jacquard_common::types::string::Did; 866/// 867/// async fn handler( 868/// Extension(auth): Extension<jacquard_axum::service_auth::VerifiedServiceAuth<'static>>, 869/// ) -> String { 870/// format!("Authenticated as {}", auth.did()) 871/// } 872/// 873/// #[tokio::main] 874/// async fn main() { 875/// let resolver = JacquardResolver::new( 876/// reqwest::Client::new(), 877/// ResolverOptions::default(), 878/// ); 879/// let config = ServiceAuthConfig::new( 880/// Did::new_static("did:web:feedgen.example.com").unwrap(), 881/// resolver, 882/// ); 883/// 884/// let app = Router::new() 885/// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler)) 886/// .layer(middleware::from_fn_with_state( 887/// config.clone(), 888/// service_auth_middleware::<ServiceAuthConfig<PublicResolver>>, 889/// )) 890/// .with_state(config); 891/// 892/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 893/// .await 894/// .unwrap(); 895/// axum::serve(listener, app).await.unwrap(); 896/// } 897/// ``` 898pub async fn service_auth_middleware<S>( 899 state: axum::extract::State<S>, 900 mut req: axum::extract::Request, 901 next: Next, 902) -> Result<Response, ServiceAuthError> 903where 904 S: ServiceAuth + Send + Sync + Clone, 905 S::Resolver: Send + Sync, 906{ 907 // Extract auth from request parts 908 let (mut parts, body) = req.into_parts(); 909 let ExtractServiceAuth(auth) = 910 ExtractServiceAuth::from_request_parts(&mut parts, &state.0).await?; 911 912 // Add auth to extensions 913 parts.extensions.insert(auth); 914 915 // Reconstruct request and continue 916 req = axum::extract::Request::from_parts(parts, body); 917 Ok(next.run(req).await) 918}