Personal ATProto tools.
1//! Serving requests as a labeler.
2
3use axum::{
4 body::Bytes,
5 extract::ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
6 response::IntoResponse,
7 routing::any,
8 Router,
9};
10use axum_extra::TypedHeader;
11use headers::{Header, HeaderName, HeaderValue};
12use sqlx::{sqlite::{SqliteConnectOptions, SqliteJournalMode}, SqlitePool};
13use std::str::FromStr;
14use std::ops::ControlFlow;
15use std::net::SocketAddr;
16use axum::extract::connect_info::ConnectInfo;
17use axum::extract::ws::CloseFrame;
18use futures::{sink::SinkExt, stream::StreamExt};
19use axum::extract::Query;
20use axum::{Json, http::StatusCode, routing::get};
21use serde::Deserialize;
22use tower_http::decompression::RequestDecompressionLayer;
23use tower_http::compression::CompressionLayer;
24
25use crate::{types::{AssignedLabelResponse, AssignedLabelResponseWrapper, SignatureBytes, SignatureEnum, SubscribeLabelsLabels, UriParams}, webrequest::Agent};
26
27
28/// Launch the web server to respond to label inquiries.
29#[tracing::instrument]
30pub async fn main_webserve() {
31 let app = Router::new()
32 .route("/xrpc/com.atproto.label.subscribeLabels", any(subscribe_labels))
33 .route("/xrpc/com.atproto.label.queryLabels", get(query_labels))
34 .route("/xrpc/app.bsky.actor.getProfile", get(get_profile))
35 .layer(RequestDecompressionLayer::new())
36 .layer(CompressionLayer::new().deflate(true));
37 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
38 .await
39 .expect("Expected to bind to 0.0.0.0:3000 but failed.");
40 tracing::debug!("listening on {}", listener.local_addr().expect("Expected to get local address but failed."));
41 axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await.expect("Expected to be able to use axum::serve but failed.");
42}
43
44/// Querys by DID. \
45async fn query_labels(Query(params): Query<UriParams>) -> impl IntoResponse {
46 tracing::debug!("Querying labels: {:?}", params.uriPatterns);
47 drop(dotenvy::dotenv().expect("Failed to load .env file"));
48 let self_did = dotenvy::var("SELF_DID").expect("Expected to be able to get the SELF_DID from the environment, but failed");
49 let src = jetstream_oxide::exports::Did::new(self_did).expect("Expected to be able to create a valid DID but failed");
50 if let Some(uri_patterns) = params.uriPatterns {
51 let pool_opts = SqliteConnectOptions::from_str("sqlite://prod.db?mode=ro").expect("Expected to be able to configure the database, but failed.")
52 .journal_mode(SqliteJournalMode::Wal)
53 .read_only(true);
54 let pool = SqlitePool::connect_with(pool_opts).await.expect("Expected to be able to connect to the database at sqlite://prod.db but failed.");
55 let pattern = uri_patterns.replace("%", "").replace("_", "\\_"); // .replaceAll(/%/g, "").replaceAll(/_/g, "\\_");
56 let star_index = pattern.find('*');
57 let limit = params.limit.unwrap_or(50);
58 let cursor = params.cursor.unwrap_or_else(|| "0".to_owned());
59 if let Some(star_index) = star_index {
60 if star_index != pattern.len() - 1 {
61 return (StatusCode::BAD_REQUEST, Json(AssignedLabelResponseWrapper {
62 cursor: "0".to_owned(), // TODO: Other servers don't respond with a cursor in this scenario.
63 labels: Vec::<AssignedLabelResponse>::new()
64 }));
65 }
66 let labels = sqlx::query!(
67 r#"
68 SELECT seq, uri "uri: String", val "val: String", neg, cts "cts: String", sig
69 FROM profile_labels
70 WHERE seq > ?
71 LIMIT ?
72 "#,
73 cursor,
74 limit
75 )
76 .fetch_all(&pool)
77 .await
78 .expect("Expected to be able to fetch all labels from the database but failed.");
79 let smallest_cursor = labels.iter().map(|label| label.seq).min().unwrap_or(0);
80 return (StatusCode::OK, Json(AssignedLabelResponseWrapper {
81 cursor: smallest_cursor.to_string(),
82 labels: labels
83 .iter()
84 .map(|label| {
85 AssignedLabelResponse::reconstruct(
86 src.to_owned(),
87 label.uri.clone(),
88 label.val.clone(),
89 label.neg.unwrap_or(false),
90 label.cts.clone(),
91 SignatureEnum::Json(SignatureBytes::from_vec(label.sig.clone()).as_json_object()),
92 )
93 })
94 .collect(),
95 }));
96 }
97 let labels = sqlx::query!(
98 r#"
99 SELECT seq "seq: i64", uri "uri: String", val "val: String", neg, cts "cts: String", sig
100 FROM profile_labels WHERE uri = ? AND seq > ? LIMIT ?
101 "#,
102 uri_patterns,
103 cursor,
104 limit
105 )
106 .fetch_all(&pool)
107 .await
108 .expect("Expected to be able to fetch all missing labels from the database but failed.");
109 let largest_cursor = labels.iter().map(|label| label.seq).max().unwrap_or(0);
110 return (StatusCode::OK, Json(AssignedLabelResponseWrapper {
111 cursor: largest_cursor.to_string(),
112 labels: labels
113 .iter()
114 .map(|label| {
115 AssignedLabelResponse::reconstruct(
116 src.to_owned(),
117 label.uri.clone(),
118 label.val.clone(),
119 label.neg.unwrap_or(false),
120 label.cts.clone(),
121 SignatureEnum::Json(SignatureBytes::from_vec(label.sig.clone()).as_json_object()),
122 )
123 })
124 .collect(),
125 }));
126 }
127 (
128 StatusCode::OK,
129 Json(AssignedLabelResponseWrapper {
130 cursor: "0".to_owned(),
131 labels: Vec::<AssignedLabelResponse>::new()
132 }),
133 )
134}
135/// Querys by profile name.
136async fn get_profile(Query(params): Query<UriParams>) -> impl IntoResponse {
137 if let Some(actor) = ¶ms.actor {
138 let mut agent = Agent::default();
139 if let Ok(profile) = agent.get_profile(actor).await {
140 return (StatusCode::OK, Json(profile));
141 }
142 }
143 (StatusCode::OK, Json(serde_json::json!({})))
144}
145
146/// Query parameters for subscribing to labels.
147#[derive(Deserialize)]
148struct SubscribeLabelsQueryParams {
149 /// The last known event seq number to backfill from.
150 cursor: Option<u64>,
151}
152
153
154#[derive(Debug)]
155struct XForwardedFor(String);
156
157impl Header for XForwardedFor {
158 fn name() -> &'static HeaderName {
159 &http::header::FORWARDED
160 }
161
162 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
163 where
164 I: Iterator<Item = &'i HeaderValue>,
165 {
166 let value = values
167 .next()
168 .ok_or_else(headers::Error::invalid)?;
169
170 // We are only interested in the first IP address in the list.
171 let ip = value
172 .to_str()
173 .map_err(|_| headers::Error::invalid())?
174 .split(',')
175 .next()
176 .ok_or_else(headers::Error::invalid)?;
177
178 Ok(Self(ip.to_owned()))
179 }
180
181 fn encode<E>(&self, values: &mut E)
182 where
183 E: Extend<HeaderValue>,
184 {
185 let value = HeaderValue::from_str(&self.0).expect("Expected to be able to convert the X-Forwarded-For header to a string but failed.");
186 values.extend(std::iter::once(value));
187 }
188}
189
190/// The handler for the HTTP request.
191///
192/// This gets called when the HTTP request lands at the start
193/// of websocket negotiation. After this completes, the actual switching from HTTP to
194/// websocket protocol will occur.
195/// This is the last point where we can extract TCP/IP metadata such as IP address of the client
196/// as well as things from HTTP headers such as user-agent of the browser etc.
197async fn subscribe_labels(
198 ws: WebSocketUpgrade,
199 user_agent: Option<TypedHeader<headers::UserAgent>>,
200 x_forwarded_for: Option<TypedHeader<XForwardedFor>>,
201 ConnectInfo(connection_address): ConnectInfo<SocketAddr>,
202 Query(params): Query<SubscribeLabelsQueryParams>,
203) -> impl IntoResponse {
204 let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
205 user_agent.to_string()
206 } else {
207 String::from("Unknown browser")
208 };
209 // Check X-Forwarded-For header to get the apparent IP address of the client
210 // TODO: This header can be spoofed, and should only be trusted from a trusted proxy.
211 let apparent_ip = if let Some(TypedHeader(x_forwarded_for)) = x_forwarded_for {
212 SocketAddr::new(x_forwarded_for.0.parse().expect("Expected to be able to parse the X-Forwarded-For header as a socket address but failed."),
213 connection_address.port())
214 } else {
215 connection_address
216 };
217 tracing::debug!("`{user_agent}` at {apparent_ip} connected.");
218 let pool_opts = SqliteConnectOptions::from_str("sqlite://prod.db?mode=ro").expect("Expected to be able to configure the database, but failed.")
219 .journal_mode(SqliteJournalMode::Wal)
220 .read_only(true);
221 let pool = SqlitePool::connect_with(pool_opts).await.expect("Expected to be able to connect to the database at sqlite://prod.db but failed.");
222 let cursor = params.cursor.unwrap_or(
223 get_current_cursor_count(&pool)
224 .await.expect("Expected to be able to get the current cursor count but failed.")
225 .try_into().expect("Expected to be able to convert the current cursor count to a u64 but failed.")
226 ) as i64;
227 // finalize the upgrade process by returning upgrade callback.
228 // we can customize the callback by sending additional info such as address.
229 ws.on_upgrade(move |socket| handle_socket(socket, apparent_ip, cursor, pool))
230}
231
232async fn get_current_cursor_count(
233 pool: &SqlitePool,
234) -> Result<i64, sqlx::Error> {
235 let current_cursor_count = sqlx::query!(
236 r#"
237 SELECT seq FROM profile_labels ORDER BY seq DESC LIMIT 1
238 "#
239 )
240 .fetch_one(pool)
241 .await?;
242 Ok(current_cursor_count.seq)
243}
244
245/// Actual websocket statemachine (one will be spawned per connection)
246async fn handle_socket(socket: WebSocket, who: SocketAddr, cursor: i64, pool: SqlitePool) {
247 let _ = websocket_context(socket, who, cursor, pool).await;
248 // returning from the handler closes the websocket connection
249 tracing::debug!("Websocket context {who} destroyed");
250}
251
252/// Get all missed messages, based on cursor.
253async fn get_missed_messages(
254 pool: &SqlitePool,
255 cursor: i64,
256) -> Result<Vec<SubscribeLabelsLabels>, sqlx::Error> {
257 let missed_messages = sqlx::query!(
258 r#"
259 SELECT seq, uri "uri: String", val "val: String", neg "neg: bool", cts "cts: String", sig
260 FROM profile_labels WHERE seq > ?
261 "#,
262 cursor
263 )
264 .fetch_all(pool)
265 .await?;
266 drop(dotenvy::dotenv().expect("Failed to load .env file"));
267 let self_did = dotenvy::var("SELF_DID").expect("Expected to be able to get the SELF_DID from the environment, but failed");
268 let src = jetstream_oxide::exports::Did::new(self_did).expect("Expected to be able to create a valid DID but failed");
269 Ok(missed_messages
270 .iter()
271 .map(|label| {
272 SubscribeLabelsLabels {
273 seq: label.seq,
274 labels: vec![AssignedLabelResponse::reconstruct(
275 src.to_owned(),
276 label.uri.clone(),
277 label.val.clone(),
278 label.neg.unwrap_or(false),
279 label.cts.clone(),
280 SignatureEnum::Bytes(SignatureBytes::from_vec(label.sig.clone())),
281 )],
282 }
283 })
284 .collect())
285}
286
287async fn websocket_context(mut socket: WebSocket, who: SocketAddr, mut cursor: i64, pool: SqlitePool) -> ControlFlow<()> {
288 ws_send(
289 &mut socket,
290 who,
291 Message::Ping(Bytes::from_static(&[1, 2, 3]))
292 ).await?;
293 tracing::info!("{who} connected with cursor {cursor}");
294 let current_cursor_count = get_current_cursor_count(&pool).await.unwrap_or_default();
295 tracing::debug!("Current cursor count: {current_cursor_count}");
296 if cursor < current_cursor_count {
297 let missed_messages = get_missed_messages(&pool, cursor).await.expect("Expected to be able to get missed messages but failed.");
298 for message in missed_messages {
299 tracing::info!("Sending missed message to {who}: {:?}", message);
300 let message_header: Vec<u8> = Bytes::from_static(b"\xa2atg#labelsbop\x01").into();
301 let message_body = serde_cbor::to_vec(&message).expect("Expected to be able to serialize message to CBOR but failed.");
302 let message_combined = [message_header, message_body].concat();
303 let message_finished = Message::Binary(message_combined.into());
304 ws_send(
305 &mut socket,
306 who,
307 message_finished
308 ).await?;
309 }
310 cursor = current_cursor_count;
311 }
312 // By splitting socket we can send and receive at the same time. In this example we will send
313 // unsolicited messages to client based on some sort of server's internal event (i.e .timer).
314 let (mut sender, mut receiver) = socket.split();
315 // Spawn a task that will push several messages to the client (does not matter what client does)
316 let mut send_task = tokio::spawn(async move {
317 const PING_INTERVAL: std::time::Duration = std::time::Duration::from_secs(60);
318 const BROADCAST_INTERVAL: std::time::Duration = std::time::Duration::from_secs(20);
319 let mut last_ping = tokio::time::Instant::now();
320 let mut last_broadcast = tokio::time::Instant::now();
321 let mut n_msg = 0;
322 loop {
323 tokio::select! {
324 _ = tokio::time::sleep_until(last_ping + PING_INTERVAL) => {
325 tracing::debug!("Sending ping to {who}...");
326 if ws_send(
327 &mut sender,
328 who,
329 Message::Ping(Bytes::from_static(&[1, 2, 3]))
330 ).await.is_break() {
331 tracing::warn!("Client {who} failed to respond to ping");
332 break;
333 }
334 tracing::debug!("Sent ping to {who}");
335 last_ping = tokio::time::Instant::now();
336 },
337 _ = tokio::time::sleep_until(last_broadcast + BROADCAST_INTERVAL) => {
338 tracing::debug!("Polling for new messages to send to {who}...");
339 let current_cursor_count = get_current_cursor_count(&pool).await.unwrap_or_default();
340 if cursor < current_cursor_count {
341 let missed_messages = get_missed_messages(&pool, cursor).await;
342 if missed_messages.is_err() {
343 tracing::warn!("Error getting missed messages: {missed_messages:?}");
344 last_broadcast = tokio::time::Instant::now();
345 continue;
346 }
347 for message in missed_messages.expect("Expected to be able to get missed messages but failed.") {
348 let seq = message.seq;
349 let neg = message.labels[0].neg;
350 let uri = message.labels[0].uri.clone();
351 let val = message.labels[0].val.clone();
352 let prefix = if neg { "Negation" } else { "Emitting" };
353 tracing::info!("{prefix} label {seq} to {who}: {uri} {val}");
354 let message_header: Vec<u8> = Bytes::from_static(b"\xa2atg#labelsbop\x01").into();
355 let message_body = serde_cbor::to_vec(&message).expect("Expected to be able to serialize message to CBOR but failed.");
356 let message_combined = [message_header, message_body].concat();
357 let message_finished = Message::Binary(message_combined.into());
358 if ws_send(
359 &mut sender,
360 who,
361 message_finished
362 ).await.is_break() {
363 tracing::warn!("Client {who} failed to receive missed message");
364 break;
365 }
366 n_msg += 1;
367 }
368 cursor = current_cursor_count;
369 }
370 tracing::debug!("Finished poll for {who}");
371 last_broadcast = tokio::time::Instant::now();
372 }
373 }
374 }
375 tracing::info!("Sending close to {who}...");
376 ws_close(sender).await;
377 n_msg
378 });
379
380 // This second task will receive messages from client and print them on server console
381 let mut recv_task = tokio::spawn(async move {
382 let mut cnt = 0;
383 while let Some(Ok(msg)) = receiver.next().await {
384 cnt += 1;
385 // print message and break if instructed to do so
386 if process_message(msg, who).is_break() {
387 break;
388 }
389 }
390 cnt
391 });
392
393 // If any one of the tasks exit, abort the other.
394 tokio::select! {
395 rv_a = (&mut send_task) => {
396 match rv_a {
397 Ok(a) => tracing::info!("{a} messages sent to {who}"),
398 Err(a) => tracing::warn!("Error sending messages {a:?}")
399 }
400 recv_task.abort();
401 },
402 rv_b = (&mut recv_task) => {
403 match rv_b {
404 Ok(b) => tracing::info!("Received {b} messages"),
405 Err(b) => tracing::warn!("Error receiving messages {b:?}")
406 }
407 send_task.abort();
408 }
409 }
410 ControlFlow::Continue(())
411}
412
413async fn ws_send(
414 socket: &mut (impl SinkExt<Message> + Unpin),
415 who: SocketAddr,
416 msg: Message,
417) -> ControlFlow<(), ()> {
418 if socket
419 .send(msg)
420 .await
421 .is_err()
422 {
423 tracing::warn!("client {who} abruptly disconnected");
424 return ControlFlow::Break(());
425 }
426 ControlFlow::Continue(())
427}
428
429async fn ws_close(mut sender: futures::stream::SplitSink<WebSocket, Message>) {
430 if let Err(e) = sender
431 .send(Message::Close(Some(CloseFrame {
432 code: axum::extract::ws::close_code::NORMAL,
433 reason: Utf8Bytes::from_static("Goodbye"),
434 })))
435 .await
436 {
437 tracing::warn!("Could not send Close due to {e}, probably it is ok?");
438 }
439}
440
441/// helper to print contents of messages to stdout. Has special treatment for Close.
442#[allow(clippy::cognitive_complexity)]
443fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> {
444 match msg {
445 Message::Text(t) => {
446 tracing::debug!(">>> {who} sent str: {t:?}");
447 }
448 Message::Binary(d) => {
449 tracing::debug!(">>> {} sent {} bytes: {:?}", who, d.len(), d);
450 }
451 Message::Close(c) => {
452 if let Some(cf) = c {
453 tracing::debug!(
454 ">>> {} sent close with code {} and reason `{}`",
455 who, cf.code, cf.reason
456 );
457 } else {
458 tracing::debug!(">>> {who} somehow sent close message without CloseFrame");
459 }
460 return ControlFlow::Break(());
461 }
462
463 Message::Pong(v) => {
464 tracing::debug!(">>> {who} sent pong with {v:?}");
465 }
466 // You should never need to manually handle Message::Ping, as axum's websocket library
467 // will do so for you automagically by replying with Pong and copying the v according to
468 // spec. But if you need the contents of the pings you can see them here.
469 Message::Ping(v) => {
470 tracing::debug!(">>> {who} sent ping with {v:?}");
471 }
472 }
473 ControlFlow::Continue(())
474}
475
476/// WIP: fetch likes from app.bsky.feed.like
477/// https://shimeji.us-east.host.bsky.network/xrpc/com.atproto.repo.listRecords?repo=did:plc:jrtgsidnmxaen4offglr5lsh&collection=app.bsky.feed.like&limit=100
478///
479/// then return them as a custom feed
480/// request path is at "/xrpc/app.bsky.feed.getFeedSkeleton?<feed>&<limit>&<cursor>"
481#[allow(dead_code)]
482async fn get_feed_skeleton(Query(params): Query<UriParams>) -> impl IntoResponse {
483 if let Some(_uri_patterns) = ¶ms.uriPatterns {
484 let mut _agent = Agent::default();
485
486 }
487 (StatusCode::OK, Json(serde_json::json!({})))
488}