A better Rust ATProto crate
1//! # Axum helpers for jacquard XRPC server implementations.
2//!
3//! ## Usage.
4//!
5//! This crate provides server-side helpers for wiring generated Jacquard XRPC
6//! endpoint marker types into axum routers. [`ExtractXrpc`] decodes query
7//! parameters or procedure bodies into owned request values, defaulting to
8//! `DefaultStr`-backed generated types, and [`XrpcResponse`] encodes endpoint
9//! outputs with the content type declared by the endpoint response marker.
10//!
11//! ```no_run
12//! use axum::Router;
13//! use jacquard::api::com_atproto::identity::resolve_handle::{
14//! ResolveHandleOutput, ResolveHandleRequest,
15//! };
16//! use jacquard::types::string::Did;
17//! use jacquard_axum::{ExtractXrpc, IntoRouter, XrpcResponse};
18//! use miette::{IntoDiagnostic, Result};
19//!
20//! async fn handle_resolve(
21//! ExtractXrpc(req): ExtractXrpc<ResolveHandleRequest>,
22//! ) -> XrpcResponse<ResolveHandleRequest> {
23//! let _handle = req.handle;
24//! XrpcResponse(ResolveHandleOutput {
25//! did: Did::new_static("did:plc:test").unwrap(),
26//! extra_data: None,
27//! })
28//! }
29//!
30//! #[tokio::main]
31//! async fn main() -> Result<()> {
32//! let app = Router::new()
33//! .route("/", axum::routing::get(|| async { "hello world!" }))
34//! .merge(ResolveHandleRequest::into_router(handle_resolve));
35//!
36//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
37//! .await
38//! .into_diagnostic()?;
39//! axum::serve(listener, app).await.unwrap();
40//! Ok(())
41//! }
42//! ```
43//!
44//! The extractor uses the [`XrpcEndpoint`] trait to determine request type:
45//!
46//! - Query endpoints deserialize from query string parameters.
47//! - Procedure endpoints deserialize from request bodies and preserve custom
48//! encodings through [`XrpcRequest::decode_body`].
49//!
50//! Deserialization errors return a 400 Bad Request with a JSON body matching the
51//! XRPC error format.
52
53pub mod did_web;
54pub mod oauth;
55#[cfg(feature = "service-auth")]
56pub mod service_auth;
57
58use std::borrow::Cow;
59
60use axum::{
61 Json, Router,
62 body::Bytes,
63 extract::{FromRequest, Request},
64 http::{StatusCode, header::CONTENT_TYPE},
65 response::{IntoResponse, Response},
66};
67use jacquard::{
68 BosStr, CowStr, DefaultStr, IntoStatic,
69 xrpc::{XrpcEndpoint, XrpcError, XrpcMethod, XrpcRequest, XrpcResp},
70};
71use serde::{Deserialize, de::DeserializeOwned};
72use serde_json::{Value, json};
73
74/// Axum extractor for XRPC requests.
75///
76/// Deserializes incoming requests based on the endpoint's method type and
77/// returns the request type ready for handler logic.
78pub struct ExtractXrpc<R: XrpcEndpoint, B: BosStr = DefaultStr>(pub R::Request<B>);
79
80impl<R, B, State> FromRequest<State> for ExtractXrpc<R, B>
81where
82 R: XrpcEndpoint,
83 B: XrpcExtractBacking<R>,
84 State: Send + Sync,
85{
86 type Rejection = Response;
87
88 fn from_request(
89 req: Request,
90 state: &State,
91 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
92 async {
93 match R::METHOD {
94 XrpcMethod::Procedure(_) => {
95 let body = Bytes::from_request(req, state)
96 .await
97 .map_err(IntoResponse::into_response)?;
98 B::decode_body(&body).map(ExtractXrpc)
99 }
100 XrpcMethod::Query => {
101 let query = req.uri().query().unwrap_or("");
102 B::decode_query(query).map(ExtractXrpc)
103 }
104 }
105 }
106 }
107}
108
109/// Backing-specific XRPC request extraction policy.
110///
111/// Implementations choose the decode backing separately from the handler-visible
112/// backing to avoid overlapping axum extractor impls.
113pub trait XrpcExtractBacking<R>: private::Sealed + BosStr + Sized + 'static
114where
115 R: XrpcEndpoint,
116{
117 /// Decodes a query request into this handler-visible backing.
118 fn decode_query(query: &str) -> Result<R::Request<Self>, Response>;
119
120 /// Decodes a procedure body into this handler-visible backing.
121 fn decode_body(body: &[u8]) -> Result<R::Request<Self>, Response>;
122}
123
124macro_rules! impl_owned_extract_backing {
125 ($backing:ty) => {
126 impl private::Sealed for $backing {}
127
128 impl<R> XrpcExtractBacking<R> for $backing
129 where
130 R: XrpcEndpoint,
131 R::Request<Self>: DeserializeOwned,
132 {
133 fn decode_query(query: &str) -> Result<R::Request<Self>, Response> {
134 serde_html_form::from_str::<R::Request<Self>>(query)
135 .map_err(|err| invalid_request(format!("failed to decode request: {err}")))
136 }
137
138 fn decode_body(body: &[u8]) -> Result<R::Request<Self>, Response> {
139 <R::Request<Self> as XrpcRequest>::decode_body(body)
140 .map_err(|err| invalid_request(format!("failed to decode request: {err}")))
141 }
142 }
143 };
144}
145
146impl_owned_extract_backing!(DefaultStr);
147impl_owned_extract_backing!(String);
148
149impl private::Sealed for CowStr<'static> {}
150
151impl<R> XrpcExtractBacking<R> for CowStr<'static>
152where
153 R: XrpcEndpoint,
154 for<'de> R::Request<CowStr<'de>>: Deserialize<'de>,
155 for<'a> R::Request<CowStr<'a>>: IntoStatic<Output = R::Request<CowStr<'static>>>,
156{
157 fn decode_query(query: &str) -> Result<R::Request<Self>, Response> {
158 serde_html_form::from_str::<R::Request<CowStr<'_>>>(query)
159 .map(IntoStatic::into_static)
160 .map_err(|err| invalid_request(format!("failed to decode request: {err}")))
161 }
162
163 fn decode_body(body: &[u8]) -> Result<R::Request<Self>, Response> {
164 <R::Request<CowStr<'_>> as XrpcRequest>::decode_body(body)
165 .map(IntoStatic::into_static)
166 .map_err(|err| invalid_request(format!("failed to decode request: {err}")))
167 }
168}
169
170impl private::Sealed for Cow<'static, str> {}
171
172impl<R> XrpcExtractBacking<R> for Cow<'static, str>
173where
174 R: XrpcEndpoint,
175 for<'de> R::Request<Cow<'de, str>>: Deserialize<'de>,
176 for<'a> R::Request<Cow<'a, str>>: IntoStatic<Output = R::Request<Cow<'static, str>>>,
177{
178 fn decode_query(query: &str) -> Result<R::Request<Self>, Response> {
179 serde_html_form::from_str::<R::Request<Cow<'_, str>>>(query)
180 .map(IntoStatic::into_static)
181 .map_err(|err| invalid_request(format!("failed to decode request: {err}")))
182 }
183
184 fn decode_body(body: &[u8]) -> Result<R::Request<Self>, Response> {
185 <R::Request<Cow<'_, str>> as XrpcRequest>::decode_body(body)
186 .map(IntoStatic::into_static)
187 .map_err(|err| invalid_request(format!("failed to decode request: {err}")))
188 }
189}
190
191mod private {
192 pub trait Sealed {}
193}
194
195/// Typed axum response wrapper for XRPC endpoint outputs.
196pub struct XrpcResponse<R: XrpcEndpoint, B: BosStr = DefaultStr>(
197 pub <R::Response as XrpcResp>::Output<B>,
198);
199
200impl<R, B> XrpcResponse<R, B>
201where
202 R: XrpcEndpoint,
203 B: BosStr,
204{
205 /// Creates a typed XRPC response from an endpoint output value.
206 pub fn new(value: <R::Response as XrpcResp>::Output<B>) -> Self {
207 Self(value)
208 }
209}
210
211impl<R, B> IntoResponse for XrpcResponse<R, B>
212where
213 R: XrpcEndpoint,
214 B: BosStr,
215 <R::Response as XrpcResp>::Output<B>: serde::Serialize,
216{
217 fn into_response(self) -> Response {
218 match <R::Response as XrpcResp>::encode_output(&self.0) {
219 Ok(body) => (
220 StatusCode::OK,
221 [(CONTENT_TYPE, <R::Response as XrpcResp>::ENCODING)],
222 body,
223 )
224 .into_response(),
225 Err(err) => internal_server_error_response(format!("failed to encode response: {err}")),
226 }
227 }
228}
229
230/// Conversion trait to turn an XRPC endpoint marker and a handler into a router.
231///
232/// This trait is implemented for endpoint marker types (`R: XrpcEndpoint`). It
233/// registers `R::PATH` and maps `R::METHOD` to GET for query endpoints or POST
234/// for procedure endpoints. The endpoint-associated call style intentionally
235/// avoids turbofish while keeping the endpoint type explicit to the compiler.
236pub trait IntoRouter {
237 /// Creates an axum router that invokes `handler` for this endpoint marker.
238 fn into_router<HandlerArgs, State, Handler>(handler: Handler) -> Router<State>
239 where
240 HandlerArgs: 'static,
241 State: Clone + Send + Sync + 'static,
242 Handler: axum::handler::Handler<HandlerArgs, State>;
243}
244
245impl<R> IntoRouter for R
246where
247 R: XrpcEndpoint,
248{
249 fn into_router<HandlerArgs, State, Handler>(handler: Handler) -> Router<State>
250 where
251 HandlerArgs: 'static,
252 State: Clone + Send + Sync + 'static,
253 Handler: axum::handler::Handler<HandlerArgs, State>,
254 {
255 Router::new().route(
256 R::PATH,
257 (match R::METHOD {
258 XrpcMethod::Query => axum::routing::get,
259 XrpcMethod::Procedure(_) => axum::routing::post,
260 })(handler),
261 )
262 }
263}
264
265/// Axum-compatible generic XRPC error response.
266///
267/// Use this type when no generated endpoint error value is available, such as
268/// local infrastructure failures in a handler.
269#[derive(Debug, Clone)]
270pub struct GenericXrpcErrorResponse {
271 pub status: StatusCode,
272 error: String,
273 message: Option<String>,
274}
275
276impl GenericXrpcErrorResponse {
277 /// Creates an internal server error response.
278 pub fn internal_server_error() -> Self {
279 Self::new(
280 StatusCode::INTERNAL_SERVER_ERROR,
281 "InternalServerError",
282 Some("internal server error"),
283 )
284 }
285
286 /// Creates a generic XRPC error response with a custom message.
287 pub fn internal_server_error_with_message(message: impl Into<String>) -> Self {
288 Self::new(
289 StatusCode::INTERNAL_SERVER_ERROR,
290 "InternalServerError",
291 Some(message),
292 )
293 }
294
295 /// Creates a generic XRPC error response.
296 pub fn new(
297 status: StatusCode,
298 error: impl Into<String>,
299 message: Option<impl Into<String>>,
300 ) -> Self {
301 Self {
302 status,
303 error: error.into(),
304 message: message.map(Into::into),
305 }
306 }
307}
308
309impl IntoResponse for GenericXrpcErrorResponse {
310 fn into_response(self) -> Response {
311 let mut body = json!({ "error": self.error });
312 if let Some(message) = self.message {
313 body["message"] = json!(message);
314 }
315 (self.status, Json(body)).into_response()
316 }
317}
318
319/// Axum-compatible typed XRPC error wrapper.
320///
321/// Implements [`IntoResponse`] for generated endpoint error enums and common
322/// XRPC client-side error variants.
323#[derive(Debug, thiserror::Error, miette::Diagnostic)]
324#[error("XRPC error: {error}")]
325pub struct XrpcErrorResponse<E>
326where
327 E: std::error::Error,
328{
329 pub status: StatusCode,
330 #[diagnostic_source]
331 pub error: XrpcError<E>,
332}
333
334impl<E> XrpcErrorResponse<E>
335where
336 E: std::error::Error,
337{
338 /// Creates a new `XrpcErrorResponse` from the given status code and error.
339 pub fn new(status: StatusCode, error: XrpcError<E>) -> Self {
340 Self { status, error }
341 }
342
343 /// Changes the status code of the error response.
344 pub fn with_status(self, status: StatusCode) -> Self {
345 Self {
346 status,
347 error: self.error,
348 }
349 }
350}
351
352impl<E> IntoResponse for XrpcErrorResponse<E>
353where
354 E: std::error::Error + serde::Serialize,
355{
356 fn into_response(self) -> Response {
357 let json = match self.error {
358 XrpcError::Xrpc(error) => typed_error_json(&error),
359 XrpcError::Auth(auth_error) => json!({
360 "error": "Authentication",
361 "message": format!("{auth_error}")
362 }),
363 XrpcError::Generic(generic) => typed_error_json(&generic),
364 XrpcError::Decode(error) => json!({
365 "error": "InvalidRequest",
366 "message": format!("failed to decode request: {error}")
367 }),
368 _ => json!({
369 "error": "InternalServerError",
370 "message": "unknown error"
371 }),
372 };
373 (self.status, Json(json)).into_response()
374 }
375}
376
377impl<E> From<XrpcError<E>> for XrpcErrorResponse<E>
378where
379 E: std::error::Error,
380{
381 fn from(value: XrpcError<E>) -> Self {
382 Self {
383 status: StatusCode::INTERNAL_SERVER_ERROR,
384 error: value,
385 }
386 }
387}
388
389impl<E> From<XrpcErrorResponse<E>> for XrpcError<E>
390where
391 E: std::error::Error,
392{
393 fn from(value: XrpcErrorResponse<E>) -> Self {
394 value.error
395 }
396}
397
398fn invalid_request(message: impl Into<String>) -> Response {
399 (
400 StatusCode::BAD_REQUEST,
401 Json(json!({
402 "error": "InvalidRequest",
403 "message": message.into()
404 })),
405 )
406 .into_response()
407}
408
409fn internal_server_error_response(message: impl Into<String>) -> Response {
410 (
411 StatusCode::INTERNAL_SERVER_ERROR,
412 Json(json!({
413 "error": "InternalServerError",
414 "message": message.into()
415 })),
416 )
417 .into_response()
418}
419
420fn typed_error_json(error: &(impl std::error::Error + serde::Serialize)) -> Value {
421 serde_json::to_value(error).unwrap_or_else(|_| {
422 json!({
423 "error": "InternalServerError",
424 "message": format!("{error}")
425 })
426 })
427}