Now let's take a silly one
1use axum::Router;
2use axum::body::Body;
3use http::header::ALT_SVC;
4use http::{HeaderValue, Request, Response, StatusCode};
5
6const ALT_SVC_MAX_AGE_SECS: u32 = 86_400;
7
8pub fn alt_svc_header(port: u16) -> HeaderValue {
9 HeaderValue::from_str(&format!("h3=\":{port}\"; ma={ALT_SVC_MAX_AGE_SECS}"))
10 .expect("alt-svc header value is valid ascii")
11}
12
13pub fn with_alt_svc(app: Router, port: u16) -> Router {
14 let value = alt_svc_header(port);
15 app.layer(axum::middleware::map_response(
16 move |mut response: Response<Body>| {
17 let value = value.clone();
18 async move {
19 if response.status() != StatusCode::SWITCHING_PROTOCOLS {
20 response.headers_mut().insert(ALT_SVC, value);
21 }
22 response
23 }
24 },
25 ))
26}
27
28pub fn with_host_from_authority(app: Router) -> Router {
29 app.layer(axum::middleware::map_request(
30 |mut request: Request<Body>| async move {
31 let authority = request
32 .uri()
33 .authority()
34 .map(|authority| HeaderValue::from_str(authority.as_str()));
35 match (
36 request.headers().contains_key(http::header::HOST),
37 authority,
38 ) {
39 (false, Some(Ok(value))) => {
40 request.headers_mut().insert(http::header::HOST, value);
41 request
42 }
43 _ => request,
44 }
45 },
46 ))
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52
53 use axum::routing::get;
54 use tower::ServiceExt;
55
56 #[test]
57 fn alt_svc_header_advertises_h3() {
58 assert_eq!(
59 alt_svc_header(443).to_str().unwrap(),
60 "h3=\":443\"; ma=86400"
61 );
62 }
63
64 #[tokio::test]
65 async fn alt_svc_added_to_responses_except_switching_protocols() {
66 let app = with_alt_svc(
67 Router::new().route("/ok", get(|| async { "ok" })).route(
68 "/upgrade",
69 get(|| async {
70 Response::builder()
71 .status(StatusCode::SWITCHING_PROTOCOLS)
72 .body(Body::empty())
73 .unwrap()
74 }),
75 ),
76 443,
77 );
78
79 let normal = app
80 .clone()
81 .oneshot(Request::get("/ok").body(Body::empty()).unwrap())
82 .await
83 .unwrap();
84 assert_eq!(
85 normal.headers().get(ALT_SVC).and_then(|v| v.to_str().ok()),
86 Some("h3=\":443\"; ma=86400")
87 );
88
89 let upgrade = app
90 .oneshot(Request::get("/upgrade").body(Body::empty()).unwrap())
91 .await
92 .unwrap();
93 assert!(
94 upgrade.headers().get(ALT_SVC).is_none(),
95 "101 responses must not carry Alt-Svc"
96 );
97 }
98}