A better Rust ATProto crate
1

Configure Feed

Select the types of activity you want to include in your feed.

fixed service auth, incorporating a version of bailey's PR, adding replay protection.

author nonbinary.computer date (Jun 7, 2026, 8:21 PM -0400) commit 9a664d6f parent 05add82e change-id spkxlwsy
+1296 -321
+7 -62
Cargo.lock
··· 39 39 ] 40 40 41 41 [[package]] 42 - name = "aliasable" 43 - version = "0.1.3" 44 - source = "registry+https://github.com/rust-lang/crates.io-index" 45 - checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" 46 - 47 - [[package]] 48 42 name = "aligned" 49 43 version = "0.4.3" 50 44 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 753 747 source = "registry+https://github.com/rust-lang/crates.io-index" 754 748 checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" 755 749 dependencies = [ 756 - "heck 0.5.0", 750 + "heck", 757 751 "proc-macro2", 758 752 "quote", 759 753 "syn", ··· 1325 1319 source = "registry+https://github.com/rust-lang/crates.io-index" 1326 1320 checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" 1327 1321 dependencies = [ 1328 - "heck 0.5.0", 1322 + "heck", 1329 1323 "proc-macro2", 1330 1324 "quote", 1331 1325 "syn", ··· 1889 1883 1890 1884 [[package]] 1891 1885 name = "heck" 1892 - version = "0.4.1" 1893 - source = "registry+https://github.com/rust-lang/crates.io-index" 1894 - checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" 1895 - 1896 - [[package]] 1897 - name = "heck" 1898 1886 version = "0.5.0" 1899 1887 source = "registry+https://github.com/rust-lang/crates.io-index" 1900 1888 checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" ··· 2453 2441 "jacquard-identity", 2454 2442 "k256", 2455 2443 "miette", 2444 + "mini-moka-wasm", 2456 2445 "multibase", 2457 2446 "rand 0.8.5", 2458 2447 "reqwest", ··· 2509 2498 "multibase", 2510 2499 "multihash", 2511 2500 "n0-future", 2512 - "ouroboros", 2513 2501 "oxilangtag", 2514 2502 "p256", 2515 2503 "phf", ··· 2542 2530 name = "jacquard-derive" 2543 2531 version = "0.12.0-beta.2" 2544 2532 dependencies = [ 2545 - "heck 0.5.0", 2533 + "heck", 2546 2534 "inventory", 2547 2535 "jacquard-common", 2548 2536 "jacquard-lexicon", ··· 2610 2598 "bytes", 2611 2599 "cid", 2612 2600 "dashmap", 2613 - "heck 0.5.0", 2601 + "heck", 2614 2602 "inventory", 2615 2603 "jacquard-common", 2616 2604 "miette", ··· 3519 3507 checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" 3520 3508 3521 3509 [[package]] 3522 - name = "ouroboros" 3523 - version = "0.18.5" 3524 - source = "registry+https://github.com/rust-lang/crates.io-index" 3525 - checksum = "1e0f050db9c44b97a94723127e6be766ac5c340c48f2c4bb3ffa11713744be59" 3526 - dependencies = [ 3527 - "aliasable", 3528 - "ouroboros_macro", 3529 - "static_assertions", 3530 - ] 3531 - 3532 - [[package]] 3533 - name = "ouroboros_macro" 3534 - version = "0.18.5" 3535 - source = "registry+https://github.com/rust-lang/crates.io-index" 3536 - checksum = "3c7028bdd3d43083f6d8d4d5187680d0d3560d54df4cc9d752005268b41e64d0" 3537 - dependencies = [ 3538 - "heck 0.4.1", 3539 - "proc-macro2", 3540 - "proc-macro2-diagnostics", 3541 - "quote", 3542 - "syn", 3543 - ] 3544 - 3545 - [[package]] 3546 3510 name = "owo-colors" 3547 3511 version = "4.3.0" 3548 3512 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 3852 3816 checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" 3853 3817 dependencies = [ 3854 3818 "unicode-ident", 3855 - ] 3856 - 3857 - [[package]] 3858 - name = "proc-macro2-diagnostics" 3859 - version = "0.10.1" 3860 - source = "registry+https://github.com/rust-lang/crates.io-index" 3861 - checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" 3862 - dependencies = [ 3863 - "proc-macro2", 3864 - "quote", 3865 - "syn", 3866 - "version_check", 3867 - "yansi", 3868 3819 ] 3869 3820 3870 3821 [[package]] ··· 4909 4860 version = "1.2.1" 4910 4861 source = "registry+https://github.com/rust-lang/crates.io-index" 4911 4862 checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" 4912 - 4913 - [[package]] 4914 - name = "static_assertions" 4915 - version = "1.1.0" 4916 - source = "registry+https://github.com/rust-lang/crates.io-index" 4917 - checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" 4918 4863 4919 4864 [[package]] 4920 4865 name = "string_cache" ··· 6401 6346 checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" 6402 6347 dependencies = [ 6403 6348 "anyhow", 6404 - "heck 0.5.0", 6349 + "heck", 6405 6350 "wit-parser", 6406 6351 ] 6407 6352 ··· 6412 6357 checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" 6413 6358 dependencies = [ 6414 6359 "anyhow", 6415 - "heck 0.5.0", 6360 + "heck", 6416 6361 "indexmap", 6417 6362 "prettyplease", 6418 6363 "syn",
+4 -1
crates/jacquard-axum/Cargo.toml
··· 22 22 [dependencies] 23 23 axum = "0.8.6" 24 24 bytes.workspace = true 25 + chrono.workspace = true 25 26 jacquard = { version = "0.12.0-beta.1", path = "../jacquard", default-features = false, features = ["api"] } 26 27 jacquard-common = { version = "0.12.0-beta.1", path = "../jacquard-common", features = ["reqwest-client"] } 27 28 jacquard-derive = { version = "0.12.0-beta.1", path = "../jacquard-derive" } 28 29 jacquard-identity = { version = "0.12.0-beta.1", path = "../jacquard-identity", optional = true } 29 30 miette.workspace = true 30 31 multibase = { version = "0.9.1", optional = true } 32 + mini-moka = { package = "mini-moka-wasm", version = "0.10", path = "../mini-moka-wasm", optional = true } 31 33 serde.workspace = true 32 34 serde_html_form.workspace = true 33 35 serde_json.workspace = true ··· 37 39 tracing = "0.1.41" 38 40 39 41 [features] 40 - default = ["service-auth"] 42 + default = ["service-auth", "service-auth-replay"] 41 43 service-auth = ["jacquard-common/service-auth", "dep:jacquard-identity", "dep:multibase"] 44 + service-auth-replay = ["service-auth", "dep:mini-moka"] 42 45 tracing = [] 43 46 44 47 [dev-dependencies]
+396 -140
crates/jacquard-axum/src/service_auth.rs
··· 1 - //! Service authentication extractor and middleware 1 + //! Service authentication extractor and middleware. 2 + //! 3 + //! Service auth verifies AT Protocol inter-service JWTs. Normal 4 + //! [`ServiceAuthConfig::new`] configurations require `lxm` method binding and, 5 + //! when the `service-auth-replay` feature is enabled, reject missing or replayed 6 + //! `jti` values by default. Use [`ServiceAuthConfig::disable_replay_protection`] 7 + //! only for legacy compatibility. 8 + //! 9 + //! Global service-id allow-lists constrain present `aud` fragments but do not 10 + //! require a fragment. Use [`require_service_id`] as a route layer for endpoints 11 + //! that require a specific `did:web:example.com#service_id` audience fragment. 12 + //! 13 + //! [`ExtractOptionalServiceAuth`] treats only an absent Authorization header as 14 + //! anonymous. Present malformed, invalid, or replayed credentials are rejected. 15 + //! 16 + //! The default replay store is in-memory and per process. Horizontally scaled 17 + //! deployments should provide a shared [`ReplayStore`] implementation. 18 + //! Legacy configs created with [`ServiceAuthConfig::new_legacy`] disable `lxm` 19 + //! and replay requirements. 2 20 //! 3 21 //! # Example 4 22 //! ··· 38 56 //! ``` 39 57 40 58 use axum::{ 41 - Json, 59 + Extension, Json, 42 60 extract::FromRequestParts, 43 61 http::{HeaderValue, StatusCode, header, request::Parts}, 44 62 middleware::Next, 45 63 response::{IntoResponse, Response}, 46 64 }; 65 + use jacquard_common::deps::smol_str::SmolStr; 47 66 use jacquard_common::{ 48 67 CowStr, IntoStatic, 49 68 service_auth::{self, PublicKey}, 50 69 types::{ 51 70 did_doc::VerificationMethod, 52 - string::{Did, Nsid}, 71 + string::{Did, DidService, Nsid}, 53 72 }, 54 73 }; 55 74 use jacquard_identity::resolver::IdentityResolver; 56 75 use serde_json::json; 57 - use std::sync::Arc; 76 + use std::future::Future; 77 + use std::pin::Pin; 78 + use std::sync::{Arc, Mutex}; 58 79 use thiserror::Error; 59 80 81 + /// Replay key for service auth JWT `jti` replay protection. 82 + #[derive(Debug, Clone, PartialEq, Eq, Hash)] 83 + pub struct ReplayKey { 84 + /// Issuer DID from the JWT. 85 + iss: Did, 86 + /// Full audience, including any service-id fragment. 87 + aud: DidService, 88 + /// JWT ID nonce. 89 + jti: SmolStr, 90 + } 91 + 92 + impl ReplayKey { 93 + /// Create a new replay key. 94 + pub fn new(iss: Did, aud: DidService, jti: impl Into<SmolStr>) -> Self { 95 + Self { 96 + iss, 97 + aud, 98 + jti: jti.into(), 99 + } 100 + } 101 + } 102 + 103 + /// Errors returned by replay stores. 104 + #[derive(Debug, Error)] 105 + #[non_exhaustive] 106 + pub enum ReplayStoreError { 107 + /// The replay key has already been presented and has not expired. 108 + #[error("service auth JWT replay detected")] 109 + Replayed, 110 + 111 + /// The replay store failed. 112 + #[error("replay store failed: {0}")] 113 + Store(String), 114 + } 115 + 116 + /// Store used to reject replayed service auth JWT IDs. 117 + pub trait ReplayStore: Send + Sync + 'static { 118 + /// Check whether `key` has been seen, and record it until `expires_at`. 119 + fn check_and_insert( 120 + &self, 121 + key: ReplayKey, 122 + expires_at: i64, 123 + ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>>; 124 + } 125 + 126 + /// Replay store that disables replay protection. 127 + #[derive(Debug, Clone, Copy, Default)] 128 + pub struct NoopReplayStore; 129 + 130 + impl ReplayStore for NoopReplayStore { 131 + fn check_and_insert( 132 + &self, 133 + _key: ReplayKey, 134 + _expires_at: i64, 135 + ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>> { 136 + Box::pin(async { Ok(()) }) 137 + } 138 + } 139 + 140 + /// Default in-memory replay store. 141 + #[cfg(feature = "service-auth-replay")] 142 + #[derive(Debug, Clone)] 143 + pub struct InMemoryReplayStore { 144 + cache: mini_moka::sync::Cache<ReplayKey, i64>, 145 + lock: Arc<Mutex<()>>, 146 + } 147 + 148 + #[cfg(feature = "service-auth-replay")] 149 + impl Default for InMemoryReplayStore { 150 + fn default() -> Self { 151 + Self::new(100_000) 152 + } 153 + } 154 + 155 + #[cfg(feature = "service-auth-replay")] 156 + impl InMemoryReplayStore { 157 + /// Create an in-memory replay store with a maximum key capacity. 158 + pub fn new(max_capacity: u64) -> Self { 159 + Self { 160 + cache: mini_moka::sync::Cache::new(max_capacity), 161 + lock: Arc::new(Mutex::new(())), 162 + } 163 + } 164 + } 165 + 166 + #[cfg(feature = "service-auth-replay")] 167 + impl ReplayStore for InMemoryReplayStore { 168 + fn check_and_insert( 169 + &self, 170 + key: ReplayKey, 171 + expires_at: i64, 172 + ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>> { 173 + Box::pin(async move { 174 + let _guard = self 175 + .lock 176 + .lock() 177 + .map_err(|_| ReplayStoreError::Store("replay store lock poisoned".to_string()))?; 178 + let now = chrono::Utc::now().timestamp(); 179 + if let Some(existing_expires_at) = self.cache.get(&key) { 180 + if existing_expires_at > now { 181 + return Err(ReplayStoreError::Replayed); 182 + } 183 + self.cache.invalidate(&key); 184 + } 185 + self.cache.insert(key, expires_at); 186 + Ok(()) 187 + }) 188 + } 189 + } 190 + 60 191 /// Trait for providing service authentication configuration. 61 192 /// 62 193 /// This trait allows custom state types to provide service auth configuration ··· 71 202 /// Get a reference to the identity resolver 72 203 fn resolver(&self) -> &Self::Resolver; 73 204 74 - /// Whether to require the `lxm` (method binding) field 205 + /// Whether to require the `lxm` (method binding) field. 75 206 fn require_lxm(&self) -> bool; 207 + 208 + /// Service-id fragments allowed by global validation. 209 + fn allowed_services(&self) -> &[SmolStr]; 210 + 211 + /// Whether replay protection is enabled. 212 + fn replay_protection_enabled(&self) -> bool; 213 + 214 + /// Replay store used by replay protection. 215 + fn replay_store(&self) -> &dyn ReplayStore; 76 216 } 77 217 78 218 /// Configuration for service auth verification. ··· 80 220 /// This should be stored in your Axum app state and will be extracted 81 221 /// by the `ExtractServiceAuth` extractor. 82 222 pub struct ServiceAuthConfig<R> { 83 - /// The DID of your service (the expected audience) 223 + /// The DID of your service (the expected audience). 84 224 service_did: Did, 85 - /// Identity resolver for fetching DID documents 225 + /// Identity resolver for fetching DID documents. 86 226 resolver: Arc<R>, 87 - /// Whether to require the `lxm` (method binding) field 227 + /// Whether to require the `lxm` (method binding) field. 88 228 require_lxm: bool, 229 + /// Globally allowed service-id fragments. 230 + allowed_services: Vec<SmolStr>, 231 + /// Replay store used when replay protection is enabled. 232 + replay_store: Arc<dyn ReplayStore>, 233 + /// Whether replay protection is enabled. 234 + replay_protection_enabled: bool, 89 235 } 90 236 91 237 impl<R> Clone for ServiceAuthConfig<R> { ··· 94 240 service_did: self.service_did.clone(), 95 241 resolver: Arc::clone(&self.resolver), 96 242 require_lxm: self.require_lxm, 243 + allowed_services: self.allowed_services.clone(), 244 + replay_store: Arc::clone(&self.replay_store), 245 + replay_protection_enabled: self.replay_protection_enabled, 97 246 } 98 247 } 99 248 } 100 249 250 + fn default_replay_store() -> (Arc<dyn ReplayStore>, bool) { 251 + #[cfg(feature = "service-auth-replay")] 252 + { 253 + (Arc::new(InMemoryReplayStore::default()), true) 254 + } 255 + #[cfg(not(feature = "service-auth-replay"))] 256 + { 257 + (Arc::new(NoopReplayStore), false) 258 + } 259 + } 260 + 101 261 impl<R: IdentityResolver> ServiceAuthConfig<R> { 102 262 /// Create a new service auth config. 103 263 /// 104 264 /// This enables `lxm` (method binding). If you need backward compatibility, 105 265 /// use `ServiceAuthConfig::new_legacy()` 106 266 pub fn new(service_did: Did, resolver: R) -> Self { 267 + let (replay_store, replay_protection_enabled) = default_replay_store(); 107 268 Self { 108 269 service_did, 109 270 resolver: Arc::new(resolver), 110 271 require_lxm: true, 272 + allowed_services: Vec::new(), 273 + replay_store, 274 + replay_protection_enabled, 111 275 } 112 276 } 113 277 ··· 119 283 service_did, 120 284 resolver: Arc::new(resolver), 121 285 require_lxm: false, 286 + allowed_services: Vec::new(), 287 + replay_store: Arc::new(NoopReplayStore), 288 + replay_protection_enabled: false, 122 289 } 123 290 } 124 291 ··· 131 298 self 132 299 } 133 300 301 + /// Replace the global allowed service-id fragments. 302 + pub fn with_allowed_services<I, Svc>(mut self, services: I) -> Self 303 + where 304 + I: IntoIterator<Item = Svc>, 305 + Svc: Into<SmolStr>, 306 + { 307 + self.allowed_services = services.into_iter().map(Into::into).collect(); 308 + self 309 + } 310 + 311 + /// Add a single global allowed service-id fragment. 312 + pub fn allow_service(mut self, service: impl Into<SmolStr>) -> Self { 313 + self.allowed_services.push(service.into()); 314 + self 315 + } 316 + 317 + /// Replace the replay store and enable replay protection. 318 + pub fn with_replay_store(mut self, store: impl ReplayStore) -> Self { 319 + self.replay_store = Arc::new(store); 320 + self.replay_protection_enabled = true; 321 + self 322 + } 323 + 324 + /// Disable replay protection for legacy compatibility. 325 + pub fn disable_replay_protection(mut self) -> Self { 326 + self.replay_store = Arc::new(NoopReplayStore); 327 + self.replay_protection_enabled = false; 328 + self 329 + } 330 + 331 + /// Get the globally allowed service-id fragments. 332 + pub fn allowed_services(&self) -> &[SmolStr] { 333 + &self.allowed_services 334 + } 335 + 134 336 /// Get the service DID. 135 337 pub fn service_did(&self) -> Did<&str> { 136 338 self.service_did.borrow() ··· 156 358 fn require_lxm(&self) -> bool { 157 359 self.require_lxm 158 360 } 361 + 362 + fn allowed_services(&self) -> &[SmolStr] { 363 + &self.allowed_services 364 + } 365 + 366 + fn replay_protection_enabled(&self) -> bool { 367 + self.replay_protection_enabled 368 + } 369 + 370 + fn replay_store(&self) -> &dyn ReplayStore { 371 + self.replay_store.as_ref() 372 + } 373 + } 374 + 375 + /// Route-scoped service auth policy. 376 + #[derive(Debug, Clone, Default)] 377 + pub struct ServiceAuthRoutePolicy { 378 + /// Required service-id fragment for this route. 379 + required_service_id: Option<SmolStr>, 380 + } 381 + 382 + impl ServiceAuthRoutePolicy { 383 + /// Require a specific service-id fragment for this route. 384 + pub fn require_service_id(service_id: impl Into<SmolStr>) -> Self { 385 + Self { 386 + required_service_id: Some(service_id.into()), 387 + } 388 + } 389 + 390 + /// Get the required service-id fragment. 391 + pub fn required_service_id(&self) -> Option<&str> { 392 + self.required_service_id.as_deref() 393 + } 394 + } 395 + 396 + /// Create an axum route layer that requires a specific service-id fragment. 397 + pub fn require_service_id(service_id: impl Into<SmolStr>) -> Extension<ServiceAuthRoutePolicy> { 398 + Extension(ServiceAuthRoutePolicy::require_service_id(service_id)) 159 399 } 160 400 161 401 /// Verified service authentication information. ··· 166 406 pub struct VerifiedServiceAuth<'a> { 167 407 /// The authenticated user's DID (from `iss` claim) 168 408 did: Did, 169 - /// The audience (should match your service DID) 170 - aud: Did, 409 + /// The audience (should match your service DID, with optional service fragment). 410 + aud: DidService, 171 411 /// The lexicon method NSID, if present 172 412 lxm: Option<Nsid>, 173 413 /// JWT ID (nonce), if present ··· 180 420 self.did.borrow() 181 421 } 182 422 183 - /// Get the audience (your service DID). 184 - pub fn aud(&self) -> Did<&str> { 423 + /// Get the full audience, including any service-id fragment. 424 + pub fn aud(&self) -> DidService<&str> { 185 425 self.aud.borrow() 186 426 } 187 427 428 + /// Get the fragmentless service DID audience. 429 + pub fn audience(&self) -> Did<&str> { 430 + self.aud.audience() 431 + } 432 + 433 + /// Get the optional service-id fragment. 434 + pub fn service(&self) -> Option<&str> { 435 + self.aud.service() 436 + } 437 + 188 438 /// Get the lexicon method NSID, if present. 189 439 pub fn lxm(&self) -> Option<Nsid<&str>> { 190 440 self.lxm.as_ref().map(|l| l.borrow()) ··· 326 576 /// Invalid key format 327 577 #[error("invalid key format: {0}")] 328 578 InvalidKey(String), 579 + 580 + /// Service-id fragment is required for this route. 581 + #[error("service id {required} is required but missing from token audience")] 582 + ServiceIdRequired { 583 + /// Required service-id fragment. 584 + required: SmolStr, 585 + }, 586 + 587 + /// Service-id fragment does not match this route. 588 + #[error("service id mismatch: required {required}, got {actual}")] 589 + RouteServiceIdMismatch { 590 + /// Required service-id fragment. 591 + required: SmolStr, 592 + /// Actual service-id fragment. 593 + actual: SmolStr, 594 + }, 595 + 596 + /// Replay protection is enabled but the token has no `jti`. 597 + #[error("service auth JWT is missing required jti")] 598 + MissingJti, 599 + 600 + /// Replay protection rejected the token. 601 + #[error("replay protection failed: {0}")] 602 + Replay(#[from] ReplayStoreError), 329 603 } 330 604 331 605 impl IntoResponse for ServiceAuthError { ··· 357 631 "AuthenticationRequired", 358 632 self.to_string(), 359 633 ), 360 - ServiceAuthError::InvalidKey(_) => ( 634 + ServiceAuthError::InvalidKey(_) 635 + | ServiceAuthError::ServiceIdRequired { .. } 636 + | ServiceAuthError::RouteServiceIdMismatch { .. } 637 + | ServiceAuthError::MissingJti 638 + | ServiceAuthError::Replay(_) => ( 361 639 StatusCode::UNAUTHORIZED, 362 640 "AuthenticationRequired", 363 641 self.to_string(), ··· 381 659 } 382 660 } 383 661 384 - impl<S> FromRequestParts<S> for ExtractServiceAuth 662 + fn owned_did<S: jacquard_common::BosStr>(did: &Did<S>) -> Did { 663 + Did::new_owned(did.as_str()).unwrap() 664 + } 665 + 666 + fn bearer_token_from_parts(parts: &Parts) -> Result<Option<&str>, ServiceAuthError> { 667 + let Some(auth_header) = parts.headers.get(header::AUTHORIZATION) else { 668 + return Ok(None); 669 + }; 670 + 671 + let auth_str = auth_header 672 + .to_str() 673 + .map_err(|_| ServiceAuthError::InvalidAuthHeader)?; 674 + let token = auth_str 675 + .strip_prefix("Bearer ") 676 + .ok_or(ServiceAuthError::InvalidAuthHeader)?; 677 + Ok(Some(token)) 678 + } 679 + 680 + async fn verify_service_auth<S>( 681 + parts: &Parts, 682 + state: &S, 683 + token: &str, 684 + ) -> Result<VerifiedServiceAuth<'static>, ServiceAuthError> 385 685 where 386 686 S: ServiceAuth + Send + Sync, 387 687 S::Resolver: Send + Sync, 388 688 { 389 - type Rejection = ServiceAuthError; 390 - 391 - fn from_request_parts( 392 - parts: &mut Parts, 393 - state: &S, 394 - ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send { 395 - async move { 396 - // Extract Authorization header 397 - let auth_header = parts 398 - .headers 399 - .get(header::AUTHORIZATION) 400 - .ok_or(ServiceAuthError::MissingAuthHeader)?; 689 + let parsed = service_auth::parse_jwt(token)?; 690 + let claims = parsed.claims(); 401 691 402 - // Parse Bearer token 403 - let auth_str = auth_header 404 - .to_str() 405 - .map_err(|_| ServiceAuthError::InvalidAuthHeader)?; 692 + let did_doc = state 693 + .resolver() 694 + .resolve_did_doc(&claims.iss) 695 + .await 696 + .map_err(|e| ServiceAuthError::DidResolutionFailed { 697 + did: owned_did(&claims.iss), 698 + source: Box::new(e), 699 + })?; 406 700 407 - let token = auth_str 408 - .strip_prefix("Bearer ") 409 - .ok_or(ServiceAuthError::InvalidAuthHeader)?; 701 + let doc = did_doc 702 + .parse() 703 + .map_err(|e| ServiceAuthError::DidResolutionFailed { 704 + did: owned_did(&claims.iss), 705 + source: Box::new(e), 706 + })?; 410 707 411 - // Parse JWT 412 - let parsed = service_auth::parse_jwt(token)?; 708 + let verification_methods = doc 709 + .verification_method 710 + .as_deref() 711 + .ok_or_else(|| ServiceAuthError::NoSigningKey(owned_did(&claims.iss)))?; 413 712 414 - // Get claims for DID resolution 415 - let claims = parsed.claims(); 713 + let signing_key = extract_signing_key(verification_methods) 714 + .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 416 715 417 - // Resolve DID to get signing key (do this before checking claims) 418 - let did_doc = state 419 - .resolver() 420 - .resolve_did_doc(&claims.iss) 421 - .await 422 - .map_err(|e| ServiceAuthError::DidResolutionFailed { 423 - did: claims.iss.clone().into_static(), 424 - source: Box::new(e), 425 - })?; 716 + service_auth::verify_signature(&parsed, &signing_key)?; 717 + claims.validate(&state.service_did(), state.allowed_services())?; 426 718 427 - // Parse the DID document response to get verification methods 428 - let doc = did_doc 429 - .parse() 430 - .map_err(|e| ServiceAuthError::DidResolutionFailed { 431 - did: claims.iss.clone().into_static(), 432 - source: Box::new(e), 433 - })?; 434 - 435 - // Extract signing key from DID document 436 - let verification_methods = doc 437 - .verification_method 438 - .as_deref() 439 - .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 719 + if state.require_lxm() && claims.lxm.is_none() { 720 + return Err(ServiceAuthError::MethodBindingRequired); 721 + } 440 722 441 - let signing_key = extract_signing_key(verification_methods) 442 - .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 723 + if let Some(policy) = parts.extensions.get::<ServiceAuthRoutePolicy>() { 724 + if let Some(required) = policy.required_service_id() { 725 + match claims.aud.service() { 726 + Some(actual) if actual == required => {} 727 + Some(actual) => { 728 + return Err(ServiceAuthError::RouteServiceIdMismatch { 729 + required: SmolStr::new(required), 730 + actual: SmolStr::new(actual), 731 + }); 732 + } 733 + None => { 734 + return Err(ServiceAuthError::ServiceIdRequired { 735 + required: SmolStr::new(required), 736 + }); 737 + } 738 + } 739 + } 740 + } 443 741 444 - // Verify signature FIRST - if this fails, nothing else matters 445 - service_auth::verify_signature(&parsed, &signing_key)?; 742 + if state.replay_protection_enabled() { 743 + let jti = claims.jti.as_ref().ok_or(ServiceAuthError::MissingJti)?; 744 + let key = ReplayKey::new( 745 + claims.iss.clone().into_static(), 746 + claims.aud.clone().into_static(), 747 + jti.clone(), 748 + ); 749 + state 750 + .replay_store() 751 + .check_and_insert(key, claims.exp) 752 + .await?; 753 + } 446 754 447 - // Now validate claims (audience, expiration, etc.) 448 - claims.validate(&state.service_did())?; 755 + Ok(VerifiedServiceAuth { 756 + did: claims.iss.clone().into_static(), 757 + aud: claims.aud.clone().into_static(), 758 + lxm: claims.lxm.as_ref().map(|l| l.clone().into_static()), 759 + jti: claims.jti.as_ref().map(|j| CowStr::from(j.clone())), 760 + }) 761 + } 449 762 450 - // Check method binding if required 451 - if state.require_lxm() && claims.lxm.is_none() { 452 - return Err(ServiceAuthError::MethodBindingRequired); 453 - } 763 + impl<S> FromRequestParts<S> for ExtractServiceAuth 764 + where 765 + S: ServiceAuth + Send + Sync, 766 + S::Resolver: Send + Sync, 767 + { 768 + type Rejection = ServiceAuthError; 454 769 455 - // All checks passed - return verified auth 456 - Ok(ExtractServiceAuth(VerifiedServiceAuth { 457 - did: claims.iss.clone().into_static(), 458 - aud: claims.aud.clone().into_static(), 459 - lxm: claims.lxm.as_ref().map(|l| l.clone().into_static()), 460 - jti: claims.jti.as_ref().map(|j| j.clone().into_static()), 461 - })) 770 + fn from_request_parts( 771 + parts: &mut Parts, 772 + state: &S, 773 + ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send { 774 + async move { 775 + let token = 776 + bearer_token_from_parts(parts)?.ok_or(ServiceAuthError::MissingAuthHeader)?; 777 + verify_service_auth(parts, state, token).await.map(Self) 462 778 } 463 779 } 464 780 } ··· 475 791 state: &S, 476 792 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send { 477 793 async move { 478 - // Check for Authorization header - if missing, return None (not an error) 479 - let auth_header = match parts.headers.get(header::AUTHORIZATION) { 480 - Some(h) => h, 481 - None => return Ok(ExtractOptionalServiceAuth(None)), 794 + let Some(token) = bearer_token_from_parts(parts)? else { 795 + return Ok(Self(None)); 482 796 }; 483 - 484 - // Header is present - now we MUST validate it (bad auth = error) 485 - let auth_str = auth_header 486 - .to_str() 487 - .map_err(|_| ServiceAuthError::InvalidAuthHeader)?; 488 - 489 - let token = auth_str 490 - .strip_prefix("Bearer ") 491 - .ok_or(ServiceAuthError::InvalidAuthHeader)?; 492 - 493 - // Parse JWT 494 - let parsed = service_auth::parse_jwt(token)?; 495 - 496 - // Get claims for DID resolution 497 - let claims = parsed.claims(); 498 - 499 - // Resolve DID to get signing key 500 - let did_doc = state 501 - .resolver() 502 - .resolve_did_doc(&claims.iss) 797 + verify_service_auth(parts, state, token) 503 798 .await 504 - .map_err(|e| ServiceAuthError::DidResolutionFailed { 505 - did: claims.iss.clone().into_static(), 506 - source: Box::new(e), 507 - })?; 508 - 509 - // Parse the DID document response to get verification methods 510 - let doc = did_doc 511 - .parse() 512 - .map_err(|e| ServiceAuthError::DidResolutionFailed { 513 - did: claims.iss.clone().into_static(), 514 - source: Box::new(e), 515 - })?; 516 - 517 - // Extract signing key from DID document 518 - let verification_methods = doc 519 - .verification_method 520 - .as_deref() 521 - .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 522 - 523 - let signing_key = extract_signing_key(verification_methods) 524 - .ok_or_else(|| ServiceAuthError::NoSigningKey(claims.iss.clone().into_static()))?; 525 - 526 - // Verify signature FIRST - if this fails, nothing else matters 527 - service_auth::verify_signature(&parsed, &signing_key)?; 528 - 529 - // Now validate claims (audience, expiration, etc.) 530 - claims.validate(&state.service_did())?; 531 - 532 - // Check method binding if required 533 - if state.require_lxm() && claims.lxm.is_none() { 534 - return Err(ServiceAuthError::MethodBindingRequired); 535 - } 536 - 537 - // All checks passed - return verified auth 538 - Ok(ExtractOptionalServiceAuth(Some(VerifiedServiceAuth { 539 - did: claims.iss.clone().into_static(), 540 - aud: claims.aud.clone().into_static(), 541 - lxm: claims.lxm.as_ref().map(|l| l.clone().into_static()), 542 - jti: claims.jti.as_ref().map(|j| j.clone().into_static()), 543 - }))) 799 + .map(|auth| Self(Some(auth))) 544 800 } 545 801 } 546 802 }
+404 -5
crates/jacquard-axum/tests/service_auth_tests.rs
··· 10 10 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 11 11 use bytes::Bytes; 12 12 use jacquard_axum::service_auth::{ 13 - ExtractServiceAuth, ServiceAuthConfig, VerifiedServiceAuth, service_auth_middleware, 13 + ExtractOptionalServiceAuth, ExtractServiceAuth, ReplayKey, ReplayStore, ReplayStoreError, 14 + ServiceAuthConfig, VerifiedServiceAuth, require_service_id, service_auth_middleware, 14 15 }; 15 16 use jacquard_common::{ 16 - CowStr, 17 17 bos::BosStr, 18 18 deps::smol_str::SmolStr, 19 19 service_auth::JwtHeader, ··· 27 27 }; 28 28 use reqwest::StatusCode as ReqwestStatusCode; 29 29 use serde_json::json; 30 - use std::future::Future; 30 + use std::{ 31 + collections::HashMap, 32 + future::Future, 33 + pin::Pin, 34 + sync::{Arc, Mutex}, 35 + }; 31 36 use tower::ServiceExt; 32 37 33 38 // Test helper: create a signed JWT ··· 38 43 lxm: Option<&str>, 39 44 signing_key: &k256::ecdsa::SigningKey, 40 45 ) -> String { 46 + create_test_jwt_with_jti( 47 + iss, 48 + aud, 49 + exp, 50 + lxm, 51 + Some(&format!("test-jti-{}-{}-{}", iss, aud, exp)), 52 + signing_key, 53 + ) 54 + } 55 + 56 + fn create_test_jwt_with_jti( 57 + iss: &str, 58 + aud: &str, 59 + exp: i64, 60 + lxm: Option<&str>, 61 + jti: Option<&str>, 62 + signing_key: &k256::ecdsa::SigningKey, 63 + ) -> String { 41 64 use k256::ecdsa::signature::Signer; 42 65 43 66 let header = JwtHeader { 44 - alg: CowStr::new_static("ES256K"), 45 - typ: CowStr::new_static("JWT"), 67 + alg: SmolStr::new_static("ES256K"), 68 + typ: SmolStr::new_static("JWT"), 46 69 }; 47 70 48 71 let mut claims_json = json!({ ··· 52 75 "iat": chrono::Utc::now().timestamp(), 53 76 }); 54 77 78 + if let Some(jti_val) = jti { 79 + claims_json["jti"] = json!(jti_val); 80 + } 81 + 55 82 if let Some(lxm_val) = lxm { 56 83 claims_json["lxm"] = json!(lxm_val); 57 84 } ··· 111 138 did_doc, 112 139 options: ResolverOptions::default(), 113 140 } 141 + } 142 + } 143 + 144 + #[derive(Clone, Default)] 145 + struct DeterministicReplayStore { 146 + entries: Arc<Mutex<HashMap<ReplayKey, i64>>>, 147 + now: Arc<Mutex<i64>>, 148 + } 149 + 150 + impl DeterministicReplayStore { 151 + fn set_now(&self, now: i64) { 152 + *self.now.lock().unwrap() = now; 153 + } 154 + } 155 + 156 + impl ReplayStore for DeterministicReplayStore { 157 + fn check_and_insert( 158 + &self, 159 + key: ReplayKey, 160 + expires_at: i64, 161 + ) -> Pin<Box<dyn Future<Output = Result<(), ReplayStoreError>> + Send + '_>> { 162 + Box::pin(async move { 163 + let now = *self.now.lock().unwrap(); 164 + let mut entries = self.entries.lock().unwrap(); 165 + if let Some(existing_expires_at) = entries.get(&key) { 166 + if *existing_expires_at > now { 167 + return Err(ReplayStoreError::Replayed); 168 + } 169 + } 170 + entries.insert(key, expires_at); 171 + Ok(()) 172 + }) 114 173 } 115 174 } 116 175 ··· 493 552 let body = String::from_utf8(body_bytes.to_vec()).unwrap(); 494 553 495 554 assert_eq!(body, format!("Authenticated as {}", user_did)); 555 + } 556 + 557 + #[tokio::test] 558 + async fn test_optional_extractor_valid_and_missing_and_invalid() { 559 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 560 + let verifying_key = signing_key.verifying_key(); 561 + let user_did = "did:plc:test123"; 562 + let service_did = "did:web:feedgen.example.com"; 563 + let exp = chrono::Utc::now().timestamp() + 300; 564 + let jwt = create_test_jwt( 565 + user_did, 566 + service_did, 567 + exp, 568 + Some("app.bsky.feed.getFeedSkeleton"), 569 + &signing_key, 570 + ); 571 + let did_doc = create_test_did_doc(user_did, verifying_key); 572 + let resolver = MockResolver::new(did_doc); 573 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver); 574 + 575 + async fn handler(ExtractOptionalServiceAuth(auth): ExtractOptionalServiceAuth) -> String { 576 + auth.map(|auth| auth.did().to_string()) 577 + .unwrap_or_else(|| "anonymous".to_string()) 578 + } 579 + 580 + let app = Router::new() 581 + .route("/test", get(handler)) 582 + .with_state(config); 583 + 584 + let request = Request::builder().uri("/test").body(Body::empty()).unwrap(); 585 + let response = app.clone().oneshot(request).await.unwrap(); 586 + assert_eq!(response.status(), StatusCode::OK); 587 + let body = axum::body::to_bytes(response.into_body(), usize::MAX) 588 + .await 589 + .unwrap(); 590 + assert_eq!(String::from_utf8(body.to_vec()).unwrap(), "anonymous"); 591 + 592 + let request = Request::builder() 593 + .uri("/test") 594 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 595 + .body(Body::empty()) 596 + .unwrap(); 597 + let response = app.clone().oneshot(request).await.unwrap(); 598 + assert_eq!(response.status(), StatusCode::OK); 599 + 600 + let request = Request::builder() 601 + .uri("/test") 602 + .header(header::AUTHORIZATION, "Basic bad") 603 + .body(Body::empty()) 604 + .unwrap(); 605 + let response = app.clone().oneshot(request).await.unwrap(); 606 + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 607 + 608 + let request = Request::builder() 609 + .uri("/test") 610 + .header(header::AUTHORIZATION, "Bearer not-a-jwt") 611 + .body(Body::empty()) 612 + .unwrap(); 613 + let response = app.oneshot(request).await.unwrap(); 614 + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 615 + } 616 + 617 + #[tokio::test] 618 + async fn test_replay_rejects_second_presentation() { 619 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 620 + let verifying_key = signing_key.verifying_key(); 621 + let user_did = "did:plc:test123"; 622 + let service_did = "did:web:feedgen.example.com"; 623 + let exp = chrono::Utc::now().timestamp() + 300; 624 + let jwt = create_test_jwt_with_jti( 625 + user_did, 626 + service_did, 627 + exp, 628 + Some("app.bsky.feed.getFeedSkeleton"), 629 + Some("fixed-jti"), 630 + &signing_key, 631 + ); 632 + let did_doc = create_test_did_doc(user_did, verifying_key); 633 + let resolver = MockResolver::new(did_doc); 634 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver); 635 + 636 + async fn handler(ExtractServiceAuth(_auth): ExtractServiceAuth) -> &'static str { 637 + "ok" 638 + } 639 + 640 + let app = Router::new() 641 + .route("/test", get(handler)) 642 + .with_state(config); 643 + 644 + for expected in [StatusCode::OK, StatusCode::UNAUTHORIZED] { 645 + let request = Request::builder() 646 + .uri("/test") 647 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 648 + .body(Body::empty()) 649 + .unwrap(); 650 + let response = app.clone().oneshot(request).await.unwrap(); 651 + assert_eq!(response.status(), expected); 652 + } 653 + } 654 + 655 + #[tokio::test] 656 + async fn test_replay_accepts_again_after_expiration() { 657 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 658 + let verifying_key = signing_key.verifying_key(); 659 + let user_did = "did:plc:test123"; 660 + let service_did = "did:web:feedgen.example.com"; 661 + let now = chrono::Utc::now().timestamp(); 662 + let exp = now + 300; 663 + let jwt = create_test_jwt_with_jti( 664 + user_did, 665 + service_did, 666 + exp, 667 + Some("app.bsky.feed.getFeedSkeleton"), 668 + Some("expiring-jti"), 669 + &signing_key, 670 + ); 671 + let did_doc = create_test_did_doc(user_did, verifying_key); 672 + let resolver = MockResolver::new(did_doc); 673 + let replay_store = DeterministicReplayStore::default(); 674 + replay_store.set_now(now); 675 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver) 676 + .with_replay_store(replay_store.clone()); 677 + 678 + async fn handler(ExtractServiceAuth(_auth): ExtractServiceAuth) -> &'static str { 679 + "ok" 680 + } 681 + 682 + let app = Router::new() 683 + .route("/test", get(handler)) 684 + .with_state(config); 685 + 686 + let request = Request::builder() 687 + .uri("/test") 688 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 689 + .body(Body::empty()) 690 + .unwrap(); 691 + let response = app.clone().oneshot(request).await.unwrap(); 692 + assert_eq!(response.status(), StatusCode::OK); 693 + 694 + replay_store.set_now(exp + 1); 695 + let request = Request::builder() 696 + .uri("/test") 697 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 698 + .body(Body::empty()) 699 + .unwrap(); 700 + let response = app.oneshot(request).await.unwrap(); 701 + assert_eq!(response.status(), StatusCode::OK); 702 + } 703 + 704 + #[tokio::test] 705 + async fn test_replay_disabled_allows_missing_and_repeated_jti() { 706 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 707 + let verifying_key = signing_key.verifying_key(); 708 + let user_did = "did:plc:test123"; 709 + let service_did = "did:web:feedgen.example.com"; 710 + let exp = chrono::Utc::now().timestamp() + 300; 711 + let jwt = create_test_jwt_with_jti( 712 + user_did, 713 + service_did, 714 + exp, 715 + Some("app.bsky.feed.getFeedSkeleton"), 716 + None, 717 + &signing_key, 718 + ); 719 + let did_doc = create_test_did_doc(user_did, verifying_key); 720 + let resolver = MockResolver::new(did_doc); 721 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver) 722 + .disable_replay_protection(); 723 + 724 + async fn handler(ExtractServiceAuth(_auth): ExtractServiceAuth) -> &'static str { 725 + "ok" 726 + } 727 + 728 + let app = Router::new() 729 + .route("/test", get(handler)) 730 + .with_state(config); 731 + 732 + for _ in 0..2 { 733 + let request = Request::builder() 734 + .uri("/test") 735 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 736 + .body(Body::empty()) 737 + .unwrap(); 738 + let response = app.clone().oneshot(request).await.unwrap(); 739 + assert_eq!(response.status(), StatusCode::OK); 740 + } 741 + } 742 + 743 + #[tokio::test] 744 + async fn test_missing_jti_rejected_when_replay_enabled() { 745 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 746 + let verifying_key = signing_key.verifying_key(); 747 + let user_did = "did:plc:test123"; 748 + let service_did = "did:web:feedgen.example.com"; 749 + let exp = chrono::Utc::now().timestamp() + 300; 750 + let jwt = create_test_jwt_with_jti( 751 + user_did, 752 + service_did, 753 + exp, 754 + Some("app.bsky.feed.getFeedSkeleton"), 755 + None, 756 + &signing_key, 757 + ); 758 + let did_doc = create_test_did_doc(user_did, verifying_key); 759 + let resolver = MockResolver::new(did_doc); 760 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver); 761 + 762 + async fn handler(ExtractServiceAuth(_auth): ExtractServiceAuth) -> &'static str { 763 + "ok" 764 + } 765 + 766 + let app = Router::new() 767 + .route("/test", get(handler)) 768 + .with_state(config); 769 + let request = Request::builder() 770 + .uri("/test") 771 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 772 + .body(Body::empty()) 773 + .unwrap(); 774 + let response = app.oneshot(request).await.unwrap(); 775 + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 776 + } 777 + 778 + #[tokio::test] 779 + async fn test_global_service_id_allow_list() { 780 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 781 + let verifying_key = signing_key.verifying_key(); 782 + let user_did = "did:plc:test123"; 783 + let service_did = "did:web:feedgen.example.com"; 784 + let exp = chrono::Utc::now().timestamp() + 300; 785 + let did_doc = create_test_did_doc(user_did, verifying_key); 786 + 787 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 788 + auth.aud().to_string() 789 + } 790 + 791 + for (aud, allowed, expected) in [ 792 + (service_did, false, StatusCode::OK), 793 + ( 794 + "did:web:feedgen.example.com#bsky_appview", 795 + false, 796 + StatusCode::OK, 797 + ), 798 + (service_did, true, StatusCode::OK), 799 + ( 800 + "did:web:feedgen.example.com#bsky_appview", 801 + true, 802 + StatusCode::OK, 803 + ), 804 + ( 805 + "did:web:feedgen.example.com#other", 806 + true, 807 + StatusCode::UNAUTHORIZED, 808 + ), 809 + ( 810 + "did:web:other.example.com#bsky_appview", 811 + true, 812 + StatusCode::UNAUTHORIZED, 813 + ), 814 + ] { 815 + let resolver = MockResolver::new(did_doc.clone()); 816 + let mut config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver) 817 + .disable_replay_protection(); 818 + if allowed { 819 + config = config.allow_service("bsky_appview"); 820 + } 821 + let app = Router::new() 822 + .route("/test", get(handler)) 823 + .with_state(config); 824 + let jwt = create_test_jwt( 825 + user_did, 826 + aud, 827 + exp, 828 + Some("app.bsky.feed.getFeedSkeleton"), 829 + &signing_key, 830 + ); 831 + let request = Request::builder() 832 + .uri("/test") 833 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 834 + .body(Body::empty()) 835 + .unwrap(); 836 + let response = app.oneshot(request).await.unwrap(); 837 + assert_eq!(response.status(), expected, "aud={aud} allowed={allowed}"); 838 + } 839 + } 840 + 841 + #[tokio::test] 842 + async fn test_route_scoped_service_id_policy() { 843 + let signing_key = k256::ecdsa::SigningKey::random(&mut rand::thread_rng()); 844 + let verifying_key = signing_key.verifying_key(); 845 + let user_did = "did:plc:test123"; 846 + let service_did = "did:web:feedgen.example.com"; 847 + let exp = chrono::Utc::now().timestamp() + 300; 848 + let did_doc = create_test_did_doc(user_did, verifying_key); 849 + 850 + async fn handler(ExtractServiceAuth(auth): ExtractServiceAuth) -> String { 851 + auth.service().unwrap_or("bare").to_string() 852 + } 853 + 854 + for (aud, strict, expected) in [ 855 + (service_did, false, StatusCode::OK), 856 + (service_did, true, StatusCode::UNAUTHORIZED), 857 + ( 858 + "did:web:feedgen.example.com#bsky_appview", 859 + true, 860 + StatusCode::OK, 861 + ), 862 + ( 863 + "did:web:feedgen.example.com#other", 864 + true, 865 + StatusCode::UNAUTHORIZED, 866 + ), 867 + ] { 868 + let resolver = MockResolver::new(did_doc.clone()); 869 + let config = ServiceAuthConfig::new(Did::new_static(service_did).unwrap(), resolver) 870 + .disable_replay_protection(); 871 + let route = get(handler); 872 + let app = if strict { 873 + Router::new() 874 + .route("/test", route) 875 + .route_layer(require_service_id("bsky_appview")) 876 + .with_state(config) 877 + } else { 878 + Router::new().route("/test", route).with_state(config) 879 + }; 880 + let jwt = create_test_jwt( 881 + user_did, 882 + aud, 883 + exp, 884 + Some("app.bsky.feed.getFeedSkeleton"), 885 + &signing_key, 886 + ); 887 + let request = Request::builder() 888 + .uri("/test") 889 + .header(header::AUTHORIZATION, format!("Bearer {}", jwt)) 890 + .body(Body::empty()) 891 + .unwrap(); 892 + let response = app.oneshot(request).await.unwrap(); 893 + assert_eq!(response.status(), expected, "aud={aud} strict={strict}"); 894 + } 496 895 } 497 896 498 897 #[tokio::test]
-1
crates/jacquard-common/Cargo.toml
··· 66 66 miette = { workspace = true, optional = true } # also need to gate this to std only 67 67 multibase = { version = "0.9.1", default-features = false } 68 68 multihash = { version = "0.19.3", default-features = false, features = ["alloc"] } 69 - ouroboros = "0.18.5" 70 69 rand = { version = "0.9.2", default-features = false, features = ["alloc"] } 71 70 serde.workspace = true 72 71 serde_html_form.workspace = true # need to check these at workspace level
+150 -112
crates/jacquard-common/src/service_auth.rs
··· 18 18 19 19 use crate::CowStr; 20 20 use crate::IntoStatic; 21 - use crate::types::string::{Did, Nsid}; 21 + use crate::bos::{BosStr, DefaultStr}; 22 + use crate::types::string::{Did, DidService, Nsid}; 22 23 use alloc::string::String; 23 - use alloc::string::ToString; 24 24 use alloc::vec::Vec; 25 25 use base64::Engine; 26 26 use base64::engine::general_purpose::URL_SAFE_NO_PAD; 27 - use ouroboros::self_referencing; 28 27 use serde::{Deserialize, Serialize}; 29 28 use signature::Verifier; 30 29 use smol_str::SmolStr; ··· 73 72 now: i64, 74 73 }, 75 74 76 - /// Audience mismatch 75 + /// Audience mismatch. 77 76 #[error("audience mismatch: expected {expected}, got {actual}")] 78 77 AudienceMismatch { 79 - /// Expected audience DID 78 + /// Expected audience DID. 80 79 expected: Did, 81 - /// Actual audience DID in token 82 - actual: Did, 80 + /// Actual audience DID service in token. 81 + actual: DidService, 82 + }, 83 + 84 + /// Service id mismatch. 85 + #[error("service id mismatch: allowed {allowed:?}, got {actual:?}")] 86 + ServiceIdMismatch { 87 + /// Allowed service ids. 88 + allowed: Vec<SmolStr>, 89 + /// Actual service id in token. 90 + actual: Option<SmolStr>, 83 91 }, 84 92 85 93 /// Method mismatch (lxm field) ··· 102 110 103 111 /// JWT header for service auth tokens. 104 112 #[derive(Debug, Clone, Serialize, Deserialize)] 105 - pub struct JwtHeader<'a> { 106 - /// Algorithm used for signing 107 - #[serde(borrow)] 108 - pub alg: CowStr<'a>, 109 - /// Type (always "JWT") 110 - #[serde(borrow)] 111 - pub typ: CowStr<'a>, 113 + pub struct JwtHeader<S: BosStr = DefaultStr> { 114 + /// Algorithm used for signing. 115 + pub alg: S, 116 + /// Type (always "JWT"). 117 + pub typ: S, 112 118 } 113 119 114 - impl IntoStatic for JwtHeader<'_> { 115 - type Output = JwtHeader<'static>; 120 + impl<S> IntoStatic for JwtHeader<S> 121 + where 122 + S: BosStr + IntoStatic, 123 + S::Output: BosStr, 124 + { 125 + type Output = JwtHeader<S::Output>; 116 126 117 127 fn into_static(self) -> Self::Output { 118 128 JwtHeader { ··· 126 136 /// 127 137 /// These are the payload fields in a service auth JWT. 128 138 #[derive(Debug, Clone, Serialize, Deserialize)] 129 - pub struct ServiceAuthClaims<'a> { 130 - /// Issuer (user's DID) 131 - pub iss: Did, 139 + pub struct ServiceAuthClaims<S: BosStr = DefaultStr> { 140 + /// Issuer (user's DID). 141 + pub iss: Did<S>, 132 142 133 - /// Audience (target service DID) 134 - pub aud: Did, 143 + /// Audience (target service DID with optional service-id fragment). 144 + pub aud: DidService<S>, 135 145 136 - /// Expiration time (unix timestamp) 146 + /// Expiration time (unix timestamp). 137 147 pub exp: i64, 138 148 139 - /// Issued at (unix timestamp) 149 + /// Issued at (unix timestamp). 140 150 pub iat: i64, 141 151 142 - /// JWT ID (nonce for replay protection) 143 - #[serde(borrow, skip_serializing_if = "Option::is_none")] 144 - pub jti: Option<CowStr<'a>>, 152 + /// JWT ID (nonce for replay protection). 153 + #[serde(skip_serializing_if = "Option::is_none")] 154 + pub jti: Option<S>, 145 155 146 - /// Lexicon method NSID (method binding) 156 + /// Lexicon method NSID (method binding). 147 157 #[serde(skip_serializing_if = "Option::is_none")] 148 - pub lxm: Option<Nsid>, 158 + pub lxm: Option<Nsid<S>>, 149 159 } 150 160 151 - impl<'a> IntoStatic for ServiceAuthClaims<'a> { 152 - type Output = ServiceAuthClaims<'static>; 161 + impl<S> IntoStatic for ServiceAuthClaims<S> 162 + where 163 + S: BosStr + IntoStatic, 164 + S::Output: BosStr, 165 + { 166 + type Output = ServiceAuthClaims<S::Output>; 153 167 154 168 fn into_static(self) -> Self::Output { 155 169 ServiceAuthClaims { ··· 163 177 } 164 178 } 165 179 166 - impl<'a> ServiceAuthClaims<'a> { 180 + impl<S: BosStr> ServiceAuthClaims<S> { 167 181 /// Validate the claims against expected values. 168 182 /// 169 183 /// Checks: 170 - /// - Audience matches expected DID 171 - /// - Token is not expired 172 - pub fn validate(&self, expected_aud: &Did<&str>) -> Result<(), ServiceAuthError> { 173 - // Check audience 174 - if self.aud.as_str() != expected_aud.as_str() { 184 + /// - The fragmentless audience matches the expected DID. 185 + /// - A present service-id fragment is in the allowed service list when configured. 186 + /// - The token is not expired. 187 + pub fn validate<B, Svc>( 188 + &self, 189 + expected_aud: &Did<B>, 190 + allowed_services: &[Svc], 191 + ) -> Result<(), ServiceAuthError> 192 + where 193 + B: BosStr, 194 + Svc: AsRef<str>, 195 + { 196 + if self.aud.audience().as_str() != expected_aud.as_str() { 175 197 return Err(ServiceAuthError::AudienceMismatch { 176 - expected: expected_aud.clone().into_static(), 177 - actual: self.aud.clone().into_static(), 198 + expected: expected_aud.borrow().into_static(), 199 + actual: DidService::new_owned(self.aud.as_str()).unwrap(), 178 200 }); 179 201 } 180 202 181 - // Check expiration 203 + if !allowed_services.is_empty() { 204 + if let Some(service) = self.aud.service() { 205 + if !allowed_services 206 + .iter() 207 + .any(|allowed| allowed.as_ref() == service) 208 + { 209 + return Err(ServiceAuthError::ServiceIdMismatch { 210 + allowed: allowed_services 211 + .iter() 212 + .map(|allowed| SmolStr::new(allowed.as_ref())) 213 + .collect(), 214 + actual: Some(SmolStr::new(service)), 215 + }); 216 + } 217 + } 218 + } 219 + 182 220 if self.is_expired() { 183 221 let now = chrono::Utc::now().timestamp(); 184 222 return Err(ServiceAuthError::Expired { exp: self.exp, now }); ··· 219 257 220 258 /// Parsed JWT components. 221 259 /// 222 - /// This struct owns the decoded buffers and parsed components using ouroboros 223 - /// self-referencing. The header and claims borrow from their respective buffers. 224 - #[self_referencing] 225 - pub struct ParsedJwt { 226 - /// Decoded header buffer (owned) 227 - header_buf: Vec<u8>, 228 - /// Decoded payload buffer (owned) 229 - payload_buf: Vec<u8>, 230 - /// Original token string for signing_input 231 - token: String, 232 - /// Signature bytes 260 + /// This struct owns decoded and parsed JWT data. `signing_input` stores the 261 + /// original `header.payload` bytes used for signature verification. 262 + pub struct ParsedJwt<S: BosStr = DefaultStr> { 263 + /// Parsed JWT header. 264 + header: JwtHeader, 265 + /// Parsed service-auth claims. 266 + claims: ServiceAuthClaims<S>, 267 + /// Original `header.payload` signing input. 268 + signing_input: String, 269 + /// Decoded signature bytes. 233 270 signature: Vec<u8>, 234 - /// Parsed header borrowing from header_buf 235 - #[borrows(header_buf)] 236 - #[covariant] 237 - header: JwtHeader<'this>, 238 - /// Parsed claims borrowing from payload_buf 239 - #[borrows(payload_buf)] 240 - #[covariant] 241 - claims: ServiceAuthClaims<'this>, 242 271 } 243 272 244 - impl ParsedJwt { 273 + impl<S: BosStr> ParsedJwt<S> { 245 274 /// Get the signing input (header.payload) for signature verification. 246 275 pub fn signing_input(&self) -> &[u8] { 247 - self.with_token(|token| { 248 - let dot_pos = token.find('.').unwrap(); 249 - let second_dot_pos = token[dot_pos + 1..].find('.').unwrap() + dot_pos + 1; 250 - token[..second_dot_pos].as_bytes() 251 - }) 276 + self.signing_input.as_bytes() 252 277 } 253 278 254 279 /// Get a reference to the header. 255 - pub fn header(&self) -> &JwtHeader<'_> { 256 - self.borrow_header() 280 + pub fn header(&self) -> &JwtHeader { 281 + &self.header 257 282 } 258 283 259 284 /// Get a reference to the claims. 260 - pub fn claims(&self) -> &ServiceAuthClaims<'_> { 261 - self.borrow_claims() 285 + pub fn claims(&self) -> &ServiceAuthClaims<S> { 286 + &self.claims 262 287 } 263 288 264 289 /// Get a reference to the signature. 265 290 pub fn signature(&self) -> &[u8] { 266 - self.borrow_signature() 291 + &self.signature 267 292 } 268 293 269 294 /// Get owned header with 'static lifetime. 270 - pub fn into_header(self) -> JwtHeader<'static> { 271 - self.with_header(|header| header.clone().into_static()) 295 + pub fn into_header(self) -> JwtHeader { 296 + self.header 272 297 } 273 298 274 - /// Get owned claims with 'static lifetime. 275 - pub fn into_claims(self) -> ServiceAuthClaims<'static> { 276 - self.with_claims(|claims| claims.clone().into_static()) 299 + /// Get owned claims. 300 + pub fn into_claims(self) -> ServiceAuthClaims<S> { 301 + self.claims 277 302 } 278 303 } 279 304 280 305 /// Parse a JWT token into its components without verifying the signature. 281 306 /// 282 307 /// This extracts and decodes all JWT components. The header and claims are parsed 283 - /// and borrow from their respective owned buffers using ouroboros self-referencing. 308 + /// into their default owned backing types. 284 309 pub fn parse_jwt(token: &str) -> Result<ParsedJwt, ServiceAuthError> { 285 310 let parts: Vec<&str> = token.split('.').collect(); 286 311 if parts.len() != 3 { ··· 293 318 let payload_b64 = parts[1]; 294 319 let signature_b64 = parts[2]; 295 320 296 - // Decode all components 297 321 let header_buf = URL_SAFE_NO_PAD.decode(header_b64)?; 298 322 let payload_buf = URL_SAFE_NO_PAD.decode(payload_b64)?; 299 323 let signature = URL_SAFE_NO_PAD.decode(signature_b64)?; 324 + let header: JwtHeader = serde_json::from_slice(&header_buf)?; 325 + let claims: ServiceAuthClaims = serde_json::from_slice(&payload_buf)?; 326 + let signing_input = format!("{}.{}", header_b64, payload_b64); 300 327 301 - // Validate that buffers contain valid JSON for their types 302 - // We parse once here to validate, then again in the builder (unavoidable with ouroboros) 303 - let _header: JwtHeader = serde_json::from_slice(&header_buf)?; 304 - let _claims: ServiceAuthClaims = serde_json::from_slice(&payload_buf)?; 305 - 306 - Ok(ParsedJwtBuilder { 307 - header_buf, 308 - payload_buf, 309 - token: token.to_string(), 328 + Ok(ParsedJwt { 329 + header, 330 + claims, 331 + signing_input, 310 332 signature, 311 - header_builder: |buf| { 312 - // Safe: we validated this succeeds above 313 - serde_json::from_slice(buf).expect("header was validated") 314 - }, 315 - claims_builder: |buf| { 316 - // Safe: we validated this succeeds above 317 - serde_json::from_slice(buf).expect("claims were validated") 318 - }, 319 - } 320 - .build()) 333 + }) 321 334 } 322 335 323 336 /// Public key types for signature verification. ··· 402 415 pub fn verify_service_jwt( 403 416 token: &str, 404 417 public_key: &PublicKey, 405 - ) -> Result<ServiceAuthClaims<'static>, ServiceAuthError> { 418 + ) -> Result<ServiceAuthClaims, ServiceAuthError> { 406 419 let parsed = parse_jwt(token)?; 407 420 verify_signature(&parsed, public_key)?; 408 421 Ok(parsed.into_claims()) ··· 421 434 #[test] 422 435 fn test_claims_expiration() { 423 436 let now = chrono::Utc::now().timestamp(); 424 - let expired_claims = ServiceAuthClaims { 437 + let expired_claims: ServiceAuthClaims = ServiceAuthClaims { 425 438 iss: Did::new_static("did:plc:test").unwrap(), 426 - aud: Did::new_static("did:web:example.com").unwrap(), 439 + aud: DidService::new_static("did:web:example.com").unwrap(), 427 440 exp: now - 100, 428 441 iat: now - 200, 429 442 jti: None, ··· 432 445 433 446 assert!(expired_claims.is_expired()); 434 447 435 - let valid_claims = ServiceAuthClaims { 448 + let valid_claims: ServiceAuthClaims = ServiceAuthClaims { 436 449 iss: Did::new_static("did:plc:test").unwrap(), 437 - aud: Did::new_static("did:web:example.com").unwrap(), 450 + aud: DidService::new_static("did:web:example.com").unwrap(), 438 451 exp: now + 100, 439 452 iat: now, 440 453 jti: None, ··· 444 457 assert!(!valid_claims.is_expired()); 445 458 } 446 459 447 - #[test] 448 - fn test_audience_validation() { 449 - let now = chrono::Utc::now().timestamp(); 450 - let claims = ServiceAuthClaims { 460 + fn claims_with_aud(aud: &str) -> ServiceAuthClaims { 461 + ServiceAuthClaims { 451 462 iss: Did::new_static("did:plc:test").unwrap(), 452 - aud: Did::new_static("did:web:example.com").unwrap(), 453 - exp: now + 100, 454 - iat: now, 463 + aud: DidService::new_owned(aud).unwrap(), 464 + exp: chrono::Utc::now().timestamp() + 100, 465 + iat: chrono::Utc::now().timestamp(), 455 466 jti: None, 456 467 lxm: None, 457 - }; 468 + } 469 + } 458 470 471 + #[test] 472 + fn test_audience_validation() { 459 473 let expected_aud = Did::new("did:web:example.com").unwrap(); 460 - assert!(claims.validate(&expected_aud).is_ok()); 474 + assert!( 475 + claims_with_aud("did:web:example.com") 476 + .validate(&expected_aud, &[] as &[&str]) 477 + .is_ok() 478 + ); 479 + assert!( 480 + claims_with_aud("did:web:example.com#bsky_appview") 481 + .validate(&expected_aud, &[] as &[&str]) 482 + .is_ok() 483 + ); 484 + assert!( 485 + claims_with_aud("did:web:example.com") 486 + .validate(&expected_aud, &["bsky_appview"]) 487 + .is_ok() 488 + ); 489 + assert!( 490 + claims_with_aud("did:web:example.com#bsky_appview") 491 + .validate(&expected_aud, &["bsky_appview"]) 492 + .is_ok() 493 + ); 494 + assert!(matches!( 495 + claims_with_aud("did:web:example.com#other").validate(&expected_aud, &["bsky_appview"]), 496 + Err(ServiceAuthError::ServiceIdMismatch { .. }) 497 + )); 461 498 462 499 let wrong_aud = Did::new("did:web:wrong.com").unwrap(); 463 500 assert!(matches!( 464 - claims.validate(&wrong_aud), 501 + claims_with_aud("did:web:example.com#bsky_appview") 502 + .validate(&wrong_aud, &["bsky_appview"]), 465 503 Err(ServiceAuthError::AudienceMismatch { .. }) 466 504 )); 467 505 } 468 506 469 507 #[test] 470 508 fn test_method_check() { 471 - let claims = ServiceAuthClaims { 509 + let claims: ServiceAuthClaims = ServiceAuthClaims { 472 510 iss: Did::new_static("did:plc:test").unwrap(), 473 - aud: Did::new_static("did:web:example.com").unwrap(), 511 + aud: DidService::new_static("did:web:example.com").unwrap(), 474 512 exp: chrono::Utc::now().timestamp() + 100, 475 513 iat: chrono::Utc::now().timestamp(), 476 514 jti: None,
+2
crates/jacquard-common/src/types.rs
··· 18 18 pub mod did; 19 19 /// DID Document types and helpers 20 20 pub mod did_doc; 21 + /// DID service audience types and validation. 22 + pub mod did_service; 21 23 /// AT Protocol handle types and validation 22 24 pub mod handle; 23 25 /// AT Protocol identifier types (handle or DID)
+332
crates/jacquard-common/src/types/did_service.rs
··· 1 + use crate::bos::{Bos, BosStr, DefaultStr}; 2 + use crate::types::did::{Did, validate_did}; 3 + use crate::types::string::{AtStrError, StrParseKind}; 4 + use crate::{CowStr, IntoStatic}; 5 + use alloc::string::{String, ToString}; 6 + use core::fmt; 7 + use core::ops::Deref; 8 + use core::str::FromStr; 9 + use serde::{Deserialize, Deserializer, Serialize}; 10 + use smol_str::SmolStr; 11 + 12 + /// A DID audience with an optional service-id fragment. 13 + /// 14 + /// Service auth JWTs may target either a bare service DID such as 15 + /// `did:web:example.com` or a DID plus service fragment such as 16 + /// `did:web:example.com#bsky_appview`. 17 + #[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)] 18 + #[serde(transparent)] 19 + #[repr(transparent)] 20 + pub struct DidService<S: Bos<str> = DefaultStr>(pub(crate) S); 21 + 22 + /// Validate a DID service audience string without constructing a [`DidService`]. 23 + pub fn validate_did_service(value: &str) -> Result<(), AtStrError> { 24 + if value.len() > 2048 { 25 + return Err(AtStrError::too_long( 26 + "did service audience", 27 + value, 28 + 2048, 29 + value.len(), 30 + )); 31 + } 32 + 33 + let mut parts = value.split('#'); 34 + let did = parts.next().unwrap_or_default(); 35 + let service = parts.next(); 36 + if parts.next().is_some() { 37 + return Err(AtStrError::regex( 38 + "did service audience", 39 + value, 40 + SmolStr::new_static("multiple fragments"), 41 + )); 42 + } 43 + 44 + validate_did(did)?; 45 + 46 + if let Some(service) = service { 47 + validate_service_id(service, value)?; 48 + } 49 + 50 + Ok(()) 51 + } 52 + 53 + fn validate_service_id(service: &str, whole: &str) -> Result<(), AtStrError> { 54 + let mut bytes = service.bytes(); 55 + match bytes.next() { 56 + Some(first) if first.is_ascii_alphabetic() => {} 57 + _ => { 58 + return Err(AtStrError::regex( 59 + "did service audience", 60 + whole, 61 + SmolStr::new_static("invalid service id"), 62 + )); 63 + } 64 + } 65 + 66 + if bytes.all(|byte| byte.is_ascii_alphanumeric() || byte == b'_' || byte == b'-') { 67 + Ok(()) 68 + } else { 69 + Err(AtStrError::regex( 70 + "did service audience", 71 + whole, 72 + SmolStr::new_static("invalid service id"), 73 + )) 74 + } 75 + } 76 + 77 + impl<S: BosStr> DidService<S> { 78 + /// Get the full DID service audience as a string slice. 79 + pub fn as_str(&self) -> &str { 80 + self.0.as_ref() 81 + } 82 + 83 + /// Get the fragmentless DID audience. 84 + pub fn audience(&self) -> Did<&str> { 85 + let did = self 86 + .as_str() 87 + .split_once('#') 88 + .map_or(self.as_str(), |(did, _)| did); 89 + // SAFETY: self is already validated, and validation validates the DID portion. 90 + unsafe { Did::unchecked(did) } 91 + } 92 + 93 + /// Get the optional service-id fragment without the leading `#`. 94 + pub fn service(&self) -> Option<&str> { 95 + self.as_str().split_once('#').map(|(_, service)| service) 96 + } 97 + } 98 + 99 + impl<S: Bos<str>> DidService<S> { 100 + /// Infallible unchecked constructor. 101 + /// 102 + /// # Safety 103 + /// 104 + /// The caller must ensure the DID service audience is valid. 105 + pub unsafe fn unchecked(value: S) -> Self { 106 + Self(value) 107 + } 108 + 109 + /// Convert to a `DidService` with a different backing type. 110 + pub fn convert<B: Bos<str> + From<S>>(self) -> DidService<B> { 111 + DidService(B::from(self.0)) 112 + } 113 + 114 + /// Borrow as a `DidService<&str>`. 115 + pub fn borrow(&self) -> DidService<&str> 116 + where 117 + S: AsRef<str>, 118 + { 119 + // SAFETY: self is already validated. 120 + unsafe { DidService::unchecked(self.0.as_ref()) } 121 + } 122 + } 123 + 124 + impl<S: BosStr> DidService<S> { 125 + /// Fallible constructor that validates and wraps the input directly. 126 + pub fn new(s: S) -> Result<Self, AtStrError> { 127 + validate_did_service(s.as_ref())?; 128 + Ok(Self(s)) 129 + } 130 + 131 + /// Infallible constructor. Panics on invalid DID service audiences. 132 + pub fn raw(s: S) -> Self { 133 + Self::new(s).expect("invalid DID service audience") 134 + } 135 + } 136 + 137 + impl<S: BosStr + FromStr> DidService<S> { 138 + /// Fallible constructor that validates and takes ownership. 139 + pub fn new_owned(value: impl AsRef<str>) -> Result<Self, AtStrError> { 140 + let value = value.as_ref(); 141 + validate_did_service(value)?; 142 + let s = S::from_str(value).map_err(|_| { 143 + AtStrError::new( 144 + "did service audience", 145 + value.to_string(), 146 + StrParseKind::Conversion, 147 + ) 148 + })?; 149 + Ok(Self(s)) 150 + } 151 + 152 + /// Fallible constructor for static strings. 153 + pub fn new_static(value: &'static str) -> Result<Self, AtStrError> { 154 + validate_did_service(value)?; 155 + let s = S::from_str(value).map_err(|_| { 156 + AtStrError::new( 157 + "did service audience", 158 + value.to_string(), 159 + StrParseKind::Conversion, 160 + ) 161 + })?; 162 + Ok(Self(s)) 163 + } 164 + } 165 + 166 + impl<'de, S> Deserialize<'de> for DidService<S> 167 + where 168 + S: BosStr + Deserialize<'de>, 169 + { 170 + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> 171 + where 172 + D: Deserializer<'de>, 173 + { 174 + let s = S::deserialize(deserializer)?; 175 + validate_did_service(s.as_ref()).map_err(serde::de::Error::custom)?; 176 + Ok(Self(s)) 177 + } 178 + } 179 + 180 + impl<S> IntoStatic for DidService<S> 181 + where 182 + S: Bos<str> + IntoStatic, 183 + S::Output: Bos<str>, 184 + { 185 + type Output = DidService<S::Output>; 186 + 187 + fn into_static(self) -> Self::Output { 188 + DidService(self.0.into_static()) 189 + } 190 + } 191 + 192 + impl<S: BosStr + FromStr> FromStr for DidService<S> { 193 + type Err = AtStrError; 194 + 195 + fn from_str(s: &str) -> Result<Self, Self::Err> { 196 + Self::new_owned(s) 197 + } 198 + } 199 + 200 + impl<S: BosStr> fmt::Display for DidService<S> { 201 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 202 + f.write_str(self.as_str()) 203 + } 204 + } 205 + 206 + impl<S: BosStr> fmt::Debug for DidService<S> { 207 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 208 + f.write_str(self.as_str()) 209 + } 210 + } 211 + 212 + impl<S: BosStr> From<DidService<S>> for String { 213 + fn from(value: DidService<S>) -> Self { 214 + value.as_str().to_string() 215 + } 216 + } 217 + 218 + impl<S: BosStr> From<DidService<S>> for CowStr<'static> { 219 + fn from(value: DidService<S>) -> Self { 220 + CowStr::copy_from_str(value.as_str()) 221 + } 222 + } 223 + 224 + impl From<String> for DidService { 225 + fn from(value: String) -> Self { 226 + Self::new_owned(value).unwrap() 227 + } 228 + } 229 + 230 + impl<'d> From<CowStr<'d>> for DidService<CowStr<'d>> { 231 + fn from(value: CowStr<'d>) -> Self { 232 + Self::new(value).unwrap() 233 + } 234 + } 235 + 236 + impl<S: BosStr> AsRef<str> for DidService<S> { 237 + fn as_ref(&self) -> &str { 238 + self.as_str() 239 + } 240 + } 241 + 242 + impl<S: BosStr> Deref for DidService<S> { 243 + type Target = str; 244 + 245 + fn deref(&self) -> &Self::Target { 246 + self.as_str() 247 + } 248 + } 249 + 250 + #[cfg(test)] 251 + mod tests { 252 + use super::*; 253 + use serde_json::json; 254 + 255 + #[test] 256 + fn valid_with_service_id() { 257 + assert!(DidService::<&str>::new("did:web:example.com#bsky_appview").is_ok()); 258 + assert!(DidService::<&str>::new("did:plc:abc123#atproto_labeler").is_ok()); 259 + } 260 + 261 + #[test] 262 + fn valid_bare_did() { 263 + assert!(DidService::<&str>::new("did:web:example.com").is_ok()); 264 + } 265 + 266 + #[test] 267 + fn splits_audience_and_service() { 268 + let value = DidService::<&str>::new("did:web:example.com#bsky_appview").unwrap(); 269 + assert_eq!(value.audience().as_str(), "did:web:example.com"); 270 + assert_eq!(value.service(), Some("bsky_appview")); 271 + 272 + let bare = DidService::<&str>::new("did:web:example.com").unwrap(); 273 + assert_eq!(bare.audience().as_str(), "did:web:example.com"); 274 + assert_eq!(bare.service(), None); 275 + } 276 + 277 + #[test] 278 + fn rejects_empty_fragment() { 279 + assert!(DidService::<&str>::new("did:web:example.com#").is_err()); 280 + } 281 + 282 + #[test] 283 + fn rejects_invalid_service_chars() { 284 + for value in [ 285 + "did:web:example.com#1bad", 286 + "did:web:example.com#-bad", 287 + "did:web:example.com#bad.service", 288 + "did:web:example.com#bad:service", 289 + "did:web:example.com#bad service", 290 + "did:web:example.com#bad#service", 291 + ] { 292 + assert!(DidService::<&str>::new(value).is_err(), "{value}"); 293 + } 294 + } 295 + 296 + #[test] 297 + fn rejects_invalid_did_body() { 298 + assert!(DidService::<&str>::new("not-a-did#service").is_err()); 299 + } 300 + 301 + #[test] 302 + fn enforces_max_length() { 303 + let service = "a".repeat(2049 - "did:web:example.com#".len()); 304 + let value = format!("did:web:example.com#{service}"); 305 + assert!(DidService::<&str>::new(&value).is_err()); 306 + } 307 + 308 + #[test] 309 + fn serde_roundtrip() { 310 + let value = DidService::new_static("did:web:example.com#bsky_appview").unwrap(); 311 + let json = serde_json::to_value(&value).unwrap(); 312 + assert_eq!(json, json!("did:web:example.com#bsky_appview")); 313 + let decoded: DidService = serde_json::from_value(json).unwrap(); 314 + assert_eq!(decoded, value); 315 + } 316 + 317 + #[test] 318 + fn into_static_preserves_value() { 319 + let value = DidService::<CowStr<'_>>::new(CowStr::copy_from_str( 320 + "did:web:example.com#bsky_appview", 321 + )) 322 + .unwrap(); 323 + let static_value: DidService<CowStr<'static>> = value.into_static(); 324 + assert_eq!(static_value.as_str(), "did:web:example.com#bsky_appview"); 325 + } 326 + 327 + #[test] 328 + fn from_str_owns_value() { 329 + let value: DidService = "did:web:example.com#bsky_appview".parse().unwrap(); 330 + assert_eq!(value.as_str(), "did:web:example.com#bsky_appview"); 331 + } 332 + }
+1
crates/jacquard-common/src/types/string.rs
··· 34 34 cid::{Cid, CidLink}, 35 35 datetime::Datetime, 36 36 did::Did, 37 + did_service::DidService, 37 38 handle::Handle, 38 39 ident::AtIdentifier, 39 40 language::Language,