A better Rust ATProto crate
1use axum::{Json, Router, http::StatusCode, response::IntoResponse};
2use axum_test::TestServer;
3use jacquard::xrpc::{XrpcEndpoint, XrpcMethod, XrpcRequest, XrpcResp};
4use jacquard_axum::{ExtractXrpc, IntoRouter, XrpcResponse};
5use jacquard_common::bos::{BosStr, DefaultStr};
6use jacquard_common::types::string::Did;
7use jacquard_derive::IntoStatic;
8use serde::{Deserialize, Serialize};
9use std::{borrow::Cow, collections::BTreeMap};
10
11#[derive(Debug, Clone, Serialize, Deserialize, IntoStatic)]
12#[serde(bound(deserialize = "S: serde::Deserialize<'de> + BosStr"))]
13struct TestQueryInput<S: BosStr = DefaultStr> {
14 did: Did<S>,
15 #[serde(default)]
16 limit: Option<u32>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(bound(deserialize = "S: serde::Deserialize<'de> + BosStr"))]
21struct TestQueryOutput<S: BosStr = DefaultStr> {
22 did: Did<S>,
23 #[serde(skip_serializing_if = "BTreeMap::is_empty", default)]
24 extra_data: BTreeMap<String, serde_json::Value>,
25}
26
27struct TestQueryResponse;
28struct TestQueryRequest;
29
30#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]
31#[error("test error")]
32struct TestError;
33
34impl XrpcResp for TestQueryResponse {
35 const NSID: &'static str = "com.example.test.query";
36 const ENCODING: &'static str = "application/json";
37 type Output<S: BosStr> = TestQueryOutput<S>;
38 type Err = TestError;
39}
40
41impl<S: BosStr> XrpcRequest for TestQueryInput<S> {
42 const NSID: &'static str = "com.example.test.query";
43 const METHOD: XrpcMethod = XrpcMethod::Query;
44 type Response = TestQueryResponse;
45}
46
47impl XrpcEndpoint for TestQueryRequest {
48 const PATH: &'static str = "/xrpc/com.example.test.query";
49 const METHOD: XrpcMethod = XrpcMethod::Query;
50 type Request<S: BosStr> = TestQueryInput<S>;
51 type Response = TestQueryResponse;
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, IntoStatic)]
55#[serde(bound(deserialize = "S: serde::Deserialize<'de> + BosStr"))]
56struct TestProcedureInput<S: BosStr = DefaultStr> {
57 did: Did<S>,
58 active: bool,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62#[serde(bound(deserialize = "S: serde::Deserialize<'de> + BosStr"))]
63struct TestProcedureOutput<S: BosStr = DefaultStr> {
64 did: Did<S>,
65 active: bool,
66}
67
68struct TestProcedureResponse;
69struct TestProcedureRequest;
70
71impl XrpcResp for TestProcedureResponse {
72 const NSID: &'static str = "com.example.test.procedure";
73 const ENCODING: &'static str = "application/json";
74 type Output<S: BosStr> = TestProcedureOutput<S>;
75 type Err = TestError;
76}
77
78impl<S: BosStr> XrpcRequest for TestProcedureInput<S> {
79 const NSID: &'static str = "com.example.test.procedure";
80 const METHOD: XrpcMethod = XrpcMethod::Procedure("application/json");
81 type Response = TestProcedureResponse;
82}
83
84impl XrpcEndpoint for TestProcedureRequest {
85 const PATH: &'static str = "/xrpc/com.example.test.procedure";
86 const METHOD: XrpcMethod = XrpcMethod::Procedure("application/json");
87 type Request<S: BosStr> = TestProcedureInput<S>;
88 type Response = TestProcedureResponse;
89}
90
91async fn test_query_handler(ExtractXrpc(req): ExtractXrpc<TestQueryRequest>) -> impl IntoResponse {
92 Json(TestQueryOutput {
93 did: req.did,
94 extra_data: BTreeMap::new(),
95 })
96}
97
98async fn typed_query_handler(
99 ExtractXrpc(req): ExtractXrpc<TestQueryRequest>,
100) -> XrpcResponse<TestQueryRequest> {
101 XrpcResponse(TestQueryOutput {
102 did: req.did,
103 extra_data: BTreeMap::new(),
104 })
105}
106
107async fn string_query_handler(
108 ExtractXrpc(req): ExtractXrpc<TestQueryRequest, String>,
109) -> impl IntoResponse {
110 Json(TestQueryOutput {
111 did: req.did,
112 extra_data: BTreeMap::new(),
113 })
114}
115
116async fn cowstr_query_handler(
117 ExtractXrpc(req): ExtractXrpc<TestQueryRequest, jacquard::CowStr<'static>>,
118) -> impl IntoResponse {
119 Json(TestQueryOutput {
120 did: req.did,
121 extra_data: BTreeMap::new(),
122 })
123}
124
125async fn std_cow_query_handler(
126 ExtractXrpc(req): ExtractXrpc<TestQueryRequest, Cow<'static, str>>,
127) -> impl IntoResponse {
128 Json(TestQueryOutput {
129 did: req.did,
130 extra_data: BTreeMap::new(),
131 })
132}
133
134async fn test_procedure_handler(
135 ExtractXrpc(req): ExtractXrpc<TestProcedureRequest>,
136) -> impl IntoResponse {
137 Json(TestProcedureOutput {
138 did: req.did,
139 active: req.active,
140 })
141}
142
143async fn string_procedure_handler(
144 ExtractXrpc(req): ExtractXrpc<TestProcedureRequest, String>,
145) -> impl IntoResponse {
146 Json(TestProcedureOutput {
147 did: req.did,
148 active: req.active,
149 })
150}
151
152async fn cowstr_procedure_handler(
153 ExtractXrpc(req): ExtractXrpc<TestProcedureRequest, jacquard::CowStr<'static>>,
154) -> impl IntoResponse {
155 Json(TestProcedureOutput {
156 did: req.did,
157 active: req.active,
158 })
159}
160
161async fn std_cow_procedure_handler(
162 ExtractXrpc(req): ExtractXrpc<TestProcedureRequest, Cow<'static, str>>,
163) -> impl IntoResponse {
164 Json(TestProcedureOutput {
165 did: req.did,
166 active: req.active,
167 })
168}
169
170#[tokio::test]
171async fn test_url_encoded_did_in_query_params() {
172 let app = Router::new().merge(TestQueryRequest::into_router(test_query_handler));
173
174 let server = TestServer::new(app).unwrap();
175
176 let response = server
177 .get("/xrpc/com.example.test.query?did=did%3Aplc%3A123abc")
178 .await;
179
180 response.assert_status_ok();
181
182 let body_text = response.text();
183 let body: TestQueryOutput = serde_json::from_str(&body_text).unwrap();
184 assert_eq!(body.did.as_str(), "did:plc:123abc");
185}
186
187#[tokio::test]
188async fn test_unencoded_did_in_query_params() {
189 let app = Router::new().merge(TestQueryRequest::into_router(test_query_handler));
190
191 let server = TestServer::new(app).unwrap();
192
193 let response = server
194 .get("/xrpc/com.example.test.query?did=did:plc:123abc")
195 .await;
196
197 response.assert_status_ok();
198
199 let body_text = response.text();
200 let body: TestQueryOutput = serde_json::from_str(&body_text).unwrap();
201 assert_eq!(body.did.as_str(), "did:plc:123abc");
202}
203
204#[tokio::test]
205async fn test_multiple_params_with_encoded_did() {
206 let app = Router::new().merge(TestQueryRequest::into_router(test_query_handler));
207
208 let server = TestServer::new(app).unwrap();
209
210 let response = server
211 .get("/xrpc/com.example.test.query?did=did%3Aweb%3Aexample.com&limit=50")
212 .await;
213
214 response.assert_status_ok();
215
216 let body_text = response.text();
217 let body: TestQueryOutput = serde_json::from_str(&body_text).unwrap();
218 assert_eq!(body.did.as_str(), "did:web:example.com");
219}
220
221#[tokio::test]
222async fn test_string_extractor_decodes_query() {
223 let app = Router::new().merge(TestQueryRequest::into_router(string_query_handler));
224
225 let server = TestServer::new(app).unwrap();
226
227 let response = server
228 .get("/xrpc/com.example.test.query?did=did%3Aplc%3Astring")
229 .await;
230
231 response.assert_status_ok();
232 let body: TestQueryOutput = serde_json::from_str(&response.text()).unwrap();
233 assert_eq!(body.did.as_str(), "did:plc:string");
234}
235
236#[tokio::test]
237async fn test_cowstr_static_extractor_decodes_query() {
238 let app = Router::new().merge(TestQueryRequest::into_router(cowstr_query_handler));
239
240 let server = TestServer::new(app).unwrap();
241
242 let response = server
243 .get("/xrpc/com.example.test.query?did=did%3Aplc%3Acowstr")
244 .await;
245
246 response.assert_status_ok();
247 let body: TestQueryOutput = serde_json::from_str(&response.text()).unwrap();
248 assert_eq!(body.did.as_str(), "did:plc:cowstr");
249}
250
251#[tokio::test]
252async fn test_std_cow_static_extractor_decodes_query() {
253 let app = Router::new().merge(TestQueryRequest::into_router(std_cow_query_handler));
254
255 let server = TestServer::new(app).unwrap();
256
257 let response = server
258 .get("/xrpc/com.example.test.query?did=did%3Aplc%3Astd-cow")
259 .await;
260
261 response.assert_status_ok();
262 let body: TestQueryOutput = serde_json::from_str(&response.text()).unwrap();
263 assert_eq!(body.did.as_str(), "did:plc:std-cow");
264}
265
266#[tokio::test]
267async fn test_malformed_query_returns_xrpc_invalid_request() {
268 let app = Router::new().merge(TestQueryRequest::into_router(test_query_handler));
269
270 let server = TestServer::new(app).unwrap();
271
272 let response = server.get("/xrpc/com.example.test.query?limit=50").await;
273
274 response.assert_status_bad_request();
275 let body: serde_json::Value = serde_json::from_str(&response.text()).unwrap();
276 assert_eq!(body["error"], "InvalidRequest");
277}
278
279#[tokio::test]
280async fn test_procedure_post_decodes_body() {
281 let app = Router::new().merge(TestProcedureRequest::into_router(test_procedure_handler));
282
283 let server = TestServer::new(app).unwrap();
284
285 let response = server
286 .post("/xrpc/com.example.test.procedure")
287 .json(&serde_json::json!({
288 "did": "did:plc:procedure",
289 "active": true
290 }))
291 .await;
292
293 response.assert_status_ok();
294 let body: TestProcedureOutput = serde_json::from_str(&response.text()).unwrap();
295 assert_eq!(body.did.as_str(), "did:plc:procedure");
296 assert!(body.active);
297}
298
299#[tokio::test]
300async fn test_string_procedure_extractor_decodes_body() {
301 let app = Router::new().merge(TestProcedureRequest::into_router(string_procedure_handler));
302
303 let server = TestServer::new(app).unwrap();
304
305 let response = server
306 .post("/xrpc/com.example.test.procedure")
307 .json(&serde_json::json!({
308 "did": "did:plc:string-procedure",
309 "active": true
310 }))
311 .await;
312
313 response.assert_status_ok();
314 let body: TestProcedureOutput = serde_json::from_str(&response.text()).unwrap();
315 assert_eq!(body.did.as_str(), "did:plc:string-procedure");
316 assert!(body.active);
317}
318
319#[tokio::test]
320async fn test_cowstr_static_procedure_extractor_decodes_body() {
321 let app = Router::new().merge(TestProcedureRequest::into_router(cowstr_procedure_handler));
322
323 let server = TestServer::new(app).unwrap();
324
325 let response = server
326 .post("/xrpc/com.example.test.procedure")
327 .json(&serde_json::json!({
328 "did": "did:plc:cowstr-procedure",
329 "active": true
330 }))
331 .await;
332
333 response.assert_status_ok();
334 let body: TestProcedureOutput = serde_json::from_str(&response.text()).unwrap();
335 assert_eq!(body.did.as_str(), "did:plc:cowstr-procedure");
336 assert!(body.active);
337}
338
339#[tokio::test]
340async fn test_std_cow_static_procedure_extractor_decodes_body() {
341 let app = Router::new().merge(TestProcedureRequest::into_router(std_cow_procedure_handler));
342
343 let server = TestServer::new(app).unwrap();
344
345 let response = server
346 .post("/xrpc/com.example.test.procedure")
347 .json(&serde_json::json!({
348 "did": "did:plc:std-cow-procedure",
349 "active": true
350 }))
351 .await;
352
353 response.assert_status_ok();
354 let body: TestProcedureOutput = serde_json::from_str(&response.text()).unwrap();
355 assert_eq!(body.did.as_str(), "did:plc:std-cow-procedure");
356 assert!(body.active);
357}
358
359#[tokio::test]
360async fn test_get_to_procedure_route_is_rejected() {
361 let app = Router::new().merge(TestProcedureRequest::into_router(test_procedure_handler));
362
363 let server = TestServer::new(app).unwrap();
364
365 let response = server.get("/xrpc/com.example.test.procedure").await;
366
367 response.assert_status(StatusCode::METHOD_NOT_ALLOWED);
368}
369
370#[tokio::test]
371async fn test_malformed_procedure_body_returns_xrpc_invalid_request() {
372 let app = Router::new().merge(TestProcedureRequest::into_router(test_procedure_handler));
373
374 let server = TestServer::new(app).unwrap();
375
376 let response = server
377 .post("/xrpc/com.example.test.procedure")
378 .text("{not valid json")
379 .await;
380
381 response.assert_status_bad_request();
382 let body: serde_json::Value = serde_json::from_str(&response.text()).unwrap();
383 assert_eq!(body["error"], "InvalidRequest");
384}
385
386#[tokio::test]
387async fn test_xrpc_response_encodes_typed_output() {
388 let app = Router::new().merge(TestQueryRequest::into_router(typed_query_handler));
389
390 let server = TestServer::new(app).unwrap();
391
392 let response = server
393 .get("/xrpc/com.example.test.query?did=did%3Aplc%3Atyped")
394 .await;
395
396 response.assert_status_ok();
397 let content_type = response
398 .headers()
399 .get(axum::http::header::CONTENT_TYPE)
400 .unwrap()
401 .to_str()
402 .unwrap();
403 assert!(content_type.starts_with(TestQueryResponse::ENCODING));
404
405 let body: TestQueryOutput = serde_json::from_str(&response.text()).unwrap();
406 assert_eq!(body.did.as_str(), "did:plc:typed");
407}