A better Rust ATProto crate
1//! Service authentication extractor and middleware.
2//!
3//! Service auth verifies AT Protocol inter-service JWTs. Normal
4//! [`ServiceAuthConfig::new`] configurations require `lxm` method binding and,
5//! when the `service-auth-replay` feature is enabled, reject missing or replayed
6//! `jti` values by default. Use [`ServiceAuthConfig::disable_replay_protection`]
7//! only for legacy compatibility.
8//!
9//! Global service-id allow-lists constrain present `aud` fragments but do not
10//! require a fragment. Use [`require_service_id`] as a route layer for endpoints
11//! that require a specific `did:web:example.com#service_id` audience fragment.
12//!
13//! [`ExtractOptionalServiceAuth`] treats only an absent Authorization header as
14//! anonymous. Present malformed, invalid, or replayed credentials are rejected.
15//!
16//! The default replay store is in-memory and per process. Horizontally scaled
17//! deployments should provide a shared [`ReplayStore`] implementation.
18//! Legacy configs created with [`ServiceAuthConfig::new_legacy`] disable `lxm`
19//! and replay requirements.
20//!
21//! # Example
22//!
23//! ```no_run
24//! use axum::{Router, routing::get};
25//! use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth};
26//! use jacquard_identity::JacquardResolver;
27//! use jacquard_identity::resolver::ResolverOptions;
28//! use jacquard_common::types::string::Did;
29//!
30//! async fn handler(
31//! ExtractServiceAuth(auth): ExtractServiceAuth,
32//! ) -> String {
33//! format!("Authenticated as {}", auth.did())
34//! }
35//!
36//! #[tokio::main]
37//! async fn main() {
38//! let resolver = JacquardResolver::new(
39//! reqwest::Client::new(),
40//! ResolverOptions::default(),
41//! );
42//! let config = ServiceAuthConfig::new(
43//! Did::new_static("did:web:feedgen.example.com").unwrap(),
44//! resolver,
45//! );
46//!
47//! let app = Router::new()
48//! .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler))
49//! .with_state(config);
50//!
51//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
52//! .await
53//! .unwrap();
54//! axum::serve(listener, app).await.unwrap();
55//! }
56//! ```
57
58use axum::{
59 Extension, Json,
60 extract::FromRequestParts,
61 http::{HeaderValue, StatusCode, header, request::Parts},
62 middleware::Next,
63 response::{IntoResponse, Response},
64};
65use jacquard_common::deps::smol_str::SmolStr;
66use jacquard_common::{
67 CowStr, IntoStatic,
68 service_auth::{self, PublicKey},
69 types::{
70 did_doc::VerificationMethod,
71 string::{Did, DidService, Nsid},
72 },
73};
74use jacquard_identity::resolver::IdentityResolver;
75use serde_json::json;
76use std::future::Future;
77use std::pin::Pin;
78use std::sync::{Arc, Mutex};
79use thiserror::Error;
80
81/// Replay key for service auth JWT `jti` replay protection.
82#[derive(Debug, Clone, PartialEq, Eq, Hash)]
83pub struct ReplayKey {
84 /// Issuer DID from the JWT.
85 iss: Did,
86 /// Full audience, including any service-id fragment.
87 aud: DidService,
88 /// JWT ID nonce.
89 jti: SmolStr,
90}
91
92impl ReplayKey {
93 /// Create a new replay key.
94 pub fn new(iss: Did, aud: DidService, jti: impl Into<SmolStr>) -> Self {
95 Self {
96 iss,
97 aud,
98 jti: jti.into(),
99 }
100 }
101}
102
103/// Errors returned by replay stores.
104#[derive(Debug, Error)]
105#[non_exhaustive]
106pub enum ReplayStoreError {
107 /// The replay key has already been presented and has not expired.
108 #[error("service auth JWT replay detected")]
109 Replayed,
110
111 /// The replay store failed.
112 #[error("replay store failed: {0}")]
113 Store(String),
114}
115
116/// Store used to reject replayed service auth JWT IDs.
117pub trait ReplayStore: Send + Sync + 'static {
118 /// Check whether `key` has been seen, and record it until `expires_at`.
119 fn check_and_insert(
120 &self,
121 key: ReplayKey,
122 expires_at: i64,
123 ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>>;
124}
125
126/// Replay store that disables replay protection.
127#[derive(Debug, Clone, Copy, Default)]
128pub struct NoopReplayStore;
129
130impl ReplayStore for NoopReplayStore {
131 fn check_and_insert(
132 &self,
133 _key: ReplayKey,
134 _expires_at: i64,
135 ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>> {
136 Box::pin(async { Ok(()) })
137 }
138}
139
140/// Default in-memory replay store.
141#[cfg(feature = "service-auth-replay")]
142#[derive(Debug, Clone)]
143pub struct InMemoryReplayStore {
144 cache: mini_moka::sync::Cache<ReplayKey, i64>,
145 lock: Arc<Mutex<()>>,
146}
147
148#[cfg(feature = "service-auth-replay")]
149impl Default for InMemoryReplayStore {
150 fn default() -> Self {
151 Self::new(100_000)
152 }
153}
154
155#[cfg(feature = "service-auth-replay")]
156impl InMemoryReplayStore {
157 /// Create an in-memory replay store with a maximum key capacity.
158 pub fn new(max_capacity: u64) -> Self {
159 Self {
160 cache: mini_moka::sync::Cache::new(max_capacity),
161 lock: Arc::new(Mutex::new(())),
162 }
163 }
164}
165
166#[cfg(feature = "service-auth-replay")]
167impl ReplayStore for InMemoryReplayStore {
168 fn check_and_insert(
169 &self,
170 key: ReplayKey,
171 expires_at: i64,
172 ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>> {
173 Box::pin(async move {
174 let _guard = self
175 .lock
176 .lock()
177 .map_err(|_| ReplayStoreError::Store("replay store lock poisoned".to_string()))?;
178 let now = chrono::Utc::now().timestamp();
179 if let Some(existing_expires_at) = self.cache.get(&key) {
180 if existing_expires_at > now {
181 return Err(ReplayStoreError::Replayed);
182 }
183 self.cache.invalidate(&key);
184 }
185 self.cache.insert(key, expires_at);
186 Ok(())
187 })
188 }
189}
190
191/// Trait for providing service authentication configuration.
192///
193/// This trait allows custom state types to provide service auth configuration
194/// without requiring `ServiceAuthConfig<R>` directly.
195pub trait ServiceAuth {
196 /// The identity resolver type
197 type Resolver: IdentityResolver;
198
199 /// Get the service DID (expected audience)
200 fn service_did(&self) -> Did<&str>;
201
202 /// Get a reference to the identity resolver
203 fn resolver(&self) -> &Self::Resolver;
204
205 /// Whether to require the `lxm` (method binding) field.
206 fn require_lxm(&self) -> bool;
207
208 /// Service-id fragments allowed by global validation.
209 fn allowed_services(&self) -> &[SmolStr];
210
211 /// Whether replay protection is enabled.
212 fn replay_protection_enabled(&self) -> bool;
213
214 /// Replay store used by replay protection.
215 fn replay_store(&self) -> &dyn ReplayStore;
216}
217
218/// Configuration for service auth verification.
219///
220/// This should be stored in your Axum app state and will be extracted
221/// by the `ExtractServiceAuth` extractor.
222pub struct ServiceAuthConfig<R> {
223 /// The DID of your service (the expected audience).
224 service_did: Did,
225 /// Identity resolver for fetching DID documents.
226 resolver: Arc<R>,
227 /// Whether to require the `lxm` (method binding) field.
228 require_lxm: bool,
229 /// Globally allowed service-id fragments.
230 allowed_services: Vec<SmolStr>,
231 /// Replay store used when replay protection is enabled.
232 replay_store: Arc<dyn ReplayStore>,
233 /// Whether replay protection is enabled.
234 replay_protection_enabled: bool,
235}
236
237impl<R> Clone for ServiceAuthConfig<R> {
238 fn clone(&self) -> Self {
239 Self {
240 service_did: self.service_did.clone(),
241 resolver: Arc::clone(&self.resolver),
242 require_lxm: self.require_lxm,
243 allowed_services: self.allowed_services.clone(),
244 replay_store: Arc::clone(&self.replay_store),
245 replay_protection_enabled: self.replay_protection_enabled,
246 }
247 }
248}
249
250fn default_replay_store() -> (Arc<dyn ReplayStore>, bool) {
251 #[cfg(feature = "service-auth-replay")]
252 {
253 (Arc::new(InMemoryReplayStore::default()), true)
254 }
255 #[cfg(not(feature = "service-auth-replay"))]
256 {
257 (Arc::new(NoopReplayStore), false)
258 }
259}
260
261impl<R: IdentityResolver> ServiceAuthConfig<R> {
262 /// Create a new service auth config.
263 ///
264 /// This enables `lxm` (method binding). If you need backward compatibility,
265 /// use `ServiceAuthConfig::new_legacy()`
266 pub fn new(service_did: Did, resolver: R) -> Self {
267 let (replay_store, replay_protection_enabled) = default_replay_store();
268 Self {
269 service_did,
270 resolver: Arc::new(resolver),
271 require_lxm: true,
272 allowed_services: Vec::new(),
273 replay_store,
274 replay_protection_enabled,
275 }
276 }
277
278 /// Create a new service auth config.
279 ///
280 /// `lxm` (method binding) is disabled for backwards compatibility
281 pub fn new_legacy(service_did: Did, resolver: R) -> Self {
282 Self {
283 service_did,
284 resolver: Arc::new(resolver),
285 require_lxm: false,
286 allowed_services: Vec::new(),
287 replay_store: Arc::new(NoopReplayStore),
288 replay_protection_enabled: false,
289 }
290 }
291
292 /// Set whether to require the `lxm` field (method binding).
293 ///
294 /// When enabled, the JWT must contain an `lxm` field matching the requested endpoint.
295 /// This prevents token reuse across different methods.
296 pub fn require_lxm(mut self, require: bool) -> Self {
297 self.require_lxm = require;
298 self
299 }
300
301 /// Replace the global allowed service-id fragments.
302 pub fn with_allowed_services<I, Svc>(mut self, services: I) -> Self
303 where
304 I: IntoIterator<Item = Svc>,
305 Svc: Into<SmolStr>,
306 {
307 self.allowed_services = services.into_iter().map(Into::into).collect();
308 self
309 }
310
311 /// Add a single global allowed service-id fragment.
312 pub fn allow_service(mut self, service: impl Into<SmolStr>) -> Self {
313 self.allowed_services.push(service.into());
314 self
315 }
316
317 /// Replace the replay store and enable replay protection.
318 pub fn with_replay_store(mut self, store: impl ReplayStore) -> Self {
319 self.replay_store = Arc::new(store);
320 self.replay_protection_enabled = true;
321 self
322 }
323
324 /// Disable replay protection for legacy compatibility.
325 pub fn disable_replay_protection(mut self) -> Self {
326 self.replay_store = Arc::new(NoopReplayStore);
327 self.replay_protection_enabled = false;
328 self
329 }
330
331 /// Get the globally allowed service-id fragments.
332 pub fn allowed_services(&self) -> &[SmolStr] {
333 &self.allowed_services
334 }
335
336 /// Get the service DID.
337 pub fn service_did(&self) -> Did<&str> {
338 self.service_did.borrow()
339 }
340
341 /// Get a reference to the identity resolver.
342 pub fn resolver(&self) -> &R {
343 &self.resolver
344 }
345}
346
347impl<R: IdentityResolver> ServiceAuth for ServiceAuthConfig<R> {
348 type Resolver = R;
349
350 fn service_did(&self) -> Did<&str> {
351 self.service_did.borrow()
352 }
353
354 fn resolver(&self) -> &Self::Resolver {
355 &self.resolver
356 }
357
358 fn require_lxm(&self) -> bool {
359 self.require_lxm
360 }
361
362 fn allowed_services(&self) -> &[SmolStr] {
363 &self.allowed_services
364 }
365
366 fn replay_protection_enabled(&self) -> bool {
367 self.replay_protection_enabled
368 }
369
370 fn replay_store(&self) -> &dyn ReplayStore {
371 self.replay_store.as_ref()
372 }
373}
374
375/// Route-scoped service auth policy.
376#[derive(Debug, Clone, Default)]
377pub struct ServiceAuthRoutePolicy {
378 /// Required service-id fragment for this route.
379 required_service_id: Option<SmolStr>,
380}
381
382impl ServiceAuthRoutePolicy {
383 /// Require a specific service-id fragment for this route.
384 pub fn require_service_id(service_id: impl Into<SmolStr>) -> Self {
385 Self {
386 required_service_id: Some(service_id.into()),
387 }
388 }
389
390 /// Get the required service-id fragment.
391 pub fn required_service_id(&self) -> Option<&str> {
392 self.required_service_id.as_deref()
393 }
394}
395
396/// Create an axum route layer that requires a specific service-id fragment.
397pub fn require_service_id(service_id: impl Into<SmolStr>) -> Extension<ServiceAuthRoutePolicy> {
398 Extension(ServiceAuthRoutePolicy::require_service_id(service_id))
399}
400
401/// Verified service authentication information.
402///
403/// This is the result of successfully verifying a service auth JWT.
404/// This type is extracted by the `ExtractServiceAuth` extractor.
405#[derive(Debug, Clone, jacquard_derive::IntoStatic)]
406pub struct VerifiedServiceAuth<'a> {
407 /// The authenticated user's DID (from `iss` claim)
408 did: Did,
409 /// The audience (should match your service DID, with optional service fragment).
410 aud: DidService,
411 /// The lexicon method NSID, if present
412 lxm: Option<Nsid>,
413 /// JWT ID (nonce), if present
414 jti: Option<CowStr<'a>>,
415}
416
417impl<'a> VerifiedServiceAuth<'a> {
418 /// Get the authenticated user's DID.
419 pub fn did(&self) -> Did<&str> {
420 self.did.borrow()
421 }
422
423 /// Get the full audience, including any service-id fragment.
424 pub fn aud(&self) -> DidService<&str> {
425 self.aud.borrow()
426 }
427
428 /// Get the fragmentless service DID audience.
429 pub fn audience(&self) -> Did<&str> {
430 self.aud.audience()
431 }
432
433 /// Get the optional service-id fragment.
434 pub fn service(&self) -> Option<&str> {
435 self.aud.service()
436 }
437
438 /// Get the lexicon method NSID, if present.
439 pub fn lxm(&self) -> Option<Nsid<&str>> {
440 self.lxm.as_ref().map(|l| l.borrow())
441 }
442
443 /// Get the JWT ID (nonce), if present.
444 ///
445 /// You can use this for replay protection by tracking seen JTIs
446 /// until their expiration time.
447 pub fn jti(&self) -> Option<&str> {
448 self.jti.as_ref().map(|j| j.as_ref())
449 }
450}
451
452/// Axum extractor for service authentication.
453///
454/// This extracts and verifies a service auth JWT from the Authorization header,
455/// resolving the issuer's DID to verify the signature.
456///
457/// # Example
458///
459/// ```no_run
460/// use axum::{Router, routing::get};
461/// use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractServiceAuth};
462/// use jacquard_identity::JacquardResolver;
463/// use jacquard_identity::resolver::ResolverOptions;
464/// use jacquard_common::types::string::Did;
465///
466/// async fn handler(
467/// ExtractServiceAuth(auth): ExtractServiceAuth,
468/// ) -> String {
469/// format!("Authenticated as {}", auth.did())
470/// }
471///
472/// #[tokio::main]
473/// async fn main() {
474/// let resolver = JacquardResolver::new(
475/// reqwest::Client::new(),
476/// ResolverOptions::default(),
477/// );
478/// let config = ServiceAuthConfig::new(
479/// Did::new_static("did:web:feedgen.example.com").unwrap(),
480/// resolver,
481/// );
482///
483/// let app = Router::new()
484/// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler))
485/// .with_state(config);
486///
487/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
488/// .await
489/// .unwrap();
490/// axum::serve(listener, app).await.unwrap();
491/// }
492/// ```
493pub struct ExtractServiceAuth(pub VerifiedServiceAuth<'static>);
494
495/// Axum extractor for optional service authentication.
496///
497/// Like `ExtractServiceAuth`, but returns `None` if no Authorization header
498/// is present. If a header IS present but invalid, returns an error.
499///
500/// Use this for endpoints that work for both authenticated and anonymous users,
501/// but show different content based on auth status.
502///
503/// # Example
504///
505/// ```no_run
506/// use axum::{Router, routing::get};
507/// use jacquard_axum::service_auth::{ServiceAuthConfig, ExtractOptionalServiceAuth};
508/// use jacquard_identity::JacquardResolver;
509/// use jacquard_identity::resolver::ResolverOptions;
510/// use jacquard_common::types::string::Did;
511///
512/// async fn handler(
513/// ExtractOptionalServiceAuth(auth): ExtractOptionalServiceAuth,
514/// ) -> String {
515/// match auth {
516/// Some(a) => format!("Authenticated as {}", a.did()),
517/// None => "Anonymous request".to_string(),
518/// }
519/// }
520///
521/// #[tokio::main]
522/// async fn main() {
523/// let resolver = JacquardResolver::new(
524/// reqwest::Client::new(),
525/// ResolverOptions::default(),
526/// );
527/// let config = ServiceAuthConfig::new(
528/// Did::new_static("did:web:example.com").unwrap(),
529/// resolver,
530/// );
531///
532/// let app = Router::new()
533/// .route("/xrpc/com.example.getData", get(handler))
534/// .with_state(config);
535///
536/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
537/// .await
538/// .unwrap();
539/// axum::serve(listener, app).await.unwrap();
540/// }
541/// ```
542pub struct ExtractOptionalServiceAuth(pub Option<VerifiedServiceAuth<'static>>);
543
544/// Errors that can occur during service auth verification.
545#[derive(Debug, Error, miette::Diagnostic)]
546#[non_exhaustive]
547pub enum ServiceAuthError {
548 /// Authorization header is missing
549 #[error("missing Authorization header")]
550 MissingAuthHeader,
551
552 /// Authorization header is malformed (not "Bearer `token`")
553 #[error("invalid Authorization header format")]
554 InvalidAuthHeader,
555
556 /// JWT parsing or verification failed
557 #[error("JWT verification failed: {0}")]
558 JwtError(#[from] service_auth::ServiceAuthError),
559
560 /// DID resolution failed
561 #[error("failed to resolve DID {did}: {source}")]
562 DidResolutionFailed {
563 did: Did,
564 #[source]
565 source: Box<dyn std::error::Error + Send + Sync>,
566 },
567
568 /// No valid signing key found in DID document
569 #[error("no valid signing key found in DID document for {0}")]
570 NoSigningKey(Did),
571
572 /// Method binding required but missing
573 #[error("lxm (method binding) is required but missing from token")]
574 MethodBindingRequired,
575
576 /// Invalid key format
577 #[error("invalid key format: {0}")]
578 InvalidKey(String),
579
580 /// Service-id fragment is required for this route.
581 #[error("service id {required} is required but missing from token audience")]
582 ServiceIdRequired {
583 /// Required service-id fragment.
584 required: SmolStr,
585 },
586
587 /// Service-id fragment does not match this route.
588 #[error("service id mismatch: required {required}, got {actual}")]
589 RouteServiceIdMismatch {
590 /// Required service-id fragment.
591 required: SmolStr,
592 /// Actual service-id fragment.
593 actual: SmolStr,
594 },
595
596 /// Replay protection is enabled but the token has no `jti`.
597 #[error("service auth JWT is missing required jti")]
598 MissingJti,
599
600 /// Replay protection rejected the token.
601 #[error("replay protection failed: {0}")]
602 Replay(#[from] ReplayStoreError),
603}
604
605impl IntoResponse for ServiceAuthError {
606 fn into_response(self) -> Response {
607 let (status, error_code, message) = match &self {
608 ServiceAuthError::MissingAuthHeader => {
609 (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string())
610 }
611 ServiceAuthError::InvalidAuthHeader => {
612 (StatusCode::UNAUTHORIZED, "AuthMissing", self.to_string())
613 }
614 ServiceAuthError::JwtError(_) => (
615 StatusCode::UNAUTHORIZED,
616 "AuthenticationRequired",
617 self.to_string(),
618 ),
619 ServiceAuthError::DidResolutionFailed { .. } => (
620 StatusCode::UNAUTHORIZED,
621 "AuthenticationRequired",
622 self.to_string(),
623 ),
624 ServiceAuthError::NoSigningKey(_) => (
625 StatusCode::UNAUTHORIZED,
626 "AuthenticationRequired",
627 self.to_string(),
628 ),
629 ServiceAuthError::MethodBindingRequired => (
630 StatusCode::UNAUTHORIZED,
631 "AuthenticationRequired",
632 self.to_string(),
633 ),
634 ServiceAuthError::InvalidKey(_)
635 | ServiceAuthError::ServiceIdRequired { .. }
636 | ServiceAuthError::RouteServiceIdMismatch { .. }
637 | ServiceAuthError::MissingJti
638 | ServiceAuthError::Replay(_) => (
639 StatusCode::UNAUTHORIZED,
640 "AuthenticationRequired",
641 self.to_string(),
642 ),
643 };
644
645 tracing::warn!("Service auth failed: {}", message);
646
647 (
648 status,
649 [(
650 header::CONTENT_TYPE,
651 HeaderValue::from_static("application/json"),
652 )],
653 Json(json!({
654 "error": error_code,
655 "message": message,
656 })),
657 )
658 .into_response()
659 }
660}
661
662fn owned_did<S: jacquard_common::BosStr>(did: &Did<S>) -> Did {
663 Did::new_owned(did.as_str()).unwrap()
664}
665
666fn bearer_token_from_parts(parts: &Parts) -> Result<Option<&str>, ServiceAuthError> {
667 let Some(auth_header) = parts.headers.get(header::AUTHORIZATION) else {
668 return Ok(None);
669 };
670
671 let auth_str = auth_header
672 .to_str()
673 .map_err(|_| ServiceAuthError::InvalidAuthHeader)?;
674 let token = auth_str
675 .strip_prefix("Bearer ")
676 .ok_or(ServiceAuthError::InvalidAuthHeader)?;
677 Ok(Some(token))
678}
679
680async fn verify_service_auth<S>(
681 parts: &Parts,
682 state: &S,
683 token: &str,
684) -> Result<VerifiedServiceAuth<'static>, ServiceAuthError>
685where
686 S: ServiceAuth + Send + Sync,
687 S::Resolver: Send + Sync,
688{
689 let parsed = service_auth::parse_jwt(token)?;
690 let claims = parsed.claims();
691
692 let did_doc = state
693 .resolver()
694 .resolve_did_doc(&claims.iss)
695 .await
696 .map_err(|e| ServiceAuthError::DidResolutionFailed {
697 did: owned_did(&claims.iss),
698 source: Box::new(e),
699 })?;
700
701 let doc = did_doc
702 .parse()
703 .map_err(|e| ServiceAuthError::DidResolutionFailed {
704 did: owned_did(&claims.iss),
705 source: Box::new(e),
706 })?;
707
708 let verification_methods = doc
709 .verification_method
710 .as_deref()
711 .ok_or_else(|| ServiceAuthError::NoSigningKey(owned_did(&claims.iss)))?;
712
713 let signing_key = extract_signing_key(verification_methods)
714 .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?;
715
716 service_auth::verify_signature(&parsed, &signing_key)?;
717 claims.validate(&state.service_did(), state.allowed_services())?;
718
719 if state.require_lxm() && claims.lxm.is_none() {
720 return Err(ServiceAuthError::MethodBindingRequired);
721 }
722
723 if let Some(policy) = parts.extensions.get::<ServiceAuthRoutePolicy>() {
724 if let Some(required) = policy.required_service_id() {
725 match claims.aud.service() {
726 Some(actual) if actual == required => {}
727 Some(actual) => {
728 return Err(ServiceAuthError::RouteServiceIdMismatch {
729 required: SmolStr::new(required),
730 actual: SmolStr::new(actual),
731 });
732 }
733 None => {
734 return Err(ServiceAuthError::ServiceIdRequired {
735 required: SmolStr::new(required),
736 });
737 }
738 }
739 }
740 }
741
742 if state.replay_protection_enabled() {
743 let jti = claims.jti.as_ref().ok_or(ServiceAuthError::MissingJti)?;
744 let key = ReplayKey::new(
745 claims.iss.clone().into_static(),
746 claims.aud.clone().into_static(),
747 jti.clone(),
748 );
749 state
750 .replay_store()
751 .check_and_insert(key, claims.exp)
752 .await?;
753 }
754
755 Ok(VerifiedServiceAuth {
756 did: claims.iss.clone().into_static(),
757 aud: claims.aud.clone().into_static(),
758 lxm: claims.lxm.as_ref().map(|l| l.clone().into_static()),
759 jti: claims.jti.as_ref().map(|j| CowStr::from(j.clone())),
760 })
761}
762
763impl<S> FromRequestParts<S> for ExtractServiceAuth
764where
765 S: ServiceAuth + Send + Sync,
766 S::Resolver: Send + Sync,
767{
768 type Rejection = ServiceAuthError;
769
770 fn from_request_parts(
771 parts: &mut Parts,
772 state: &S,
773 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
774 async move {
775 let token =
776 bearer_token_from_parts(parts)?.ok_or(ServiceAuthError::MissingAuthHeader)?;
777 verify_service_auth(parts, state, token).await.map(Self)
778 }
779 }
780}
781
782impl<S> FromRequestParts<S> for ExtractOptionalServiceAuth
783where
784 S: ServiceAuth + Send + Sync,
785 S::Resolver: Send + Sync,
786{
787 type Rejection = ServiceAuthError;
788
789 fn from_request_parts(
790 parts: &mut Parts,
791 state: &S,
792 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
793 async move {
794 let Some(token) = bearer_token_from_parts(parts)? else {
795 return Ok(Self(None));
796 };
797 verify_service_auth(parts, state, token)
798 .await
799 .map(|auth| Self(Some(auth)))
800 }
801 }
802}
803
804/// Extract the signing key from a DID document's verification methods.
805///
806/// This looks for a key with type "atproto" or the first available key
807/// if no atproto-specific key is found.
808fn extract_signing_key(methods: &[VerificationMethod<CowStr<'_>>]) -> Option<PublicKey> {
809 // First try to find an atproto-specific key
810 let atproto_method = methods
811 .iter()
812 .find(|m| m.r#type.as_ref() == "Multikey" || m.r#type.as_ref() == "atproto");
813
814 let method = atproto_method.or_else(|| methods.first())?;
815
816 // Parse the multikey
817 let public_key_multibase = method.public_key_multibase.as_ref()?;
818
819 // Decode multibase
820 let (_, key_bytes) = multibase::decode(public_key_multibase.as_ref()).ok()?;
821
822 // First two bytes are the multicodec prefix
823 if key_bytes.len() < 2 {
824 return None;
825 }
826
827 let codec = &key_bytes[..2];
828 let key_material = &key_bytes[2..];
829
830 match codec {
831 // p256-pub (0x1200)
832 [0x80, 0x24] => PublicKey::from_p256_bytes(key_material)
833 .inspect_err(|_e| {
834 #[cfg(feature = "tracing")]
835 tracing::error!("Failed to parse p256 public key: {}", _e);
836 })
837 .ok(),
838 // secp256k1-pub (0xe7)
839 [0xe7, 0x01] => PublicKey::from_k256_bytes(key_material)
840 .inspect_err(|_e| {
841 #[cfg(feature = "tracing")]
842 tracing::error!("Failed to parse secp256k1 public key: {}", _e);
843 })
844 .ok(),
845 _ => {
846 #[cfg(feature = "tracing")]
847 tracing::error!("Unsupported public key multicodec: {:?}", codec);
848 None
849 }
850 }
851}
852
853/// Middleware for verifying service authentication on all requests.
854///
855/// This middleware extracts and verifies the service auth JWT, then adds the
856/// `VerifiedServiceAuth` to request extensions for downstream handlers to access.
857///
858/// # Example
859///
860/// ```no_run
861/// use axum::{Router, routing::get, middleware, Extension};
862/// use jacquard_axum::service_auth::{ServiceAuthConfig, service_auth_middleware};
863/// use jacquard_identity::{PublicResolver, JacquardResolver};
864/// use jacquard_identity::resolver::ResolverOptions;
865/// use jacquard_common::types::string::Did;
866///
867/// async fn handler(
868/// Extension(auth): Extension<jacquard_axum::service_auth::VerifiedServiceAuth<'static>>,
869/// ) -> String {
870/// format!("Authenticated as {}", auth.did())
871/// }
872///
873/// #[tokio::main]
874/// async fn main() {
875/// let resolver = JacquardResolver::new(
876/// reqwest::Client::new(),
877/// ResolverOptions::default(),
878/// );
879/// let config = ServiceAuthConfig::new(
880/// Did::new_static("did:web:feedgen.example.com").unwrap(),
881/// resolver,
882/// );
883///
884/// let app = Router::new()
885/// .route("/xrpc/app.bsky.feed.getFeedSkeleton", get(handler))
886/// .layer(middleware::from_fn_with_state(
887/// config.clone(),
888/// service_auth_middleware::<ServiceAuthConfig<PublicResolver>>,
889/// ))
890/// .with_state(config);
891///
892/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
893/// .await
894/// .unwrap();
895/// axum::serve(listener, app).await.unwrap();
896/// }
897/// ```
898pub async fn service_auth_middleware<S>(
899 state: axum::extract::State<S>,
900 mut req: axum::extract::Request,
901 next: Next,
902) -> Result<Response, ServiceAuthError>
903where
904 S: ServiceAuth + Send + Sync + Clone,
905 S::Resolver: Send + Sync,
906{
907 // Extract auth from request parts
908 let (mut parts, body) = req.into_parts();
909 let ExtractServiceAuth(auth) =
910 ExtractServiceAuth::from_request_parts(&mut parts, &state.0).await?;
911
912 // Add auth to extensions
913 parts.extensions.insert(auth);
914
915 // Reconstruct request and continue
916 req = axum::extract::Request::from_parts(parts, body);
917 Ok(next.run(req).await)
918}