Now let's take a silly one
1use std::io::BufReader;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use arc_swap::ArcSwap;
6use base64::Engine;
7use quinn::crypto::rustls::QuicServerConfig;
8use rustls::RootCertStore;
9use rustls::ServerConfig;
10use rustls::crypto::aws_lc_rs;
11use rustls::pki_types::{CertificateDer, PrivateKeyDer, UnixTime};
12use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
13use rustls::server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier};
14use rustls::sign::CertifiedKey;
15use sha2::{Digest, Sha256};
16use tokio_util::sync::CancellationToken;
17
18use crate::limits::ListenLimits;
19use crate::zerortt::EarlyDataPolicy;
20
21pub const ACME_TLS_ALPN: &[u8] = rustls_acme::acme::ACME_TLS_ALPN_NAME;
22
23const MAX_CONCURRENT_STREAMS: u32 = 256;
24const STREAM_RECEIVE_WINDOW: u32 = 8 * 1024 * 1024;
25const CONNECTION_RECEIVE_WINDOW: u32 = 32 * 1024 * 1024;
26
27#[derive(Debug, thiserror::Error)]
28pub enum TlsError {
29 #[error("reading {path}: {source}")]
30 Read {
31 path: String,
32 source: std::io::Error,
33 },
34 #[error("parsing {path}: {message}")]
35 Parse { path: String, message: String },
36 #[error("no certificates found in {0}")]
37 NoCertificates(String),
38 #[error("no private key found in {0}")]
39 NoPrivateKey(String),
40 #[error("unusable private key: {0}")]
41 SigningKey(String),
42 #[error("building server config: {0}")]
43 Config(String),
44 #[error("certificate and private key do not match: {0}")]
45 KeyMismatch(String),
46 #[error("session ticketer: {0}")]
47 Ticketer(String),
48 #[error("client certificate verifier for {path}: {message}")]
49 ClientVerifier { path: String, message: String },
50 #[error("admin SPKI pin: {0}")]
51 SpkiPin(String),
52}
53
54#[derive(Clone)]
55pub struct SpkiPin([u8; 32]);
56
57impl PartialEq for SpkiPin {
58 fn eq(&self, other: &Self) -> bool {
59 self.0
60 .iter()
61 .zip(other.0.iter())
62 .fold(0u8, |acc, (left, right)| acc | (left ^ right))
63 == 0
64 }
65}
66
67impl Eq for SpkiPin {}
68
69impl SpkiPin {
70 pub fn from_base64(encoded: &str) -> Result<Self, TlsError> {
71 let bytes = base64::engine::general_purpose::STANDARD
72 .decode(encoded.trim())
73 .map_err(|error| TlsError::SpkiPin(error.to_string()))?;
74 let array: [u8; 32] = bytes.try_into().map_err(|bytes: Vec<u8>| {
75 TlsError::SpkiPin(format!("expected 32 bytes, got {}", bytes.len()))
76 })?;
77 Ok(Self(array))
78 }
79
80 fn of_certificate(cert: &CertificateDer<'_>) -> Result<Self, rustls::Error> {
81 let (_, parsed) = x509_parser::parse_x509_certificate(cert.as_ref()).map_err(|error| {
82 rustls::Error::General(format!("parse client certificate: {error}"))
83 })?;
84 Ok(Self(
85 Sha256::digest(parsed.tbs_certificate.subject_pki.raw).into(),
86 ))
87 }
88}
89
90impl std::fmt::Debug for SpkiPin {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_tuple("SpkiPin").finish_non_exhaustive()
93 }
94}
95
96pub(crate) fn fuzz_of_certificate(data: &[u8]) {
97 let cert = CertificateDer::from(data.to_vec());
98 let _ = SpkiPin::of_certificate(&cert);
99}
100
101pub struct ReloadableCertResolver {
102 current: ArcSwap<CertifiedKey>,
103}
104
105impl std::fmt::Debug for ReloadableCertResolver {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.debug_struct("ReloadableCertResolver")
108 .finish_non_exhaustive()
109 }
110}
111
112impl ReloadableCertResolver {
113 pub fn new(initial: CertifiedKey) -> Self {
114 Self {
115 current: ArcSwap::from_pointee(initial),
116 }
117 }
118
119 pub fn store(&self, key: CertifiedKey) {
120 self.current.store(Arc::new(key));
121 }
122}
123
124impl ResolvesServerCert for ReloadableCertResolver {
125 fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
126 Some(self.current.load_full())
127 }
128}
129
130pub fn spawn_cert_reload(
131 resolver: Arc<ReloadableCertResolver>,
132 cert_path: PathBuf,
133 key_path: PathBuf,
134 shutdown: CancellationToken,
135) {
136 #[cfg(unix)]
137 tokio::spawn(async move {
138 use tokio::signal::unix::{SignalKind, signal};
139 let mut hangup = match signal(SignalKind::hangup()) {
140 Ok(stream) => stream,
141 Err(error) => {
142 tracing::error!("install SIGHUP handler: {error}");
143 return;
144 }
145 };
146 loop {
147 tokio::select! {
148 () = shutdown.cancelled() => break,
149 received = hangup.recv() => {
150 if received.is_none() {
151 break;
152 }
153 match load_certified_key(&cert_path, &key_path) {
154 Ok(key) => {
155 resolver.store(key);
156 tracing::info!("reloaded TLS certificate on SIGHUP");
157 }
158 Err(error) => {
159 tracing::warn!("SIGHUP reload kept the existing certificate: {error}");
160 }
161 }
162 }
163 }
164 }
165 });
166
167 #[cfg(not(unix))]
168 let _ = (resolver, cert_path, key_path, shutdown);
169}
170
171pub fn load_certified_key(cert_path: &Path, key_path: &Path) -> Result<CertifiedKey, TlsError> {
172 let certs = load_certs(cert_path)?;
173 let key = load_private_key(key_path)?;
174 let signing_key = aws_lc_rs::sign::any_supported_type(&key)
175 .map_err(|error| TlsError::SigningKey(error.to_string()))?;
176 let certified = CertifiedKey::new(certs, signing_key);
177 certified
178 .keys_match()
179 .map_err(|error| TlsError::KeyMismatch(error.to_string()))?;
180 Ok(certified)
181}
182
183fn load_certs(path: &Path) -> Result<Vec<CertificateDer<'static>>, TlsError> {
184 let bytes = std::fs::read(path).map_err(|source| TlsError::Read {
185 path: path.display().to_string(),
186 source,
187 })?;
188 let mut reader = BufReader::new(bytes.as_slice());
189 let certs = rustls_pemfile::certs(&mut reader)
190 .collect::<Result<Vec<_>, _>>()
191 .map_err(|error| TlsError::Parse {
192 path: path.display().to_string(),
193 message: error.to_string(),
194 })?;
195 match certs.is_empty() {
196 true => Err(TlsError::NoCertificates(path.display().to_string())),
197 false => Ok(certs),
198 }
199}
200
201fn load_private_key(path: &Path) -> Result<PrivateKeyDer<'static>, TlsError> {
202 let bytes = std::fs::read(path).map_err(|source| TlsError::Read {
203 path: path.display().to_string(),
204 source,
205 })?;
206 let mut reader = BufReader::new(bytes.as_slice());
207 rustls_pemfile::private_key(&mut reader)
208 .map_err(|error| TlsError::Parse {
209 path: path.display().to_string(),
210 message: error.to_string(),
211 })?
212 .ok_or_else(|| TlsError::NoPrivateKey(path.display().to_string()))
213}
214
215fn ticketer() -> Result<Arc<dyn rustls::server::ProducesTickets>, TlsError> {
216 aws_lc_rs::Ticketer::new().map_err(|error| TlsError::Ticketer(error.to_string()))
217}
218
219fn tcp_alpn(extra: &[&[u8]]) -> Vec<Vec<u8>> {
220 [b"h2".as_slice(), b"http/1.1".as_slice()]
221 .into_iter()
222 .chain(extra.iter().copied())
223 .map(<[u8]>::to_vec)
224 .collect()
225}
226
227pub fn build_tls_server_config(
228 resolver: Arc<dyn ResolvesServerCert>,
229 extra_alpn: &[&[u8]],
230) -> Result<ServerConfig, TlsError> {
231 let provider = Arc::new(aws_lc_rs::default_provider());
232 let mut config = ServerConfig::builder_with_provider(provider)
233 .with_safe_default_protocol_versions()
234 .map_err(|error| TlsError::Config(error.to_string()))?
235 .with_no_client_auth()
236 .with_cert_resolver(resolver);
237 config.alpn_protocols = tcp_alpn(extra_alpn);
238 config.ticketer = ticketer()?;
239 Ok(config)
240}
241
242pub fn build_mtls_server_config(
243 resolver: Arc<dyn ResolvesServerCert>,
244 client_ca_path: &Path,
245 pin: SpkiPin,
246) -> Result<ServerConfig, TlsError> {
247 let provider = Arc::new(aws_lc_rs::default_provider());
248 let roots = load_client_ca(client_ca_path)?;
249 let webpki =
250 WebPkiClientVerifier::builder_with_provider(Arc::new(roots), Arc::clone(&provider))
251 .build()
252 .map_err(|error| TlsError::ClientVerifier {
253 path: client_ca_path.display().to_string(),
254 message: error.to_string(),
255 })?;
256 let verifier = Arc::new(PinnedClientVerifier { inner: webpki, pin });
257 let mut config = ServerConfig::builder_with_provider(provider)
258 .with_safe_default_protocol_versions()
259 .map_err(|error| TlsError::Config(error.to_string()))?
260 .with_client_cert_verifier(verifier)
261 .with_cert_resolver(resolver);
262 config.alpn_protocols = tcp_alpn(&[]);
263 config.ticketer = ticketer()?;
264 Ok(config)
265}
266
267fn load_client_ca(path: &Path) -> Result<RootCertStore, TlsError> {
268 let certs = load_certs(path)?;
269 let mut roots = RootCertStore::empty();
270 let (added, _) = roots.add_parsable_certificates(certs);
271 match added {
272 0 => Err(TlsError::NoCertificates(path.display().to_string())),
273 _ => Ok(roots),
274 }
275}
276
277#[derive(Debug)]
278struct PinnedClientVerifier {
279 inner: Arc<dyn ClientCertVerifier>,
280 pin: SpkiPin,
281}
282
283impl ClientCertVerifier for PinnedClientVerifier {
284 fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] {
285 self.inner.root_hint_subjects()
286 }
287
288 fn offer_client_auth(&self) -> bool {
289 self.inner.offer_client_auth()
290 }
291
292 fn client_auth_mandatory(&self) -> bool {
293 self.inner.client_auth_mandatory()
294 }
295
296 fn verify_client_cert(
297 &self,
298 end_entity: &CertificateDer<'_>,
299 intermediates: &[CertificateDer<'_>],
300 now: UnixTime,
301 ) -> Result<ClientCertVerified, rustls::Error> {
302 let verified = self
303 .inner
304 .verify_client_cert(end_entity, intermediates, now)?;
305 match SpkiPin::of_certificate(end_entity)? == self.pin {
306 true => Ok(verified),
307 false => Err(rustls::Error::General(
308 "client certificate SPKI does not match the pinned admin identity".to_string(),
309 )),
310 }
311 }
312
313 fn verify_tls12_signature(
314 &self,
315 message: &[u8],
316 cert: &CertificateDer<'_>,
317 dss: &rustls::DigitallySignedStruct,
318 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
319 self.inner.verify_tls12_signature(message, cert, dss)
320 }
321
322 fn verify_tls13_signature(
323 &self,
324 message: &[u8],
325 cert: &CertificateDer<'_>,
326 dss: &rustls::DigitallySignedStruct,
327 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
328 self.inner.verify_tls13_signature(message, cert, dss)
329 }
330
331 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
332 self.inner.supported_verify_schemes()
333 }
334}
335
336pub fn build_quic_server_config(
337 resolver: Arc<dyn ResolvesServerCert>,
338 limits: ListenLimits,
339 early_data: EarlyDataPolicy,
340) -> Result<quinn::ServerConfig, TlsError> {
341 let provider = Arc::new(aws_lc_rs::default_provider());
342 let mut crypto = ServerConfig::builder_with_provider(provider)
343 .with_protocol_versions(&[&rustls::version::TLS13])
344 .map_err(|error| TlsError::Config(error.to_string()))?
345 .with_no_client_auth()
346 .with_cert_resolver(resolver);
347 crypto.alpn_protocols = vec![b"h3".to_vec()];
348 crypto.max_early_data_size = early_data.max_early_data_size();
349
350 let quic_crypto =
351 QuicServerConfig::try_from(crypto).map_err(|error| TlsError::Config(error.to_string()))?;
352 let mut config = quinn::ServerConfig::with_crypto(Arc::new(quic_crypto));
353
354 let mut transport = quinn::TransportConfig::default();
355 transport.max_concurrent_bidi_streams(quinn::VarInt::from_u32(MAX_CONCURRENT_STREAMS));
356 transport.stream_receive_window(quinn::VarInt::from_u32(STREAM_RECEIVE_WINDOW));
357 transport.receive_window(quinn::VarInt::from_u32(CONNECTION_RECEIVE_WINDOW));
358 let idle = limits.idle_timeout();
359 transport.max_idle_timeout(Some(
360 quinn::IdleTimeout::try_from(idle).map_err(|error| TlsError::Config(error.to_string()))?,
361 ));
362 transport.keep_alive_interval(Some(idle / 2));
363 config.transport_config(Arc::new(transport));
364 Ok(config)
365}
366
367pub const fn max_concurrent_streams() -> u32 {
368 MAX_CONCURRENT_STREAMS
369}
370
371#[cfg(test)]
372pub(crate) mod test_support {
373 use std::num::{NonZeroU32, NonZeroU64};
374
375 use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer};
376 use rustls::sign::CertifiedKey;
377
378 use super::*;
379 use crate::limits::ListenLimits;
380
381 pub(crate) fn self_signed() -> CertifiedKey {
382 let generated = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
383 let cert_der = generated.cert.der().clone();
384 let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(
385 generated.signing_key.serialize_der(),
386 ));
387 let signing_key = aws_lc_rs::sign::any_supported_type(&key_der).unwrap();
388 CertifiedKey::new(vec![cert_der], signing_key)
389 }
390
391 pub(crate) fn resolver() -> Arc<ReloadableCertResolver> {
392 Arc::new(ReloadableCertResolver::new(self_signed()))
393 }
394
395 pub(crate) fn limits() -> ListenLimits {
396 ListenLimits::new(
397 NonZeroU64::new(5_000).unwrap(),
398 NonZeroU64::new(30_000).unwrap(),
399 NonZeroU32::new(64).unwrap(),
400 )
401 }
402
403 #[derive(Debug)]
404 pub(crate) struct AcceptAnyServerCert;
405
406 impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCert {
407 fn verify_server_cert(
408 &self,
409 _end_entity: &rustls::pki_types::CertificateDer<'_>,
410 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
411 _server_name: &rustls::pki_types::ServerName<'_>,
412 _ocsp_response: &[u8],
413 _now: rustls::pki_types::UnixTime,
414 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
415 Ok(rustls::client::danger::ServerCertVerified::assertion())
416 }
417
418 fn verify_tls12_signature(
419 &self,
420 message: &[u8],
421 cert: &rustls::pki_types::CertificateDer<'_>,
422 dss: &rustls::DigitallySignedStruct,
423 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
424 rustls::crypto::verify_tls12_signature(
425 message,
426 cert,
427 dss,
428 &aws_lc_rs::default_provider().signature_verification_algorithms,
429 )
430 }
431
432 fn verify_tls13_signature(
433 &self,
434 message: &[u8],
435 cert: &rustls::pki_types::CertificateDer<'_>,
436 dss: &rustls::DigitallySignedStruct,
437 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
438 rustls::crypto::verify_tls13_signature(
439 message,
440 cert,
441 dss,
442 &aws_lc_rs::default_provider().signature_verification_algorithms,
443 )
444 }
445
446 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
447 aws_lc_rs::default_provider()
448 .signature_verification_algorithms
449 .supported_schemes()
450 }
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn the_tcp_alpn_set_offers_h2_and_http1_but_not_h3() {
460 let config = build_tls_server_config(test_support::resolver(), &[]).unwrap();
461 assert_eq!(
462 config.alpn_protocols,
463 vec![b"h2".to_vec(), b"http/1.1".to_vec()],
464 "h3 is QUIC-only and must never appear in the TCP ALPN set"
465 );
466 }
467
468 #[test]
469 fn the_acme_challenge_alpn_joins_only_when_requested() {
470 let config = build_tls_server_config(test_support::resolver(), &[ACME_TLS_ALPN]).unwrap();
471 assert_eq!(
472 config.alpn_protocols,
473 vec![b"h2".to_vec(), b"http/1.1".to_vec(), ACME_TLS_ALPN.to_vec()],
474 "acme-tls/1 must trail h2 and http/1.1 so normal clients never select it"
475 );
476 }
477
478 #[test]
479 fn the_tcp_config_installs_an_enabled_session_ticketer() {
480 let config = build_tls_server_config(test_support::resolver(), &[]).unwrap();
481 assert!(
482 config.ticketer.enabled(),
483 "session resumption requires an enabled ticketer"
484 );
485 }
486
487 #[test]
488 fn an_spki_pin_round_trips_through_base64() {
489 let encoded = base64::engine::general_purpose::STANDARD.encode([9u8; 32]);
490 assert_eq!(SpkiPin::from_base64(&encoded).unwrap(), SpkiPin([9u8; 32]));
491 }
492
493 #[test]
494 fn the_spki_pin_matches_the_standard_openssl_recipe() {
495 const CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\n\
496MIIBhDCCASugAwIBAgIUSkWE4CZvV8B8z9phedKo1PbahDUwCgYIKoZIzj0EAwIw\n\
497GDEWMBQGA1UEAwwNYW5lbW9uZS5hZG1pbjAeFw0yNjA2MjExOTA5MDdaFw0zNjA2\n\
498MTgxOTA5MDdaMBgxFjAUBgNVBAMMDWFuZW1vbmUuYWRtaW4wWTATBgcqhkjOPQIB\n\
499BggqhkjOPQMBBwNCAASFLKd70MtSGSyI2UjdpQyjaJrvXLofac41nI346wK0lC9G\n\
500PjZH/NKqo1iwQn+UfZB7gotfezWrDmAUz5OgT6Rlo1MwUTAdBgNVHQ4EFgQUis7S\n\
501XEFGpe4gQwWnzX/uzjpt274wHwYDVR0jBBgwFoAUis7SXEFGpe4gQwWnzX/uzjpt\n\
502274wDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBhJgE4cMP5/FJw\n\
503imc3fYQxOhm5nO59cfG06+0vuDIV1QIgMsKZjFjsch8rbLRNiJL5+bDmlgO7MD14\n\
5040PAyPOyjb+w=\n\
505-----END CERTIFICATE-----\n";
506 const OPENSSL_PIN: &str = "EhmM1HyWzC54br06EDvoAaqt1q1h+je3vVFJTcZ9e1U=";
507
508 let der = rustls_pemfile::certs(&mut CERT_PEM.as_bytes())
509 .next()
510 .unwrap()
511 .unwrap();
512 assert_eq!(
513 SpkiPin::of_certificate(&der).unwrap(),
514 SpkiPin::from_base64(OPENSSL_PIN).unwrap(),
515 "of_certificate must hash the same SubjectPublicKeyInfo bytes as openssl pkey -pubin -outform DER | dgst -sha256"
516 );
517 }
518
519 #[test]
520 fn an_spki_pin_of_the_wrong_length_is_rejected() {
521 let encoded = base64::engine::general_purpose::STANDARD.encode([9u8; 16]);
522 assert!(matches!(
523 SpkiPin::from_base64(&encoded),
524 Err(TlsError::SpkiPin(_))
525 ));
526 }
527
528 #[test]
529 fn the_quic_alpn_set_offers_only_h3() {
530 let quic = build_quic_server_config(
531 test_support::resolver(),
532 test_support::limits(),
533 EarlyDataPolicy::Disabled,
534 );
535 assert!(quic.is_ok());
536 }
537
538 #[test]
539 fn the_default_provider_prefers_post_quantum_key_exchange() {
540 use rustls::NamedGroup;
541
542 let provider = aws_lc_rs::default_provider();
543 let first = provider.kx_groups.first().expect("a key exchange group");
544 assert_eq!(
545 first.name(),
546 NamedGroup::X25519MLKEM768,
547 "prefer-post-quantum must order X25519MLKEM768 first for both the TCP and QUIC configs"
548 );
549 }
550
551 #[test]
552 fn a_mismatched_certificate_and_key_is_rejected() {
553 let first = test_support::self_signed();
554 let second = test_support::self_signed();
555 let mismatched = CertifiedKey::new(first.cert.clone(), second.key.clone());
556 assert!(mismatched.keys_match().is_err());
557 }
558
559 struct ClientIdentity {
560 ca_path: std::path::PathBuf,
561 chain: Vec<CertificateDer<'static>>,
562 key_der: Vec<u8>,
563 pin: SpkiPin,
564 }
565
566 fn issue_client_identity() -> ClientIdentity {
567 let ca_key = rcgen::KeyPair::generate().unwrap();
568 let mut ca_params = rcgen::CertificateParams::new(Vec::<String>::new()).unwrap();
569 ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
570 let ca_cert = ca_params.self_signed(&ca_key).unwrap();
571 let issuer = rcgen::Issuer::new(ca_params, ca_key);
572
573 let client_key = rcgen::KeyPair::generate().unwrap();
574 let client_params = rcgen::CertificateParams::new(vec!["admin.knot".to_string()]).unwrap();
575 let client_cert = client_params.signed_by(&client_key, &issuer).unwrap();
576 let client_der = client_cert.der().clone();
577
578 let ca_path = std::env::temp_dir().join(format!(
579 "knot_edge_mtls_ca_{}_{:p}.pem",
580 std::process::id(),
581 &client_der as *const _
582 ));
583 std::fs::write(&ca_path, ca_cert.pem()).unwrap();
584
585 ClientIdentity {
586 ca_path,
587 pin: SpkiPin::of_certificate(&client_der).unwrap(),
588 chain: vec![client_der],
589 key_der: client_key.serialize_der(),
590 }
591 }
592
593 async fn mutual_handshake_succeeds(
594 server_config: ServerConfig,
595 identity: &ClientIdentity,
596 ) -> bool {
597 use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer, ServerName};
598 use tokio::net::{TcpListener, TcpStream};
599 use tokio_rustls::{TlsAcceptor, TlsConnector};
600
601 let acceptor = TlsAcceptor::from(Arc::new(server_config));
602 let listener = TcpListener::bind("[::1]:0").await.unwrap();
603 let addr = listener.local_addr().unwrap();
604 let server = tokio::spawn(async move {
605 let (tcp, _) = listener.accept().await.unwrap();
606 acceptor.accept(tcp).await.is_ok()
607 });
608
609 let mut client_config =
610 rustls::ClientConfig::builder_with_provider(Arc::new(aws_lc_rs::default_provider()))
611 .with_safe_default_protocol_versions()
612 .unwrap()
613 .dangerous()
614 .with_custom_certificate_verifier(Arc::new(test_support::AcceptAnyServerCert))
615 .with_client_auth_cert(
616 identity.chain.clone(),
617 PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(identity.key_der.clone())),
618 )
619 .unwrap();
620 client_config.alpn_protocols = vec![b"h2".to_vec()];
621 let connector = TlsConnector::from(Arc::new(client_config));
622 let tcp = TcpStream::connect(addr).await.unwrap();
623 let client_ok = connector
624 .connect(ServerName::try_from("localhost").unwrap(), tcp)
625 .await
626 .is_ok();
627 let server_ok = server.await.unwrap();
628 client_ok && server_ok
629 }
630
631 async fn handshake_without_client_cert(server_config: ServerConfig) -> bool {
632 use rustls::pki_types::ServerName;
633 use tokio::net::{TcpListener, TcpStream};
634 use tokio_rustls::{TlsAcceptor, TlsConnector};
635
636 let acceptor = TlsAcceptor::from(Arc::new(server_config));
637 let listener = TcpListener::bind("[::1]:0").await.unwrap();
638 let addr = listener.local_addr().unwrap();
639 let server = tokio::spawn(async move {
640 let (tcp, _) = listener.accept().await.unwrap();
641 acceptor.accept(tcp).await.is_ok()
642 });
643
644 let mut client_config =
645 rustls::ClientConfig::builder_with_provider(Arc::new(aws_lc_rs::default_provider()))
646 .with_safe_default_protocol_versions()
647 .unwrap()
648 .dangerous()
649 .with_custom_certificate_verifier(Arc::new(test_support::AcceptAnyServerCert))
650 .with_no_client_auth();
651 client_config.alpn_protocols = vec![b"h2".to_vec()];
652 let connector = TlsConnector::from(Arc::new(client_config));
653 let tcp = TcpStream::connect(addr).await.unwrap();
654 let client_ok = connector
655 .connect(ServerName::try_from("localhost").unwrap(), tcp)
656 .await
657 .is_ok();
658 let server_ok = server.await.unwrap();
659 client_ok && server_ok
660 }
661
662 #[tokio::test]
663 async fn mtls_rejects_a_client_presenting_no_certificate() {
664 let identity = issue_client_identity();
665 let config = build_mtls_server_config(
666 test_support::resolver(),
667 &identity.ca_path,
668 identity.pin.clone(),
669 )
670 .unwrap();
671 assert!(
672 !handshake_without_client_cert(config).await,
673 "the mandatory mTLS verifier must reject a client that presents no certificate"
674 );
675 }
676
677 #[tokio::test]
678 async fn mtls_admits_the_pinned_admin_certificate() {
679 let identity = issue_client_identity();
680 let config = build_mtls_server_config(
681 test_support::resolver(),
682 &identity.ca_path,
683 identity.pin.clone(),
684 )
685 .unwrap();
686 assert!(
687 mutual_handshake_succeeds(config, &identity).await,
688 "a client presenting the pinned admin certificate must complete the mTLS handshake"
689 );
690 }
691
692 #[tokio::test]
693 async fn mtls_rejects_a_ca_trusted_client_whose_spki_is_not_pinned() {
694 let identity = issue_client_identity();
695 let config = build_mtls_server_config(
696 test_support::resolver(),
697 &identity.ca_path,
698 SpkiPin([0u8; 32]),
699 )
700 .unwrap();
701 assert!(
702 !mutual_handshake_succeeds(config, &identity).await,
703 "a client trusted by the CA but failing the SPKI pin must be rejected"
704 );
705 }
706}