Now let's take a silly one
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 22 kB View raw
1use std::net::SocketAddr; 2use std::sync::Arc; 3use std::time::Duration; 4 5use axum::Router; 6use axum::body::Body; 7use axum::extract::ConnectInfo; 8use bytes::{Buf, Bytes}; 9use futures::{FutureExt, StreamExt}; 10use http::{Request, Response, Version}; 11use quinn::{Endpoint, Incoming}; 12use tokio::sync::Semaphore; 13use tokio_util::sync::CancellationToken; 14use tokio_util::task::TaskTracker; 15use tower::ServiceExt; 16 17use rustls::server::ResolvesServerCert; 18 19use crate::limits::ListenLimits; 20use crate::tls::{self, TlsError}; 21use crate::zerortt::{EarlyData, EarlyDataPolicy}; 22 23fn early_data(confirmed: bool) -> EarlyData { 24 match confirmed { 25 true => EarlyData::No, 26 false => EarlyData::Yes, 27 } 28} 29 30const LISTENER_DRAIN_GRACE: Duration = Duration::from_secs(30); 31const ENDPOINT_DRAIN_GRACE: Duration = Duration::from_secs(5); 32const CONNECTION_DRAIN_GRACE: Duration = Duration::from_secs(10); 33 34pub fn build_endpoint( 35 addr: SocketAddr, 36 resolver: Arc<dyn ResolvesServerCert>, 37 limits: ListenLimits, 38 early_data: EarlyDataPolicy, 39) -> Result<Endpoint, EndpointError> { 40 let server_config = tls::build_quic_server_config(resolver, limits, early_data)?; 41 let endpoint = Endpoint::server(server_config, addr)?; 42 Ok(endpoint) 43} 44 45#[derive(Debug, thiserror::Error)] 46pub enum EndpointError { 47 #[error(transparent)] 48 Tls(#[from] TlsError), 49 #[error("binding quic socket: {0}")] 50 Bind(#[from] std::io::Error), 51} 52 53pub async fn serve_http3( 54 endpoint: Endpoint, 55 app: Router, 56 limits: ListenLimits, 57 shutdown: CancellationToken, 58) { 59 let tracker = TaskTracker::new(); 60 let connections = Arc::new(Semaphore::new(limits.max_connections())); 61 loop { 62 let incoming = tokio::select! { 63 () = shutdown.cancelled() => break, 64 incoming = endpoint.accept() => incoming, 65 }; 66 let Some(incoming) = incoming else { break }; 67 let Ok(permit) = Arc::clone(&connections).try_acquire_owned() else { 68 incoming.refuse(); 69 continue; 70 }; 71 let app = app.clone(); 72 let conn_shutdown = shutdown.clone(); 73 let conn_tracker = tracker.clone(); 74 tracker.spawn(async move { 75 let _permit = permit; 76 if let Err(error) = serve_connection(incoming, app, conn_shutdown, conn_tracker).await { 77 tracing::debug!("h3 connection ended: {error}"); 78 } 79 }); 80 } 81 tracker.close(); 82 let _ = tokio::time::timeout(LISTENER_DRAIN_GRACE, tracker.wait()).await; 83 endpoint.close(0u32.into(), b"shutdown"); 84 let _ = tokio::time::timeout(ENDPOINT_DRAIN_GRACE, endpoint.wait_idle()).await; 85} 86 87async fn serve_connection( 88 incoming: Incoming, 89 app: Router, 90 shutdown: CancellationToken, 91 tracker: TaskTracker, 92) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 93 let (conn, confirmation, mut confirmed) = match incoming.accept()?.into_0rtt() { 94 Ok((conn, accepted)) => (conn, accepted.map(|_| ()).left_future(), false), 95 Err(connecting) => ( 96 connecting.await?, 97 std::future::pending::<()>().right_future(), 98 true, 99 ), 100 }; 101 tokio::pin!(confirmation); 102 let remote = conn.remote_address(); 103 let quic = conn.clone(); 104 let mut h3_conn = 105 h3::server::Connection::<_, Bytes>::new(h3_quinn::Connection::new(conn)).await?; 106 107 let drain_deadline = tokio::time::sleep(CONNECTION_DRAIN_GRACE); 108 tokio::pin!(drain_deadline); 109 let mut draining = false; 110 loop { 111 tokio::select! { 112 biased; 113 () = shutdown.cancelled(), if !draining => { 114 draining = true; 115 drain_deadline 116 .as_mut() 117 .reset(tokio::time::Instant::now() + CONNECTION_DRAIN_GRACE); 118 let _ = h3_conn.shutdown(0).await; 119 } 120 () = &mut drain_deadline, if draining => { 121 tracing::debug!("h3 connection from {remote} drain timed out, closing"); 122 quic.close(0u32.into(), b"drain timeout"); 123 break; 124 } 125 resolved = h3_conn.accept() => match resolved { 126 Ok(Some(resolver)) => { 127 let app = app.clone(); 128 let early = early_data(confirmed); 129 tracker.spawn(async move { 130 if let Err(error) = serve_request(resolver, app, remote, early).await { 131 tracing::debug!("h3 request from {remote} failed: {error}"); 132 } 133 }); 134 } 135 Ok(None) => break, 136 Err(error) => { 137 tracing::debug!("h3 accept from {remote} error: {error}"); 138 break; 139 } 140 }, 141 () = &mut confirmation, if !confirmed => { 142 confirmed = true; 143 } 144 } 145 } 146 Ok(()) 147} 148 149async fn serve_request( 150 resolver: h3::server::RequestResolver<h3_quinn::Connection, Bytes>, 151 app: Router, 152 remote: SocketAddr, 153 early: EarlyData, 154) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 155 let (request, stream) = resolver.resolve_request().await?; 156 let (mut send, recv) = stream.split(); 157 158 let (mut parts, ()) = request.into_parts(); 159 parts.version = Version::HTTP_3; 160 parts.extensions.insert(ConnectInfo(remote)); 161 parts.extensions.insert(early); 162 let request = Request::from_parts(parts, request_body(recv)); 163 164 let response = match app.oneshot(request).await { 165 Ok(response) => response, 166 Err(infallible) => match infallible {}, 167 }; 168 169 let (parts, body) = response.into_parts(); 170 send.send_response(Response::from_parts(parts, ())).await?; 171 172 let mut data = body.into_data_stream(); 173 while let Some(chunk) = data.next().await { 174 match chunk { 175 Ok(bytes) if bytes.has_remaining() => send.send_data(bytes).await?, 176 Ok(_) => {} 177 Err(error) => { 178 tracing::debug!("h3 response body to {remote} errored: {error}"); 179 send.stop_stream(h3::error::Code::H3_INTERNAL_ERROR); 180 return Ok(()); 181 } 182 } 183 } 184 send.finish().await?; 185 Ok(()) 186} 187 188struct RecvGuard { 189 stream: h3::server::RequestStream<h3_quinn::RecvStream, Bytes>, 190 ended: bool, 191} 192 193impl Drop for RecvGuard { 194 fn drop(&mut self) { 195 if !self.ended { 196 self.stream.stop_sending(h3::error::Code::H3_NO_ERROR); 197 } 198 } 199} 200 201fn request_body(recv: h3::server::RequestStream<h3_quinn::RecvStream, Bytes>) -> Body { 202 let guard = RecvGuard { 203 stream: recv, 204 ended: false, 205 }; 206 let stream = futures::stream::unfold(Some(guard), |state| async move { 207 let mut guard = state?; 208 match guard.stream.recv_data().await { 209 Ok(Some(mut buf)) => { 210 let bytes = buf.copy_to_bytes(buf.remaining()); 211 Some((Ok::<Bytes, std::io::Error>(bytes), Some(guard))) 212 } 213 Ok(None) => { 214 guard.ended = true; 215 None 216 } 217 Err(error) => { 218 guard.ended = true; 219 Some((Err(std::io::Error::other(error.to_string())), None)) 220 } 221 } 222 }); 223 Body::from_stream(stream) 224} 225 226#[cfg(test)] 227mod tests { 228 use super::*; 229 230 use axum::routing::get; 231 use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; 232 use rustls::crypto::aws_lc_rs; 233 234 #[test] 235 fn an_unconfirmed_handshake_is_early_data_and_a_confirmed_one_is_not() { 236 assert_eq!( 237 early_data(false), 238 EarlyData::Yes, 239 "data before handshake confirmation is early data, fail closed" 240 ); 241 assert_eq!(early_data(true), EarlyData::No); 242 } 243 use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; 244 use rustls::{DigitallySignedStruct, SignatureScheme}; 245 246 use crate::tls; 247 248 #[derive(Debug)] 249 struct AcceptAnyServerCert; 250 251 impl ServerCertVerifier for AcceptAnyServerCert { 252 fn verify_server_cert( 253 &self, 254 _end_entity: &CertificateDer<'_>, 255 _intermediates: &[CertificateDer<'_>], 256 _server_name: &ServerName<'_>, 257 _ocsp_response: &[u8], 258 _now: UnixTime, 259 ) -> Result<ServerCertVerified, rustls::Error> { 260 Ok(ServerCertVerified::assertion()) 261 } 262 263 fn verify_tls12_signature( 264 &self, 265 message: &[u8], 266 cert: &CertificateDer<'_>, 267 dss: &DigitallySignedStruct, 268 ) -> Result<HandshakeSignatureValid, rustls::Error> { 269 rustls::crypto::verify_tls12_signature( 270 message, 271 cert, 272 dss, 273 &aws_lc_rs::default_provider().signature_verification_algorithms, 274 ) 275 } 276 277 fn verify_tls13_signature( 278 &self, 279 message: &[u8], 280 cert: &CertificateDer<'_>, 281 dss: &DigitallySignedStruct, 282 ) -> Result<HandshakeSignatureValid, rustls::Error> { 283 rustls::crypto::verify_tls13_signature( 284 message, 285 cert, 286 dss, 287 &aws_lc_rs::default_provider().signature_verification_algorithms, 288 ) 289 } 290 291 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> { 292 aws_lc_rs::default_provider() 293 .signature_verification_algorithms 294 .supported_schemes() 295 } 296 } 297 298 fn client_endpoint() -> Endpoint { 299 client_endpoint_with_provider(aws_lc_rs::default_provider()) 300 } 301 302 fn client_endpoint_with_provider(provider: rustls::crypto::CryptoProvider) -> Endpoint { 303 let mut crypto = rustls::ClientConfig::builder_with_provider(Arc::new(provider)) 304 .with_protocol_versions(&[&rustls::version::TLS13]) 305 .unwrap() 306 .dangerous() 307 .with_custom_certificate_verifier(Arc::new(AcceptAnyServerCert)) 308 .with_no_client_auth(); 309 crypto.alpn_protocols = vec![b"h3".to_vec()]; 310 let quic = quinn::crypto::rustls::QuicClientConfig::try_from(crypto).unwrap(); 311 let mut endpoint = Endpoint::client("[::1]:0".parse().unwrap()).unwrap(); 312 endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(quic))); 313 endpoint 314 } 315 316 fn classical_only_provider() -> rustls::crypto::CryptoProvider { 317 let mut provider = aws_lc_rs::default_provider(); 318 provider.kx_groups = vec![aws_lc_rs::kx_group::X25519]; 319 provider 320 } 321 322 #[tokio::test] 323 async fn an_h3_get_roundtrips_through_the_router() { 324 let app = Router::new().route( 325 "/", 326 get(|ConnectInfo(peer): ConnectInfo<SocketAddr>| async move { peer.to_string() }), 327 ); 328 let endpoint = build_endpoint( 329 "[::1]:0".parse().unwrap(), 330 tls::test_support::resolver(), 331 tls::test_support::limits(), 332 EarlyDataPolicy::Disabled, 333 ) 334 .unwrap(); 335 let addr = endpoint.local_addr().unwrap(); 336 let shutdown = CancellationToken::new(); 337 tokio::spawn(serve_http3( 338 endpoint, 339 crate::altsvc::with_host_from_authority(app), 340 tls::test_support::limits(), 341 shutdown.clone(), 342 )); 343 344 let client = client_endpoint(); 345 let client_port = client.local_addr().unwrap().port(); 346 let conn = client.connect(addr, "localhost").unwrap().await.unwrap(); 347 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn)) 348 .await 349 .unwrap(); 350 let drive = 351 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await }); 352 353 let request = http::Request::get("https://localhost/").body(()).unwrap(); 354 let mut stream = send_request.send_request(request).await.unwrap(); 355 stream.finish().await.unwrap(); 356 assert_eq!(stream.recv_response().await.unwrap().status(), 200); 357 358 let mut body = Vec::new(); 359 while let Some(mut chunk) = stream.recv_data().await.unwrap() { 360 let bytes = chunk.copy_to_bytes(chunk.remaining()); 361 body.extend_from_slice(&bytes); 362 } 363 let reported: SocketAddr = String::from_utf8(body).unwrap().parse().unwrap(); 364 assert_eq!( 365 reported.port(), 366 client_port, 367 "the h3 handler must see the QUIC remote address via ConnectInfo" 368 ); 369 370 shutdown.cancel(); 371 drive.abort(); 372 } 373 374 #[tokio::test] 375 async fn an_h3_request_is_tagged_with_the_h3_protocol() { 376 use axum::middleware::from_fn; 377 378 let app = Router::new() 379 .route( 380 "/proto", 381 get(|req: axum::extract::Request| async move { 382 req.extensions() 383 .get::<crate::protocol::NegotiatedProtocol>() 384 .map(|protocol| protocol.as_str()) 385 .unwrap_or("missing") 386 .to_string() 387 }), 388 ) 389 .layer(from_fn(crate::protocol::tag)); 390 let endpoint = build_endpoint( 391 "[::1]:0".parse().unwrap(), 392 tls::test_support::resolver(), 393 tls::test_support::limits(), 394 EarlyDataPolicy::Disabled, 395 ) 396 .unwrap(); 397 let addr = endpoint.local_addr().unwrap(); 398 let shutdown = CancellationToken::new(); 399 tokio::spawn(serve_http3( 400 endpoint, 401 crate::altsvc::with_host_from_authority(app), 402 tls::test_support::limits(), 403 shutdown.clone(), 404 )); 405 406 let client = client_endpoint(); 407 let conn = client.connect(addr, "localhost").unwrap().await.unwrap(); 408 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn)) 409 .await 410 .unwrap(); 411 let drive = 412 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await }); 413 414 let request = http::Request::get("https://localhost/proto") 415 .body(()) 416 .unwrap(); 417 let mut stream = send_request.send_request(request).await.unwrap(); 418 stream.finish().await.unwrap(); 419 assert_eq!(stream.recv_response().await.unwrap().status(), 200); 420 421 let mut body = Vec::new(); 422 while let Some(mut chunk) = stream.recv_data().await.unwrap() { 423 let bytes = chunk.copy_to_bytes(chunk.remaining()); 424 body.extend_from_slice(&bytes); 425 } 426 assert_eq!( 427 String::from_utf8(body).unwrap(), 428 "h3", 429 "a request served over QUIC must be tagged as the h3 negotiated protocol" 430 ); 431 432 shutdown.cancel(); 433 drive.abort(); 434 } 435 436 #[tokio::test] 437 async fn an_early_data_enabled_endpoint_still_serves_through_the_zero_rtt_path() { 438 let app = Router::new().route("/", get(|| async { "ok" })); 439 let endpoint = build_endpoint( 440 "[::1]:0".parse().unwrap(), 441 tls::test_support::resolver(), 442 tls::test_support::limits(), 443 EarlyDataPolicy::Enabled, 444 ) 445 .unwrap(); 446 let addr = endpoint.local_addr().unwrap(); 447 let shutdown = CancellationToken::new(); 448 tokio::spawn(serve_http3( 449 endpoint, 450 crate::altsvc::with_host_from_authority(app), 451 tls::test_support::limits(), 452 shutdown.clone(), 453 )); 454 455 let client = client_endpoint(); 456 let conn = client.connect(addr, "localhost").unwrap().await.unwrap(); 457 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn)) 458 .await 459 .unwrap(); 460 let drive = 461 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await }); 462 463 let request = http::Request::get("https://localhost/").body(()).unwrap(); 464 let mut stream = send_request.send_request(request).await.unwrap(); 465 stream.finish().await.unwrap(); 466 assert_eq!( 467 stream.recv_response().await.unwrap().status(), 468 200, 469 "a server with early data enabled must serve requests through the 0-RTT acceptance path" 470 ); 471 472 shutdown.cancel(); 473 drive.abort(); 474 } 475 476 #[tokio::test] 477 async fn a_classical_only_h3_client_completes_over_x25519() { 478 let app = Router::new().route("/", get(|| async { "ok" })); 479 let endpoint = build_endpoint( 480 "[::1]:0".parse().unwrap(), 481 tls::test_support::resolver(), 482 tls::test_support::limits(), 483 EarlyDataPolicy::Disabled, 484 ) 485 .unwrap(); 486 let addr = endpoint.local_addr().unwrap(); 487 let shutdown = CancellationToken::new(); 488 tokio::spawn(serve_http3( 489 endpoint, 490 crate::altsvc::with_host_from_authority(app), 491 tls::test_support::limits(), 492 shutdown.clone(), 493 )); 494 495 let client = client_endpoint_with_provider(classical_only_provider()); 496 let conn = client.connect(addr, "localhost").unwrap().await.unwrap(); 497 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn)) 498 .await 499 .unwrap(); 500 let drive = 501 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await }); 502 503 let request = http::Request::get("https://localhost/").body(()).unwrap(); 504 let mut stream = send_request.send_request(request).await.unwrap(); 505 stream.finish().await.unwrap(); 506 assert_eq!( 507 stream.recv_response().await.unwrap().status(), 508 200, 509 "a QUIC client without ML-KEM must still complete the h3 handshake over classical X25519" 510 ); 511 512 shutdown.cancel(); 513 drive.abort(); 514 } 515 516 #[tokio::test] 517 async fn an_in_flight_h3_request_finishes_during_drain() { 518 let app = Router::new().route( 519 "/slow", 520 get(|| async { 521 tokio::time::sleep(Duration::from_millis(300)).await; 522 "drained-clean" 523 }), 524 ); 525 let endpoint = build_endpoint( 526 "[::1]:0".parse().unwrap(), 527 tls::test_support::resolver(), 528 tls::test_support::limits(), 529 EarlyDataPolicy::Disabled, 530 ) 531 .unwrap(); 532 let addr = endpoint.local_addr().unwrap(); 533 let shutdown = CancellationToken::new(); 534 tokio::spawn(serve_http3( 535 endpoint, 536 crate::altsvc::with_host_from_authority(app), 537 tls::test_support::limits(), 538 shutdown.clone(), 539 )); 540 541 let client = client_endpoint(); 542 let conn = client.connect(addr, "localhost").unwrap().await.unwrap(); 543 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn)) 544 .await 545 .unwrap(); 546 let drive = 547 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await }); 548 549 let request = http::Request::get("https://localhost/slow") 550 .body(()) 551 .unwrap(); 552 let mut stream = send_request.send_request(request).await.unwrap(); 553 stream.finish().await.unwrap(); 554 555 tokio::time::sleep(Duration::from_millis(100)).await; 556 shutdown.cancel(); 557 558 assert_eq!( 559 stream.recv_response().await.unwrap().status(), 560 200, 561 "an in-flight h3 request must complete through the graceful drain" 562 ); 563 let mut body = Vec::new(); 564 while let Some(mut chunk) = stream.recv_data().await.unwrap() { 565 let bytes = chunk.copy_to_bytes(chunk.remaining()); 566 body.extend_from_slice(&bytes); 567 } 568 assert_eq!(body, b"drained-clean"); 569 570 shutdown.cancel(); 571 drive.abort(); 572 } 573 574 #[tokio::test] 575 async fn a_slow_h3_response_survives_the_idle_timeout() { 576 use std::num::{NonZeroU32, NonZeroU64}; 577 578 let limits = ListenLimits::new( 579 NonZeroU64::new(1_000).unwrap(), 580 NonZeroU64::new(2_000).unwrap(), 581 NonZeroU32::new(64).unwrap(), 582 ); 583 let app = Router::new().route( 584 "/", 585 get(|| async { 586 tokio::time::sleep(Duration::from_secs(3)).await; 587 "ok" 588 }), 589 ); 590 let endpoint = build_endpoint( 591 "[::1]:0".parse().unwrap(), 592 tls::test_support::resolver(), 593 limits, 594 EarlyDataPolicy::Disabled, 595 ) 596 .unwrap(); 597 let addr = endpoint.local_addr().unwrap(); 598 let shutdown = CancellationToken::new(); 599 tokio::spawn(serve_http3( 600 endpoint, 601 crate::altsvc::with_host_from_authority(app), 602 limits, 603 shutdown.clone(), 604 )); 605 606 let client = client_endpoint(); 607 let conn = client.connect(addr, "localhost").unwrap().await.unwrap(); 608 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn)) 609 .await 610 .unwrap(); 611 let drive = 612 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await }); 613 614 let request = http::Request::get("https://localhost/").body(()).unwrap(); 615 let mut stream = send_request.send_request(request).await.unwrap(); 616 stream.finish().await.unwrap(); 617 assert_eq!( 618 stream.recv_response().await.unwrap().status(), 619 200, 620 "keep-alive must hold the connection through a response slower than the idle timeout" 621 ); 622 623 let mut body = Vec::new(); 624 while let Some(mut chunk) = stream.recv_data().await.unwrap() { 625 let bytes = chunk.copy_to_bytes(chunk.remaining()); 626 body.extend_from_slice(&bytes); 627 } 628 assert_eq!(body, b"ok"); 629 630 shutdown.cancel(); 631 drive.abort(); 632 } 633}