Microservice to bring 2FA to self hosted PDSes
1use crate::helpers::json_error_response;
2use axum::extract::Request;
3use axum::http::header::AUTHORIZATION;
4use axum::http::{HeaderMap, StatusCode};
5use axum::middleware::Next;
6use axum::response::IntoResponse;
7use jwt_compact::alg::{Hs256, Hs256Key};
8use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError};
9use serde::{Deserialize, Serialize};
10use std::env;
11use tracing::log;
12
13#[derive(Clone, Debug)]
14pub struct Did(pub Option<String>);
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum AuthScheme {
18 Bearer,
19 DPoP,
20}
21
22#[derive(Serialize, Deserialize)]
23pub struct TokenClaims {
24 pub sub: String,
25}
26
27pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse {
28 let auth = extract_auth(req.headers());
29
30 match auth {
31 Ok(auth_opt) => {
32 match auth_opt {
33 None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
34 .expect("Error creating an error response"),
35 Some((scheme, token_str)) => {
36 // For Bearer, validate JWT and extract DID from `sub`.
37 // For DPoP, we currently only pass through and do not validate here; insert None DID.
38 // match scheme {
39 // AuthScheme::Bearer => {
40 let token = UntrustedToken::new(&token_str);
41 if token.is_err() {
42 return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
43 .expect("Error creating an error response");
44 }
45 let parsed_token = token.expect("Already checked for error");
46 let claims: Result<Claims<TokenClaims>, ValidationError> =
47 parsed_token.deserialize_claims_unchecked();
48 if claims.is_err() {
49 return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
50 .expect("Error creating an error response");
51 }
52
53 let key = Hs256Key::new(
54 env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"),
55 );
56 let token: Result<Token<TokenClaims>, ValidationError> =
57 Hs256.validator(&key).validate(&parsed_token);
58 if token.is_err() {
59 return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
60 .expect("Error creating an error response");
61 }
62 let token = token.expect("Already checked for error,");
63 // Not going to worry about expiration since it still goes to the PDS
64 req.extensions_mut()
65 .insert(Did(Some(token.claims().custom.sub.clone())));
66 // }
67 // AuthScheme::DPoP => {
68 // // No DID extraction from DPoP here; leave None
69 // req.extensions_mut().insert(Did(None));
70 // }
71 // }
72
73 next.run(req).await
74 }
75 }
76 }
77 Err(err) => {
78 log::error!("Error extracting token: {err}");
79 json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
80 .expect("Error creating an error response")
81 }
82 }
83}
84
85fn extract_auth(headers: &HeaderMap) -> Result<Option<(AuthScheme, String)>, String> {
86 match headers.get(axum::http::header::AUTHORIZATION) {
87 None => Ok(None),
88 Some(hv) => {
89 match hv.to_str() {
90 Err(_) => Err("Authorization header is not valid".into()),
91 Ok(s) => {
92 // Accept forms like: "Bearer <token>" or "DPoP <token>" (case-sensitive for the scheme here)
93 let mut parts = s.splitn(2, ' ');
94 match (parts.next(), parts.next()) {
95 (Some("Bearer"), Some(tok)) if !tok.is_empty() =>
96 Ok(Some((AuthScheme::Bearer, tok.to_string()))),
97 (Some("DPoP"), Some(tok)) if !tok.is_empty() =>
98 Ok(Some((AuthScheme::DPoP, tok.to_string()))),
99 _ => Err("Authorization header must be in format 'Bearer <token>' or 'DPoP <token>'".into()),
100 }
101 }
102 }
103 }
104 }
105}