Now let's take a silly one
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 2.9 kB View raw
1use axum::Router; 2use axum::body::Body; 3use http::header::ALT_SVC; 4use http::{HeaderValue, Request, Response, StatusCode}; 5 6const ALT_SVC_MAX_AGE_SECS: u32 = 86_400; 7 8pub fn alt_svc_header(port: u16) -> HeaderValue { 9 HeaderValue::from_str(&format!("h3=\":{port}\"; ma={ALT_SVC_MAX_AGE_SECS}")) 10 .expect("alt-svc header value is valid ascii") 11} 12 13pub fn with_alt_svc(app: Router, port: u16) -> Router { 14 let value = alt_svc_header(port); 15 app.layer(axum::middleware::map_response( 16 move |mut response: Response<Body>| { 17 let value = value.clone(); 18 async move { 19 if response.status() != StatusCode::SWITCHING_PROTOCOLS { 20 response.headers_mut().insert(ALT_SVC, value); 21 } 22 response 23 } 24 }, 25 )) 26} 27 28pub fn with_host_from_authority(app: Router) -> Router { 29 app.layer(axum::middleware::map_request( 30 |mut request: Request<Body>| async move { 31 let authority = request 32 .uri() 33 .authority() 34 .map(|authority| HeaderValue::from_str(authority.as_str())); 35 match ( 36 request.headers().contains_key(http::header::HOST), 37 authority, 38 ) { 39 (false, Some(Ok(value))) => { 40 request.headers_mut().insert(http::header::HOST, value); 41 request 42 } 43 _ => request, 44 } 45 }, 46 )) 47} 48 49#[cfg(test)] 50mod tests { 51 use super::*; 52 53 use axum::routing::get; 54 use tower::ServiceExt; 55 56 #[test] 57 fn alt_svc_header_advertises_h3() { 58 assert_eq!( 59 alt_svc_header(443).to_str().unwrap(), 60 "h3=\":443\"; ma=86400" 61 ); 62 } 63 64 #[tokio::test] 65 async fn alt_svc_added_to_responses_except_switching_protocols() { 66 let app = with_alt_svc( 67 Router::new().route("/ok", get(|| async { "ok" })).route( 68 "/upgrade", 69 get(|| async { 70 Response::builder() 71 .status(StatusCode::SWITCHING_PROTOCOLS) 72 .body(Body::empty()) 73 .unwrap() 74 }), 75 ), 76 443, 77 ); 78 79 let normal = app 80 .clone() 81 .oneshot(Request::get("/ok").body(Body::empty()).unwrap()) 82 .await 83 .unwrap(); 84 assert_eq!( 85 normal.headers().get(ALT_SVC).and_then(|v| v.to_str().ok()), 86 Some("h3=\":443\"; ma=86400") 87 ); 88 89 let upgrade = app 90 .oneshot(Request::get("/upgrade").body(Body::empty()).unwrap()) 91 .await 92 .unwrap(); 93 assert!( 94 upgrade.headers().get(ALT_SVC).is_none(), 95 "101 responses must not carry Alt-Svc" 96 ); 97 } 98}