Now let's take a silly one
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 25 kB View raw
1use std::io::BufReader; 2use std::path::{Path, PathBuf}; 3use std::sync::Arc; 4 5use arc_swap::ArcSwap; 6use base64::Engine; 7use quinn::crypto::rustls::QuicServerConfig; 8use rustls::RootCertStore; 9use rustls::ServerConfig; 10use rustls::crypto::aws_lc_rs; 11use rustls::pki_types::{CertificateDer, PrivateKeyDer, UnixTime}; 12use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; 13use rustls::server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier}; 14use rustls::sign::CertifiedKey; 15use sha2::{Digest, Sha256}; 16use tokio_util::sync::CancellationToken; 17 18use crate::limits::ListenLimits; 19use crate::zerortt::EarlyDataPolicy; 20 21pub const ACME_TLS_ALPN: &[u8] = rustls_acme::acme::ACME_TLS_ALPN_NAME; 22 23const MAX_CONCURRENT_STREAMS: u32 = 256; 24const STREAM_RECEIVE_WINDOW: u32 = 8 * 1024 * 1024; 25const CONNECTION_RECEIVE_WINDOW: u32 = 32 * 1024 * 1024; 26 27#[derive(Debug, thiserror::Error)] 28pub enum TlsError { 29 #[error("reading {path}: {source}")] 30 Read { 31 path: String, 32 source: std::io::Error, 33 }, 34 #[error("parsing {path}: {message}")] 35 Parse { path: String, message: String }, 36 #[error("no certificates found in {0}")] 37 NoCertificates(String), 38 #[error("no private key found in {0}")] 39 NoPrivateKey(String), 40 #[error("unusable private key: {0}")] 41 SigningKey(String), 42 #[error("building server config: {0}")] 43 Config(String), 44 #[error("certificate and private key do not match: {0}")] 45 KeyMismatch(String), 46 #[error("session ticketer: {0}")] 47 Ticketer(String), 48 #[error("client certificate verifier for {path}: {message}")] 49 ClientVerifier { path: String, message: String }, 50 #[error("admin SPKI pin: {0}")] 51 SpkiPin(String), 52} 53 54#[derive(Clone)] 55pub struct SpkiPin([u8; 32]); 56 57impl PartialEq for SpkiPin { 58 fn eq(&self, other: &Self) -> bool { 59 self.0 60 .iter() 61 .zip(other.0.iter()) 62 .fold(0u8, |acc, (left, right)| acc | (left ^ right)) 63 == 0 64 } 65} 66 67impl Eq for SpkiPin {} 68 69impl SpkiPin { 70 pub fn from_base64(encoded: &str) -> Result<Self, TlsError> { 71 let bytes = base64::engine::general_purpose::STANDARD 72 .decode(encoded.trim()) 73 .map_err(|error| TlsError::SpkiPin(error.to_string()))?; 74 let array: [u8; 32] = bytes.try_into().map_err(|bytes: Vec<u8>| { 75 TlsError::SpkiPin(format!("expected 32 bytes, got {}", bytes.len())) 76 })?; 77 Ok(Self(array)) 78 } 79 80 fn of_certificate(cert: &CertificateDer<'_>) -> Result<Self, rustls::Error> { 81 let (_, parsed) = x509_parser::parse_x509_certificate(cert.as_ref()).map_err(|error| { 82 rustls::Error::General(format!("parse client certificate: {error}")) 83 })?; 84 Ok(Self( 85 Sha256::digest(parsed.tbs_certificate.subject_pki.raw).into(), 86 )) 87 } 88} 89 90impl std::fmt::Debug for SpkiPin { 91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 92 f.debug_tuple("SpkiPin").finish_non_exhaustive() 93 } 94} 95 96pub(crate) fn fuzz_of_certificate(data: &[u8]) { 97 let cert = CertificateDer::from(data.to_vec()); 98 let _ = SpkiPin::of_certificate(&cert); 99} 100 101pub struct ReloadableCertResolver { 102 current: ArcSwap<CertifiedKey>, 103} 104 105impl std::fmt::Debug for ReloadableCertResolver { 106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 107 f.debug_struct("ReloadableCertResolver") 108 .finish_non_exhaustive() 109 } 110} 111 112impl ReloadableCertResolver { 113 pub fn new(initial: CertifiedKey) -> Self { 114 Self { 115 current: ArcSwap::from_pointee(initial), 116 } 117 } 118 119 pub fn store(&self, key: CertifiedKey) { 120 self.current.store(Arc::new(key)); 121 } 122} 123 124impl ResolvesServerCert for ReloadableCertResolver { 125 fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> { 126 Some(self.current.load_full()) 127 } 128} 129 130pub fn spawn_cert_reload( 131 resolver: Arc<ReloadableCertResolver>, 132 cert_path: PathBuf, 133 key_path: PathBuf, 134 shutdown: CancellationToken, 135) { 136 #[cfg(unix)] 137 tokio::spawn(async move { 138 use tokio::signal::unix::{SignalKind, signal}; 139 let mut hangup = match signal(SignalKind::hangup()) { 140 Ok(stream) => stream, 141 Err(error) => { 142 tracing::error!("install SIGHUP handler: {error}"); 143 return; 144 } 145 }; 146 loop { 147 tokio::select! { 148 () = shutdown.cancelled() => break, 149 received = hangup.recv() => { 150 if received.is_none() { 151 break; 152 } 153 match load_certified_key(&cert_path, &key_path) { 154 Ok(key) => { 155 resolver.store(key); 156 tracing::info!("reloaded TLS certificate on SIGHUP"); 157 } 158 Err(error) => { 159 tracing::warn!("SIGHUP reload kept the existing certificate: {error}"); 160 } 161 } 162 } 163 } 164 } 165 }); 166 167 #[cfg(not(unix))] 168 let _ = (resolver, cert_path, key_path, shutdown); 169} 170 171pub fn load_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> { 172 let certs = load_certs(cert_path)?; 173 let key = load_private_key(key_path)?; 174 let signing_key = aws_lc_rs::sign::any_supported_type(&key) 175 .map_err(|error| TlsError::SigningKey(error.to_string()))?; 176 let certified = CertifiedKey::new(certs, signing_key); 177 certified 178 .keys_match() 179 .map_err(|error| TlsError::KeyMismatch(error.to_string()))?; 180 Ok(certified) 181} 182 183fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> { 184 let bytes = std::fs::read(path).map_err(|source| TlsError::Read { 185 path: path.display().to_string(), 186 source, 187 })?; 188 let mut reader = BufReader::new(bytes.as_slice()); 189 let certs = rustls_pemfile::certs(&mut reader) 190 .collect::<Result<Vec<_>, _>>() 191 .map_err(|error| TlsError::Parse { 192 path: path.display().to_string(), 193 message: error.to_string(), 194 })?; 195 match certs.is_empty() { 196 true => Err(TlsError::NoCertificates(path.display().to_string())), 197 false => Ok(certs), 198 } 199} 200 201fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> { 202 let bytes = std::fs::read(path).map_err(|source| TlsError::Read { 203 path: path.display().to_string(), 204 source, 205 })?; 206 let mut reader = BufReader::new(bytes.as_slice()); 207 rustls_pemfile::private_key(&mut reader) 208 .map_err(|error| TlsError::Parse { 209 path: path.display().to_string(), 210 message: error.to_string(), 211 })? 212 .ok_or_else(|| TlsError::NoPrivateKey(path.display().to_string())) 213} 214 215fn ticketer() -> Result<Arc<dyn rustls::server::ProducesTickets>, TlsError> { 216 aws_lc_rs::Ticketer::new().map_err(|error| TlsError::Ticketer(error.to_string())) 217} 218 219fn tcp_alpn(extra: &[&[u8]]) -> Vec<Vec<u8>> { 220 [b"h2".as_slice(), b"http/1.1".as_slice()] 221 .into_iter() 222 .chain(extra.iter().copied()) 223 .map(<[u8]>::to_vec) 224 .collect() 225} 226 227pub fn build_tls_server_config( 228 resolver: Arc<dyn ResolvesServerCert>, 229 extra_alpn: &[&[u8]], 230) -> Result<ServerConfig, TlsError> { 231 let provider = Arc::new(aws_lc_rs::default_provider()); 232 let mut config = ServerConfig::builder_with_provider(provider) 233 .with_safe_default_protocol_versions() 234 .map_err(|error| TlsError::Config(error.to_string()))? 235 .with_no_client_auth() 236 .with_cert_resolver(resolver); 237 config.alpn_protocols = tcp_alpn(extra_alpn); 238 config.ticketer = ticketer()?; 239 Ok(config) 240} 241 242pub fn build_mtls_server_config( 243 resolver: Arc<dyn ResolvesServerCert>, 244 client_ca_path: &Path, 245 pin: SpkiPin, 246) -> Result<ServerConfig, TlsError> { 247 let provider = Arc::new(aws_lc_rs::default_provider()); 248 let roots = load_client_ca(client_ca_path)?; 249 let webpki = 250 WebPkiClientVerifier::builder_with_provider(Arc::new(roots), Arc::clone(&provider)) 251 .build() 252 .map_err(|error| TlsError::ClientVerifier { 253 path: client_ca_path.display().to_string(), 254 message: error.to_string(), 255 })?; 256 let verifier = Arc::new(PinnedClientVerifier { inner: webpki, pin }); 257 let mut config = ServerConfig::builder_with_provider(provider) 258 .with_safe_default_protocol_versions() 259 .map_err(|error| TlsError::Config(error.to_string()))? 260 .with_client_cert_verifier(verifier) 261 .with_cert_resolver(resolver); 262 config.alpn_protocols = tcp_alpn(&[]); 263 config.ticketer = ticketer()?; 264 Ok(config) 265} 266 267fn load_client_ca(path: &Path) -> Result<RootCertStore, TlsError> { 268 let certs = load_certs(path)?; 269 let mut roots = RootCertStore::empty(); 270 let (added, _) = roots.add_parsable_certificates(certs); 271 match added { 272 0 => Err(TlsError::NoCertificates(path.display().to_string())), 273 _ => Ok(roots), 274 } 275} 276 277#[derive(Debug)] 278struct PinnedClientVerifier { 279 inner: Arc<dyn ClientCertVerifier>, 280 pin: SpkiPin, 281} 282 283impl ClientCertVerifier for PinnedClientVerifier { 284 fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] { 285 self.inner.root_hint_subjects() 286 } 287 288 fn offer_client_auth(&self) -> bool { 289 self.inner.offer_client_auth() 290 } 291 292 fn client_auth_mandatory(&self) -> bool { 293 self.inner.client_auth_mandatory() 294 } 295 296 fn verify_client_cert( 297 &self, 298 end_entity: &CertificateDer<'_>, 299 intermediates: &[CertificateDer<'_>], 300 now: UnixTime, 301 ) -> Result<ClientCertVerified, rustls::Error> { 302 let verified = self 303 .inner 304 .verify_client_cert(end_entity, intermediates, now)?; 305 match SpkiPin::of_certificate(end_entity)? == self.pin { 306 true => Ok(verified), 307 false => Err(rustls::Error::General( 308 "client certificate SPKI does not match the pinned admin identity".to_string(), 309 )), 310 } 311 } 312 313 fn verify_tls12_signature( 314 &self, 315 message: &[u8], 316 cert: &CertificateDer<'_>, 317 dss: &rustls::DigitallySignedStruct, 318 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> { 319 self.inner.verify_tls12_signature(message, cert, dss) 320 } 321 322 fn verify_tls13_signature( 323 &self, 324 message: &[u8], 325 cert: &CertificateDer<'_>, 326 dss: &rustls::DigitallySignedStruct, 327 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> { 328 self.inner.verify_tls13_signature(message, cert, dss) 329 } 330 331 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> { 332 self.inner.supported_verify_schemes() 333 } 334} 335 336pub fn build_quic_server_config( 337 resolver: Arc<dyn ResolvesServerCert>, 338 limits: ListenLimits, 339 early_data: EarlyDataPolicy, 340) -> Result<quinn::ServerConfig, TlsError> { 341 let provider = Arc::new(aws_lc_rs::default_provider()); 342 let mut crypto = ServerConfig::builder_with_provider(provider) 343 .with_protocol_versions(&[&rustls::version::TLS13]) 344 .map_err(|error| TlsError::Config(error.to_string()))? 345 .with_no_client_auth() 346 .with_cert_resolver(resolver); 347 crypto.alpn_protocols = vec![b"h3".to_vec()]; 348 crypto.max_early_data_size = early_data.max_early_data_size(); 349 350 let quic_crypto = 351 QuicServerConfig::try_from(crypto).map_err(|error| TlsError::Config(error.to_string()))?; 352 let mut config = quinn::ServerConfig::with_crypto(Arc::new(quic_crypto)); 353 354 let mut transport = quinn::TransportConfig::default(); 355 transport.max_concurrent_bidi_streams(quinn::VarInt::from_u32(MAX_CONCURRENT_STREAMS)); 356 transport.stream_receive_window(quinn::VarInt::from_u32(STREAM_RECEIVE_WINDOW)); 357 transport.receive_window(quinn::VarInt::from_u32(CONNECTION_RECEIVE_WINDOW)); 358 let idle = limits.idle_timeout(); 359 transport.max_idle_timeout(Some( 360 quinn::IdleTimeout::try_from(idle).map_err(|error| TlsError::Config(error.to_string()))?, 361 )); 362 transport.keep_alive_interval(Some(idle / 2)); 363 config.transport_config(Arc::new(transport)); 364 Ok(config) 365} 366 367pub const fn max_concurrent_streams() -> u32 { 368 MAX_CONCURRENT_STREAMS 369} 370 371#[cfg(test)] 372pub(crate) mod test_support { 373 use std::num::{NonZeroU32, NonZeroU64}; 374 375 use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer}; 376 use rustls::sign::CertifiedKey; 377 378 use super::*; 379 use crate::limits::ListenLimits; 380 381 pub(crate) fn self_signed() -> CertifiedKey { 382 let generated = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap(); 383 let cert_der = generated.cert.der().clone(); 384 let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from( 385 generated.signing_key.serialize_der(), 386 )); 387 let signing_key = aws_lc_rs::sign::any_supported_type(&key_der).unwrap(); 388 CertifiedKey::new(vec![cert_der], signing_key) 389 } 390 391 pub(crate) fn resolver() -> Arc<ReloadableCertResolver> { 392 Arc::new(ReloadableCertResolver::new(self_signed())) 393 } 394 395 pub(crate) fn limits() -> ListenLimits { 396 ListenLimits::new( 397 NonZeroU64::new(5_000).unwrap(), 398 NonZeroU64::new(30_000).unwrap(), 399 NonZeroU32::new(64).unwrap(), 400 ) 401 } 402 403 #[derive(Debug)] 404 pub(crate) struct AcceptAnyServerCert; 405 406 impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCert { 407 fn verify_server_cert( 408 &self, 409 _end_entity: &rustls::pki_types::CertificateDer<'_>, 410 _intermediates: &[rustls::pki_types::CertificateDer<'_>], 411 _server_name: &rustls::pki_types::ServerName<'_>, 412 _ocsp_response: &[u8], 413 _now: rustls::pki_types::UnixTime, 414 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> { 415 Ok(rustls::client::danger::ServerCertVerified::assertion()) 416 } 417 418 fn verify_tls12_signature( 419 &self, 420 message: &[u8], 421 cert: &rustls::pki_types::CertificateDer<'_>, 422 dss: &rustls::DigitallySignedStruct, 423 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> { 424 rustls::crypto::verify_tls12_signature( 425 message, 426 cert, 427 dss, 428 &aws_lc_rs::default_provider().signature_verification_algorithms, 429 ) 430 } 431 432 fn verify_tls13_signature( 433 &self, 434 message: &[u8], 435 cert: &rustls::pki_types::CertificateDer<'_>, 436 dss: &rustls::DigitallySignedStruct, 437 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> { 438 rustls::crypto::verify_tls13_signature( 439 message, 440 cert, 441 dss, 442 &aws_lc_rs::default_provider().signature_verification_algorithms, 443 ) 444 } 445 446 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> { 447 aws_lc_rs::default_provider() 448 .signature_verification_algorithms 449 .supported_schemes() 450 } 451 } 452} 453 454#[cfg(test)] 455mod tests { 456 use super::*; 457 458 #[test] 459 fn the_tcp_alpn_set_offers_h2_and_http1_but_not_h3() { 460 let config = build_tls_server_config(test_support::resolver(), &[]).unwrap(); 461 assert_eq!( 462 config.alpn_protocols, 463 vec![b"h2".to_vec(), b"http/1.1".to_vec()], 464 "h3 is QUIC-only and must never appear in the TCP ALPN set" 465 ); 466 } 467 468 #[test] 469 fn the_acme_challenge_alpn_joins_only_when_requested() { 470 let config = build_tls_server_config(test_support::resolver(), &[ACME_TLS_ALPN]).unwrap(); 471 assert_eq!( 472 config.alpn_protocols, 473 vec![b"h2".to_vec(), b"http/1.1".to_vec(), ACME_TLS_ALPN.to_vec()], 474 "acme-tls/1 must trail h2 and http/1.1 so normal clients never select it" 475 ); 476 } 477 478 #[test] 479 fn the_tcp_config_installs_an_enabled_session_ticketer() { 480 let config = build_tls_server_config(test_support::resolver(), &[]).unwrap(); 481 assert!( 482 config.ticketer.enabled(), 483 "session resumption requires an enabled ticketer" 484 ); 485 } 486 487 #[test] 488 fn an_spki_pin_round_trips_through_base64() { 489 let encoded = base64::engine::general_purpose::STANDARD.encode([9u8; 32]); 490 assert_eq!(SpkiPin::from_base64(&encoded).unwrap(), SpkiPin([9u8; 32])); 491 } 492 493 #[test] 494 fn the_spki_pin_matches_the_standard_openssl_recipe() { 495 const CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\n\ 496MIIBhDCCASugAwIBAgIUSkWE4CZvV8B8z9phedKo1PbahDUwCgYIKoZIzj0EAwIw\n\ 497GDEWMBQGA1UEAwwNYW5lbW9uZS5hZG1pbjAeFw0yNjA2MjExOTA5MDdaFw0zNjA2\n\ 498MTgxOTA5MDdaMBgxFjAUBgNVBAMMDWFuZW1vbmUuYWRtaW4wWTATBgcqhkjOPQIB\n\ 499BggqhkjOPQMBBwNCAASFLKd70MtSGSyI2UjdpQyjaJrvXLofac41nI346wK0lC9G\n\ 500PjZH/NKqo1iwQn+UfZB7gotfezWrDmAUz5OgT6Rlo1MwUTAdBgNVHQ4EFgQUis7S\n\ 501XEFGpe4gQwWnzX/uzjpt274wHwYDVR0jBBgwFoAUis7SXEFGpe4gQwWnzX/uzjpt\n\ 502274wDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBhJgE4cMP5/FJw\n\ 503imc3fYQxOhm5nO59cfG06+0vuDIV1QIgMsKZjFjsch8rbLRNiJL5+bDmlgO7MD14\n\ 5040PAyPOyjb+w=\n\ 505-----END CERTIFICATE-----\n"; 506 const OPENSSL_PIN: &str = "EhmM1HyWzC54br06EDvoAaqt1q1h+je3vVFJTcZ9e1U="; 507 508 let der = rustls_pemfile::certs(&mut CERT_PEM.as_bytes()) 509 .next() 510 .unwrap() 511 .unwrap(); 512 assert_eq!( 513 SpkiPin::of_certificate(&der).unwrap(), 514 SpkiPin::from_base64(OPENSSL_PIN).unwrap(), 515 "of_certificate must hash the same SubjectPublicKeyInfo bytes as openssl pkey -pubin -outform DER | dgst -sha256" 516 ); 517 } 518 519 #[test] 520 fn an_spki_pin_of_the_wrong_length_is_rejected() { 521 let encoded = base64::engine::general_purpose::STANDARD.encode([9u8; 16]); 522 assert!(matches!( 523 SpkiPin::from_base64(&encoded), 524 Err(TlsError::SpkiPin(_)) 525 )); 526 } 527 528 #[test] 529 fn the_quic_alpn_set_offers_only_h3() { 530 let quic = build_quic_server_config( 531 test_support::resolver(), 532 test_support::limits(), 533 EarlyDataPolicy::Disabled, 534 ); 535 assert!(quic.is_ok()); 536 } 537 538 #[test] 539 fn the_default_provider_prefers_post_quantum_key_exchange() { 540 use rustls::NamedGroup; 541 542 let provider = aws_lc_rs::default_provider(); 543 let first = provider.kx_groups.first().expect("a key exchange group"); 544 assert_eq!( 545 first.name(), 546 NamedGroup::X25519MLKEM768, 547 "prefer-post-quantum must order X25519MLKEM768 first for both the TCP and QUIC configs" 548 ); 549 } 550 551 #[test] 552 fn a_mismatched_certificate_and_key_is_rejected() { 553 let first = test_support::self_signed(); 554 let second = test_support::self_signed(); 555 let mismatched = CertifiedKey::new(first.cert.clone(), second.key.clone()); 556 assert!(mismatched.keys_match().is_err()); 557 } 558 559 struct ClientIdentity { 560 ca_path: std::path::PathBuf, 561 chain: Vec<CertificateDer<'static>>, 562 key_der: Vec<u8>, 563 pin: SpkiPin, 564 } 565 566 fn issue_client_identity() -> ClientIdentity { 567 let ca_key = rcgen::KeyPair::generate().unwrap(); 568 let mut ca_params = rcgen::CertificateParams::new(Vec::<String>::new()).unwrap(); 569 ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); 570 let ca_cert = ca_params.self_signed(&ca_key).unwrap(); 571 let issuer = rcgen::Issuer::new(ca_params, ca_key); 572 573 let client_key = rcgen::KeyPair::generate().unwrap(); 574 let client_params = rcgen::CertificateParams::new(vec!["admin.knot".to_string()]).unwrap(); 575 let client_cert = client_params.signed_by(&client_key, &issuer).unwrap(); 576 let client_der = client_cert.der().clone(); 577 578 let ca_path = std::env::temp_dir().join(format!( 579 "knot_edge_mtls_ca_{}_{:p}.pem", 580 std::process::id(), 581 &client_der as *const _ 582 )); 583 std::fs::write(&ca_path, ca_cert.pem()).unwrap(); 584 585 ClientIdentity { 586 ca_path, 587 pin: SpkiPin::of_certificate(&client_der).unwrap(), 588 chain: vec![client_der], 589 key_der: client_key.serialize_der(), 590 } 591 } 592 593 async fn mutual_handshake_succeeds( 594 server_config: ServerConfig, 595 identity: &ClientIdentity, 596 ) -> bool { 597 use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer, ServerName}; 598 use tokio::net::{TcpListener, TcpStream}; 599 use tokio_rustls::{TlsAcceptor, TlsConnector}; 600 601 let acceptor = TlsAcceptor::from(Arc::new(server_config)); 602 let listener = TcpListener::bind("[::1]:0").await.unwrap(); 603 let addr = listener.local_addr().unwrap(); 604 let server = tokio::spawn(async move { 605 let (tcp, _) = listener.accept().await.unwrap(); 606 acceptor.accept(tcp).await.is_ok() 607 }); 608 609 let mut client_config = 610 rustls::ClientConfig::builder_with_provider(Arc::new(aws_lc_rs::default_provider())) 611 .with_safe_default_protocol_versions() 612 .unwrap() 613 .dangerous() 614 .with_custom_certificate_verifier(Arc::new(test_support::AcceptAnyServerCert)) 615 .with_client_auth_cert( 616 identity.chain.clone(), 617 PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(identity.key_der.clone())), 618 ) 619 .unwrap(); 620 client_config.alpn_protocols = vec![b"h2".to_vec()]; 621 let connector = TlsConnector::from(Arc::new(client_config)); 622 let tcp = TcpStream::connect(addr).await.unwrap(); 623 let client_ok = connector 624 .connect(ServerName::try_from("localhost").unwrap(), tcp) 625 .await 626 .is_ok(); 627 let server_ok = server.await.unwrap(); 628 client_ok && server_ok 629 } 630 631 async fn handshake_without_client_cert(server_config: ServerConfig) -> bool { 632 use rustls::pki_types::ServerName; 633 use tokio::net::{TcpListener, TcpStream}; 634 use tokio_rustls::{TlsAcceptor, TlsConnector}; 635 636 let acceptor = TlsAcceptor::from(Arc::new(server_config)); 637 let listener = TcpListener::bind("[::1]:0").await.unwrap(); 638 let addr = listener.local_addr().unwrap(); 639 let server = tokio::spawn(async move { 640 let (tcp, _) = listener.accept().await.unwrap(); 641 acceptor.accept(tcp).await.is_ok() 642 }); 643 644 let mut client_config = 645 rustls::ClientConfig::builder_with_provider(Arc::new(aws_lc_rs::default_provider())) 646 .with_safe_default_protocol_versions() 647 .unwrap() 648 .dangerous() 649 .with_custom_certificate_verifier(Arc::new(test_support::AcceptAnyServerCert)) 650 .with_no_client_auth(); 651 client_config.alpn_protocols = vec![b"h2".to_vec()]; 652 let connector = TlsConnector::from(Arc::new(client_config)); 653 let tcp = TcpStream::connect(addr).await.unwrap(); 654 let client_ok = connector 655 .connect(ServerName::try_from("localhost").unwrap(), tcp) 656 .await 657 .is_ok(); 658 let server_ok = server.await.unwrap(); 659 client_ok && server_ok 660 } 661 662 #[tokio::test] 663 async fn mtls_rejects_a_client_presenting_no_certificate() { 664 let identity = issue_client_identity(); 665 let config = build_mtls_server_config( 666 test_support::resolver(), 667 &identity.ca_path, 668 identity.pin.clone(), 669 ) 670 .unwrap(); 671 assert!( 672 !handshake_without_client_cert(config).await, 673 "the mandatory mTLS verifier must reject a client that presents no certificate" 674 ); 675 } 676 677 #[tokio::test] 678 async fn mtls_admits_the_pinned_admin_certificate() { 679 let identity = issue_client_identity(); 680 let config = build_mtls_server_config( 681 test_support::resolver(), 682 &identity.ca_path, 683 identity.pin.clone(), 684 ) 685 .unwrap(); 686 assert!( 687 mutual_handshake_succeeds(config, &identity).await, 688 "a client presenting the pinned admin certificate must complete the mTLS handshake" 689 ); 690 } 691 692 #[tokio::test] 693 async fn mtls_rejects_a_ca_trusted_client_whose_spki_is_not_pinned() { 694 let identity = issue_client_identity(); 695 let config = build_mtls_server_config( 696 test_support::resolver(), 697 &identity.ca_path, 698 SpkiPin([0u8; 32]), 699 ) 700 .unwrap(); 701 assert!( 702 !mutual_handshake_succeeds(config, &identity).await, 703 "a client trusted by the CA but failing the SPKI pin must be rejected" 704 ); 705 } 706}