Now let's take a silly one
1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::time::Duration;
4
5use axum::Router;
6use axum::body::Body;
7use axum::extract::ConnectInfo;
8use bytes::{Buf, Bytes};
9use futures::{FutureExt, StreamExt};
10use http::{Request, Response, Version};
11use quinn::{Endpoint, Incoming};
12use tokio::sync::Semaphore;
13use tokio_util::sync::CancellationToken;
14use tokio_util::task::TaskTracker;
15use tower::ServiceExt;
16
17use rustls::server::ResolvesServerCert;
18
19use crate::limits::ListenLimits;
20use crate::tls::{self, TlsError};
21use crate::zerortt::{EarlyData, EarlyDataPolicy};
22
23fn early_data(confirmed: bool) -> EarlyData {
24 match confirmed {
25 true => EarlyData::No,
26 false => EarlyData::Yes,
27 }
28}
29
30const LISTENER_DRAIN_GRACE: Duration = Duration::from_secs(30);
31const ENDPOINT_DRAIN_GRACE: Duration = Duration::from_secs(5);
32const CONNECTION_DRAIN_GRACE: Duration = Duration::from_secs(10);
33
34pub fn build_endpoint(
35 addr: SocketAddr,
36 resolver: Arc<dyn ResolvesServerCert>,
37 limits: ListenLimits,
38 early_data: EarlyDataPolicy,
39) -> Result<Endpoint, EndpointError> {
40 let server_config = tls::build_quic_server_config(resolver, limits, early_data)?;
41 let endpoint = Endpoint::server(server_config, addr)?;
42 Ok(endpoint)
43}
44
45#[derive(Debug, thiserror::Error)]
46pub enum EndpointError {
47 #[error(transparent)]
48 Tls(#[from] TlsError),
49 #[error("binding quic socket: {0}")]
50 Bind(#[from] std::io::Error),
51}
52
53pub async fn serve_http3(
54 endpoint: Endpoint,
55 app: Router,
56 limits: ListenLimits,
57 shutdown: CancellationToken,
58) {
59 let tracker = TaskTracker::new();
60 let connections = Arc::new(Semaphore::new(limits.max_connections()));
61 loop {
62 let incoming = tokio::select! {
63 () = shutdown.cancelled() => break,
64 incoming = endpoint.accept() => incoming,
65 };
66 let Some(incoming) = incoming else { break };
67 let Ok(permit) = Arc::clone(&connections).try_acquire_owned() else {
68 incoming.refuse();
69 continue;
70 };
71 let app = app.clone();
72 let conn_shutdown = shutdown.clone();
73 let conn_tracker = tracker.clone();
74 tracker.spawn(async move {
75 let _permit = permit;
76 if let Err(error) = serve_connection(incoming, app, conn_shutdown, conn_tracker).await {
77 tracing::debug!("h3 connection ended: {error}");
78 }
79 });
80 }
81 tracker.close();
82 let _ = tokio::time::timeout(LISTENER_DRAIN_GRACE, tracker.wait()).await;
83 endpoint.close(0u32.into(), b"shutdown");
84 let _ = tokio::time::timeout(ENDPOINT_DRAIN_GRACE, endpoint.wait_idle()).await;
85}
86
87async fn serve_connection(
88 incoming: Incoming,
89 app: Router,
90 shutdown: CancellationToken,
91 tracker: TaskTracker,
92) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
93 let (conn, confirmation, mut confirmed) = match incoming.accept()?.into_0rtt() {
94 Ok((conn, accepted)) => (conn, accepted.map(|_| ()).left_future(), false),
95 Err(connecting) => (
96 connecting.await?,
97 std::future::pending::<()>().right_future(),
98 true,
99 ),
100 };
101 tokio::pin!(confirmation);
102 let remote = conn.remote_address();
103 let quic = conn.clone();
104 let mut h3_conn =
105 h3::server::Connection::<_, Bytes>::new(h3_quinn::Connection::new(conn)).await?;
106
107 let drain_deadline = tokio::time::sleep(CONNECTION_DRAIN_GRACE);
108 tokio::pin!(drain_deadline);
109 let mut draining = false;
110 loop {
111 tokio::select! {
112 biased;
113 () = shutdown.cancelled(), if !draining => {
114 draining = true;
115 drain_deadline
116 .as_mut()
117 .reset(tokio::time::Instant::now() + CONNECTION_DRAIN_GRACE);
118 let _ = h3_conn.shutdown(0).await;
119 }
120 () = &mut drain_deadline, if draining => {
121 tracing::debug!("h3 connection from {remote} drain timed out, closing");
122 quic.close(0u32.into(), b"drain timeout");
123 break;
124 }
125 resolved = h3_conn.accept() => match resolved {
126 Ok(Some(resolver)) => {
127 let app = app.clone();
128 let early = early_data(confirmed);
129 tracker.spawn(async move {
130 if let Err(error) = serve_request(resolver, app, remote, early).await {
131 tracing::debug!("h3 request from {remote} failed: {error}");
132 }
133 });
134 }
135 Ok(None) => break,
136 Err(error) => {
137 tracing::debug!("h3 accept from {remote} error: {error}");
138 break;
139 }
140 },
141 () = &mut confirmation, if !confirmed => {
142 confirmed = true;
143 }
144 }
145 }
146 Ok(())
147}
148
149async fn serve_request(
150 resolver: h3::server::RequestResolver<h3_quinn::Connection, Bytes>,
151 app: Router,
152 remote: SocketAddr,
153 early: EarlyData,
154) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
155 let (request, stream) = resolver.resolve_request().await?;
156 let (mut send, recv) = stream.split();
157
158 let (mut parts, ()) = request.into_parts();
159 parts.version = Version::HTTP_3;
160 parts.extensions.insert(ConnectInfo(remote));
161 parts.extensions.insert(early);
162 let request = Request::from_parts(parts, request_body(recv));
163
164 let response = match app.oneshot(request).await {
165 Ok(response) => response,
166 Err(infallible) => match infallible {},
167 };
168
169 let (parts, body) = response.into_parts();
170 send.send_response(Response::from_parts(parts, ())).await?;
171
172 let mut data = body.into_data_stream();
173 while let Some(chunk) = data.next().await {
174 match chunk {
175 Ok(bytes) if bytes.has_remaining() => send.send_data(bytes).await?,
176 Ok(_) => {}
177 Err(error) => {
178 tracing::debug!("h3 response body to {remote} errored: {error}");
179 send.stop_stream(h3::error::Code::H3_INTERNAL_ERROR);
180 return Ok(());
181 }
182 }
183 }
184 send.finish().await?;
185 Ok(())
186}
187
188struct RecvGuard {
189 stream: h3::server::RequestStream<h3_quinn::RecvStream, Bytes>,
190 ended: bool,
191}
192
193impl Drop for RecvGuard {
194 fn drop(&mut self) {
195 if !self.ended {
196 self.stream.stop_sending(h3::error::Code::H3_NO_ERROR);
197 }
198 }
199}
200
201fn request_body(recv: h3::server::RequestStream<h3_quinn::RecvStream, Bytes>) -> Body {
202 let guard = RecvGuard {
203 stream: recv,
204 ended: false,
205 };
206 let stream = futures::stream::unfold(Some(guard), |state| async move {
207 let mut guard = state?;
208 match guard.stream.recv_data().await {
209 Ok(Some(mut buf)) => {
210 let bytes = buf.copy_to_bytes(buf.remaining());
211 Some((Ok::<Bytes, std::io::Error>(bytes), Some(guard)))
212 }
213 Ok(None) => {
214 guard.ended = true;
215 None
216 }
217 Err(error) => {
218 guard.ended = true;
219 Some((Err(std::io::Error::other(error.to_string())), None))
220 }
221 }
222 });
223 Body::from_stream(stream)
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 use axum::routing::get;
231 use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
232 use rustls::crypto::aws_lc_rs;
233
234 #[test]
235 fn an_unconfirmed_handshake_is_early_data_and_a_confirmed_one_is_not() {
236 assert_eq!(
237 early_data(false),
238 EarlyData::Yes,
239 "data before handshake confirmation is early data, fail closed"
240 );
241 assert_eq!(early_data(true), EarlyData::No);
242 }
243 use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
244 use rustls::{DigitallySignedStruct, SignatureScheme};
245
246 use crate::tls;
247
248 #[derive(Debug)]
249 struct AcceptAnyServerCert;
250
251 impl ServerCertVerifier for AcceptAnyServerCert {
252 fn verify_server_cert(
253 &self,
254 _end_entity: &CertificateDer<'_>,
255 _intermediates: &[CertificateDer<'_>],
256 _server_name: &ServerName<'_>,
257 _ocsp_response: &[u8],
258 _now: UnixTime,
259 ) -> Result<ServerCertVerified, rustls::Error> {
260 Ok(ServerCertVerified::assertion())
261 }
262
263 fn verify_tls12_signature(
264 &self,
265 message: &[u8],
266 cert: &CertificateDer<'_>,
267 dss: &DigitallySignedStruct,
268 ) -> Result<HandshakeSignatureValid, rustls::Error> {
269 rustls::crypto::verify_tls12_signature(
270 message,
271 cert,
272 dss,
273 &aws_lc_rs::default_provider().signature_verification_algorithms,
274 )
275 }
276
277 fn verify_tls13_signature(
278 &self,
279 message: &[u8],
280 cert: &CertificateDer<'_>,
281 dss: &DigitallySignedStruct,
282 ) -> Result<HandshakeSignatureValid, rustls::Error> {
283 rustls::crypto::verify_tls13_signature(
284 message,
285 cert,
286 dss,
287 &aws_lc_rs::default_provider().signature_verification_algorithms,
288 )
289 }
290
291 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
292 aws_lc_rs::default_provider()
293 .signature_verification_algorithms
294 .supported_schemes()
295 }
296 }
297
298 fn client_endpoint() -> Endpoint {
299 client_endpoint_with_provider(aws_lc_rs::default_provider())
300 }
301
302 fn client_endpoint_with_provider(provider: rustls::crypto::CryptoProvider) -> Endpoint {
303 let mut crypto = rustls::ClientConfig::builder_with_provider(Arc::new(provider))
304 .with_protocol_versions(&[&rustls::version::TLS13])
305 .unwrap()
306 .dangerous()
307 .with_custom_certificate_verifier(Arc::new(AcceptAnyServerCert))
308 .with_no_client_auth();
309 crypto.alpn_protocols = vec![b"h3".to_vec()];
310 let quic = quinn::crypto::rustls::QuicClientConfig::try_from(crypto).unwrap();
311 let mut endpoint = Endpoint::client("[::1]:0".parse().unwrap()).unwrap();
312 endpoint.set_default_client_config(quinn::ClientConfig::new(Arc::new(quic)));
313 endpoint
314 }
315
316 fn classical_only_provider() -> rustls::crypto::CryptoProvider {
317 let mut provider = aws_lc_rs::default_provider();
318 provider.kx_groups = vec![aws_lc_rs::kx_group::X25519];
319 provider
320 }
321
322 #[tokio::test]
323 async fn an_h3_get_roundtrips_through_the_router() {
324 let app = Router::new().route(
325 "/",
326 get(|ConnectInfo(peer): ConnectInfo<SocketAddr>| async move { peer.to_string() }),
327 );
328 let endpoint = build_endpoint(
329 "[::1]:0".parse().unwrap(),
330 tls::test_support::resolver(),
331 tls::test_support::limits(),
332 EarlyDataPolicy::Disabled,
333 )
334 .unwrap();
335 let addr = endpoint.local_addr().unwrap();
336 let shutdown = CancellationToken::new();
337 tokio::spawn(serve_http3(
338 endpoint,
339 crate::altsvc::with_host_from_authority(app),
340 tls::test_support::limits(),
341 shutdown.clone(),
342 ));
343
344 let client = client_endpoint();
345 let client_port = client.local_addr().unwrap().port();
346 let conn = client.connect(addr, "localhost").unwrap().await.unwrap();
347 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn))
348 .await
349 .unwrap();
350 let drive =
351 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await });
352
353 let request = http::Request::get("https://localhost/").body(()).unwrap();
354 let mut stream = send_request.send_request(request).await.unwrap();
355 stream.finish().await.unwrap();
356 assert_eq!(stream.recv_response().await.unwrap().status(), 200);
357
358 let mut body = Vec::new();
359 while let Some(mut chunk) = stream.recv_data().await.unwrap() {
360 let bytes = chunk.copy_to_bytes(chunk.remaining());
361 body.extend_from_slice(&bytes);
362 }
363 let reported: SocketAddr = String::from_utf8(body).unwrap().parse().unwrap();
364 assert_eq!(
365 reported.port(),
366 client_port,
367 "the h3 handler must see the QUIC remote address via ConnectInfo"
368 );
369
370 shutdown.cancel();
371 drive.abort();
372 }
373
374 #[tokio::test]
375 async fn an_h3_request_is_tagged_with_the_h3_protocol() {
376 use axum::middleware::from_fn;
377
378 let app = Router::new()
379 .route(
380 "/proto",
381 get(|req: axum::extract::Request| async move {
382 req.extensions()
383 .get::<crate::protocol::NegotiatedProtocol>()
384 .map(|protocol| protocol.as_str())
385 .unwrap_or("missing")
386 .to_string()
387 }),
388 )
389 .layer(from_fn(crate::protocol::tag));
390 let endpoint = build_endpoint(
391 "[::1]:0".parse().unwrap(),
392 tls::test_support::resolver(),
393 tls::test_support::limits(),
394 EarlyDataPolicy::Disabled,
395 )
396 .unwrap();
397 let addr = endpoint.local_addr().unwrap();
398 let shutdown = CancellationToken::new();
399 tokio::spawn(serve_http3(
400 endpoint,
401 crate::altsvc::with_host_from_authority(app),
402 tls::test_support::limits(),
403 shutdown.clone(),
404 ));
405
406 let client = client_endpoint();
407 let conn = client.connect(addr, "localhost").unwrap().await.unwrap();
408 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn))
409 .await
410 .unwrap();
411 let drive =
412 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await });
413
414 let request = http::Request::get("https://localhost/proto")
415 .body(())
416 .unwrap();
417 let mut stream = send_request.send_request(request).await.unwrap();
418 stream.finish().await.unwrap();
419 assert_eq!(stream.recv_response().await.unwrap().status(), 200);
420
421 let mut body = Vec::new();
422 while let Some(mut chunk) = stream.recv_data().await.unwrap() {
423 let bytes = chunk.copy_to_bytes(chunk.remaining());
424 body.extend_from_slice(&bytes);
425 }
426 assert_eq!(
427 String::from_utf8(body).unwrap(),
428 "h3",
429 "a request served over QUIC must be tagged as the h3 negotiated protocol"
430 );
431
432 shutdown.cancel();
433 drive.abort();
434 }
435
436 #[tokio::test]
437 async fn an_early_data_enabled_endpoint_still_serves_through_the_zero_rtt_path() {
438 let app = Router::new().route("/", get(|| async { "ok" }));
439 let endpoint = build_endpoint(
440 "[::1]:0".parse().unwrap(),
441 tls::test_support::resolver(),
442 tls::test_support::limits(),
443 EarlyDataPolicy::Enabled,
444 )
445 .unwrap();
446 let addr = endpoint.local_addr().unwrap();
447 let shutdown = CancellationToken::new();
448 tokio::spawn(serve_http3(
449 endpoint,
450 crate::altsvc::with_host_from_authority(app),
451 tls::test_support::limits(),
452 shutdown.clone(),
453 ));
454
455 let client = client_endpoint();
456 let conn = client.connect(addr, "localhost").unwrap().await.unwrap();
457 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn))
458 .await
459 .unwrap();
460 let drive =
461 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await });
462
463 let request = http::Request::get("https://localhost/").body(()).unwrap();
464 let mut stream = send_request.send_request(request).await.unwrap();
465 stream.finish().await.unwrap();
466 assert_eq!(
467 stream.recv_response().await.unwrap().status(),
468 200,
469 "a server with early data enabled must serve requests through the 0-RTT acceptance path"
470 );
471
472 shutdown.cancel();
473 drive.abort();
474 }
475
476 #[tokio::test]
477 async fn a_classical_only_h3_client_completes_over_x25519() {
478 let app = Router::new().route("/", get(|| async { "ok" }));
479 let endpoint = build_endpoint(
480 "[::1]:0".parse().unwrap(),
481 tls::test_support::resolver(),
482 tls::test_support::limits(),
483 EarlyDataPolicy::Disabled,
484 )
485 .unwrap();
486 let addr = endpoint.local_addr().unwrap();
487 let shutdown = CancellationToken::new();
488 tokio::spawn(serve_http3(
489 endpoint,
490 crate::altsvc::with_host_from_authority(app),
491 tls::test_support::limits(),
492 shutdown.clone(),
493 ));
494
495 let client = client_endpoint_with_provider(classical_only_provider());
496 let conn = client.connect(addr, "localhost").unwrap().await.unwrap();
497 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn))
498 .await
499 .unwrap();
500 let drive =
501 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await });
502
503 let request = http::Request::get("https://localhost/").body(()).unwrap();
504 let mut stream = send_request.send_request(request).await.unwrap();
505 stream.finish().await.unwrap();
506 assert_eq!(
507 stream.recv_response().await.unwrap().status(),
508 200,
509 "a QUIC client without ML-KEM must still complete the h3 handshake over classical X25519"
510 );
511
512 shutdown.cancel();
513 drive.abort();
514 }
515
516 #[tokio::test]
517 async fn an_in_flight_h3_request_finishes_during_drain() {
518 let app = Router::new().route(
519 "/slow",
520 get(|| async {
521 tokio::time::sleep(Duration::from_millis(300)).await;
522 "drained-clean"
523 }),
524 );
525 let endpoint = build_endpoint(
526 "[::1]:0".parse().unwrap(),
527 tls::test_support::resolver(),
528 tls::test_support::limits(),
529 EarlyDataPolicy::Disabled,
530 )
531 .unwrap();
532 let addr = endpoint.local_addr().unwrap();
533 let shutdown = CancellationToken::new();
534 tokio::spawn(serve_http3(
535 endpoint,
536 crate::altsvc::with_host_from_authority(app),
537 tls::test_support::limits(),
538 shutdown.clone(),
539 ));
540
541 let client = client_endpoint();
542 let conn = client.connect(addr, "localhost").unwrap().await.unwrap();
543 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn))
544 .await
545 .unwrap();
546 let drive =
547 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await });
548
549 let request = http::Request::get("https://localhost/slow")
550 .body(())
551 .unwrap();
552 let mut stream = send_request.send_request(request).await.unwrap();
553 stream.finish().await.unwrap();
554
555 tokio::time::sleep(Duration::from_millis(100)).await;
556 shutdown.cancel();
557
558 assert_eq!(
559 stream.recv_response().await.unwrap().status(),
560 200,
561 "an in-flight h3 request must complete through the graceful drain"
562 );
563 let mut body = Vec::new();
564 while let Some(mut chunk) = stream.recv_data().await.unwrap() {
565 let bytes = chunk.copy_to_bytes(chunk.remaining());
566 body.extend_from_slice(&bytes);
567 }
568 assert_eq!(body, b"drained-clean");
569
570 shutdown.cancel();
571 drive.abort();
572 }
573
574 #[tokio::test]
575 async fn a_slow_h3_response_survives_the_idle_timeout() {
576 use std::num::{NonZeroU32, NonZeroU64};
577
578 let limits = ListenLimits::new(
579 NonZeroU64::new(1_000).unwrap(),
580 NonZeroU64::new(2_000).unwrap(),
581 NonZeroU32::new(64).unwrap(),
582 );
583 let app = Router::new().route(
584 "/",
585 get(|| async {
586 tokio::time::sleep(Duration::from_secs(3)).await;
587 "ok"
588 }),
589 );
590 let endpoint = build_endpoint(
591 "[::1]:0".parse().unwrap(),
592 tls::test_support::resolver(),
593 limits,
594 EarlyDataPolicy::Disabled,
595 )
596 .unwrap();
597 let addr = endpoint.local_addr().unwrap();
598 let shutdown = CancellationToken::new();
599 tokio::spawn(serve_http3(
600 endpoint,
601 crate::altsvc::with_host_from_authority(app),
602 limits,
603 shutdown.clone(),
604 ));
605
606 let client = client_endpoint();
607 let conn = client.connect(addr, "localhost").unwrap().await.unwrap();
608 let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(conn))
609 .await
610 .unwrap();
611 let drive =
612 tokio::spawn(async move { std::future::poll_fn(|cx| driver.poll_close(cx)).await });
613
614 let request = http::Request::get("https://localhost/").body(()).unwrap();
615 let mut stream = send_request.send_request(request).await.unwrap();
616 stream.finish().await.unwrap();
617 assert_eq!(
618 stream.recv_response().await.unwrap().status(),
619 200,
620 "keep-alive must hold the connection through a response slower than the idle timeout"
621 );
622
623 let mut body = Vec::new();
624 while let Some(mut chunk) = stream.recv_data().await.unwrap() {
625 let bytes = chunk.copy_to_bytes(chunk.remaining());
626 body.extend_from_slice(&bytes);
627 }
628 assert_eq!(body, b"ok");
629
630 shutdown.cancel();
631 drive.abort();
632 }
633}