Personal ATProto tools.
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 21 kB View raw
1//! Serving requests as a labeler. 2 3use axum::{ 4 body::Bytes, 5 extract::ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade}, 6 response::IntoResponse, 7 routing::any, 8 Router, 9}; 10use axum_extra::TypedHeader; 11use headers::{Header, HeaderName, HeaderValue}; 12use sqlx::{sqlite::{SqliteConnectOptions, SqliteJournalMode}, SqlitePool}; 13use std::str::FromStr; 14use std::ops::ControlFlow; 15use std::net::SocketAddr; 16use axum::extract::connect_info::ConnectInfo; 17use axum::extract::ws::CloseFrame; 18use futures::{sink::SinkExt, stream::StreamExt}; 19use axum::extract::Query; 20use axum::{Json, http::StatusCode, routing::get}; 21use serde::Deserialize; 22use tower_http::decompression::RequestDecompressionLayer; 23use tower_http::compression::CompressionLayer; 24 25use crate::{types::{AssignedLabelResponse, AssignedLabelResponseWrapper, SignatureBytes, SignatureEnum, SubscribeLabelsLabels, UriParams}, webrequest::Agent}; 26 27 28/// Launch the web server to respond to label inquiries. 29#[tracing::instrument] 30pub async fn main_webserve() { 31 let app = Router::new() 32 .route("/xrpc/com.atproto.label.subscribeLabels", any(subscribe_labels)) 33 .route("/xrpc/com.atproto.label.queryLabels", get(query_labels)) 34 .route("/xrpc/app.bsky.actor.getProfile", get(get_profile)) 35 .layer(RequestDecompressionLayer::new()) 36 .layer(CompressionLayer::new().deflate(true)); 37 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") 38 .await 39 .expect("Expected to bind to 0.0.0.0:3000 but failed."); 40 tracing::debug!("listening on {}", listener.local_addr().expect("Expected to get local address but failed.")); 41 axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await.expect("Expected to be able to use axum::serve but failed."); 42} 43 44/// Querys by DID. \ 45async fn query_labels(Query(params): Query<UriParams>) -> impl IntoResponse { 46 tracing::debug!("Querying labels: {:?}", params.uriPatterns); 47 drop(dotenvy::dotenv().expect("Failed to load .env file")); 48 let self_did = dotenvy::var("SELF_DID").expect("Expected to be able to get the SELF_DID from the environment, but failed"); 49 let src = jetstream_oxide::exports::Did::new(self_did).expect("Expected to be able to create a valid DID but failed"); 50 if let Some(uri_patterns) = params.uriPatterns { 51 let pool_opts = SqliteConnectOptions::from_str("sqlite://prod.db?mode=ro").expect("Expected to be able to configure the database, but failed.") 52 .journal_mode(SqliteJournalMode::Wal) 53 .read_only(true); 54 let pool = SqlitePool::connect_with(pool_opts).await.expect("Expected to be able to connect to the database at sqlite://prod.db but failed."); 55 let pattern = uri_patterns.replace("%", "").replace("_", "\\_"); // .replaceAll(/%/g, "").replaceAll(/_/g, "\\_"); 56 let star_index = pattern.find('*'); 57 let limit = params.limit.unwrap_or(50); 58 let cursor = params.cursor.unwrap_or_else(|| "0".to_owned()); 59 if let Some(star_index) = star_index { 60 if star_index != pattern.len() - 1 { 61 return (StatusCode::BAD_REQUEST, Json(AssignedLabelResponseWrapper { 62 cursor: "0".to_owned(), // TODO: Other servers don't respond with a cursor in this scenario. 63 labels: Vec::<AssignedLabelResponse>::new() 64 })); 65 } 66 let labels = sqlx::query!( 67 r#" 68 SELECT seq, uri "uri: String", val "val: String", neg, cts "cts: String", sig 69 FROM profile_labels 70 WHERE seq > ? 71 LIMIT ? 72 "#, 73 cursor, 74 limit 75 ) 76 .fetch_all(&pool) 77 .await 78 .expect("Expected to be able to fetch all labels from the database but failed."); 79 let smallest_cursor = labels.iter().map(|label| label.seq).min().unwrap_or(0); 80 return (StatusCode::OK, Json(AssignedLabelResponseWrapper { 81 cursor: smallest_cursor.to_string(), 82 labels: labels 83 .iter() 84 .map(|label| { 85 AssignedLabelResponse::reconstruct( 86 src.to_owned(), 87 label.uri.clone(), 88 label.val.clone(), 89 label.neg.unwrap_or(false), 90 label.cts.clone(), 91 SignatureEnum::Json(SignatureBytes::from_vec(label.sig.clone()).as_json_object()), 92 ) 93 }) 94 .collect(), 95 })); 96 } 97 let labels = sqlx::query!( 98 r#" 99 SELECT seq "seq: i64", uri "uri: String", val "val: String", neg, cts "cts: String", sig 100 FROM profile_labels WHERE uri = ? AND seq > ? LIMIT ? 101 "#, 102 uri_patterns, 103 cursor, 104 limit 105 ) 106 .fetch_all(&pool) 107 .await 108 .expect("Expected to be able to fetch all missing labels from the database but failed."); 109 let largest_cursor = labels.iter().map(|label| label.seq).max().unwrap_or(0); 110 return (StatusCode::OK, Json(AssignedLabelResponseWrapper { 111 cursor: largest_cursor.to_string(), 112 labels: labels 113 .iter() 114 .map(|label| { 115 AssignedLabelResponse::reconstruct( 116 src.to_owned(), 117 label.uri.clone(), 118 label.val.clone(), 119 label.neg.unwrap_or(false), 120 label.cts.clone(), 121 SignatureEnum::Json(SignatureBytes::from_vec(label.sig.clone()).as_json_object()), 122 ) 123 }) 124 .collect(), 125 })); 126 } 127 ( 128 StatusCode::OK, 129 Json(AssignedLabelResponseWrapper { 130 cursor: "0".to_owned(), 131 labels: Vec::<AssignedLabelResponse>::new() 132 }), 133 ) 134} 135/// Querys by profile name. 136async fn get_profile(Query(params): Query<UriParams>) -> impl IntoResponse { 137 if let Some(actor) = &params.actor { 138 let mut agent = Agent::default(); 139 if let Ok(profile) = agent.get_profile(actor).await { 140 return (StatusCode::OK, Json(profile)); 141 } 142 } 143 (StatusCode::OK, Json(serde_json::json!({}))) 144} 145 146/// Query parameters for subscribing to labels. 147#[derive(Deserialize)] 148struct SubscribeLabelsQueryParams { 149 /// The last known event seq number to backfill from. 150 cursor: Option<u64>, 151} 152 153 154#[derive(Debug)] 155struct XForwardedFor(String); 156 157impl Header for XForwardedFor { 158 fn name() -> &'static HeaderName { 159 &http::header::FORWARDED 160 } 161 162 fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error> 163 where 164 I: Iterator<Item = &'i HeaderValue>, 165 { 166 let value = values 167 .next() 168 .ok_or_else(headers::Error::invalid)?; 169 170 // We are only interested in the first IP address in the list. 171 let ip = value 172 .to_str() 173 .map_err(|_| headers::Error::invalid())? 174 .split(',') 175 .next() 176 .ok_or_else(headers::Error::invalid)?; 177 178 Ok(Self(ip.to_owned())) 179 } 180 181 fn encode<E>(&self, values: &mut E) 182 where 183 E: Extend<HeaderValue>, 184 { 185 let value = HeaderValue::from_str(&self.0).expect("Expected to be able to convert the X-Forwarded-For header to a string but failed."); 186 values.extend(std::iter::once(value)); 187 } 188} 189 190/// The handler for the HTTP request. 191/// 192/// This gets called when the HTTP request lands at the start 193/// of websocket negotiation. After this completes, the actual switching from HTTP to 194/// websocket protocol will occur. 195/// This is the last point where we can extract TCP/IP metadata such as IP address of the client 196/// as well as things from HTTP headers such as user-agent of the browser etc. 197async fn subscribe_labels( 198 ws: WebSocketUpgrade, 199 user_agent: Option<TypedHeader<headers::UserAgent>>, 200 x_forwarded_for: Option<TypedHeader<XForwardedFor>>, 201 ConnectInfo(connection_address): ConnectInfo<SocketAddr>, 202 Query(params): Query<SubscribeLabelsQueryParams>, 203) -> impl IntoResponse { 204 let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { 205 user_agent.to_string() 206 } else { 207 String::from("Unknown browser") 208 }; 209 // Check X-Forwarded-For header to get the apparent IP address of the client 210 // TODO: This header can be spoofed, and should only be trusted from a trusted proxy. 211 let apparent_ip = if let Some(TypedHeader(x_forwarded_for)) = x_forwarded_for { 212 SocketAddr::new(x_forwarded_for.0.parse().expect("Expected to be able to parse the X-Forwarded-For header as a socket address but failed."), 213 connection_address.port()) 214 } else { 215 connection_address 216 }; 217 tracing::debug!("`{user_agent}` at {apparent_ip} connected."); 218 let pool_opts = SqliteConnectOptions::from_str("sqlite://prod.db?mode=ro").expect("Expected to be able to configure the database, but failed.") 219 .journal_mode(SqliteJournalMode::Wal) 220 .read_only(true); 221 let pool = SqlitePool::connect_with(pool_opts).await.expect("Expected to be able to connect to the database at sqlite://prod.db but failed."); 222 let cursor = params.cursor.unwrap_or( 223 get_current_cursor_count(&pool) 224 .await.expect("Expected to be able to get the current cursor count but failed.") 225 .try_into().expect("Expected to be able to convert the current cursor count to a u64 but failed.") 226 ) as i64; 227 // finalize the upgrade process by returning upgrade callback. 228 // we can customize the callback by sending additional info such as address. 229 ws.on_upgrade(move |socket| handle_socket(socket, apparent_ip, cursor, pool)) 230} 231 232async fn get_current_cursor_count( 233 pool: &SqlitePool, 234) -> Result<i64, sqlx::Error> { 235 let current_cursor_count = sqlx::query!( 236 r#" 237 SELECT seq FROM profile_labels ORDER BY seq DESC LIMIT 1 238 "# 239 ) 240 .fetch_one(pool) 241 .await?; 242 Ok(current_cursor_count.seq) 243} 244 245/// Actual websocket statemachine (one will be spawned per connection) 246async fn handle_socket(socket: WebSocket, who: SocketAddr, cursor: i64, pool: SqlitePool) { 247 let _ = websocket_context(socket, who, cursor, pool).await; 248 // returning from the handler closes the websocket connection 249 tracing::debug!("Websocket context {who} destroyed"); 250} 251 252/// Get all missed messages, based on cursor. 253async fn get_missed_messages( 254 pool: &SqlitePool, 255 cursor: i64, 256) -> Result<Vec<SubscribeLabelsLabels>, sqlx::Error> { 257 let missed_messages = sqlx::query!( 258 r#" 259 SELECT seq, uri "uri: String", val "val: String", neg "neg: bool", cts "cts: String", sig 260 FROM profile_labels WHERE seq > ? 261 "#, 262 cursor 263 ) 264 .fetch_all(pool) 265 .await?; 266 drop(dotenvy::dotenv().expect("Failed to load .env file")); 267 let self_did = dotenvy::var("SELF_DID").expect("Expected to be able to get the SELF_DID from the environment, but failed"); 268 let src = jetstream_oxide::exports::Did::new(self_did).expect("Expected to be able to create a valid DID but failed"); 269 Ok(missed_messages 270 .iter() 271 .map(|label| { 272 SubscribeLabelsLabels { 273 seq: label.seq, 274 labels: vec![AssignedLabelResponse::reconstruct( 275 src.to_owned(), 276 label.uri.clone(), 277 label.val.clone(), 278 label.neg.unwrap_or(false), 279 label.cts.clone(), 280 SignatureEnum::Bytes(SignatureBytes::from_vec(label.sig.clone())), 281 )], 282 } 283 }) 284 .collect()) 285} 286 287async fn websocket_context(mut socket: WebSocket, who: SocketAddr, mut cursor: i64, pool: SqlitePool) -> ControlFlow<()> { 288 ws_send( 289 &mut socket, 290 who, 291 Message::Ping(Bytes::from_static(&[1, 2, 3])) 292 ).await?; 293 tracing::info!("{who} connected with cursor {cursor}"); 294 let current_cursor_count = get_current_cursor_count(&pool).await.unwrap_or_default(); 295 tracing::debug!("Current cursor count: {current_cursor_count}"); 296 if cursor < current_cursor_count { 297 let missed_messages = get_missed_messages(&pool, cursor).await.expect("Expected to be able to get missed messages but failed."); 298 for message in missed_messages { 299 tracing::info!("Sending missed message to {who}: {:?}", message); 300 let message_header: Vec<u8> = Bytes::from_static(b"\xa2atg#labelsbop\x01").into(); 301 let message_body = serde_cbor::to_vec(&message).expect("Expected to be able to serialize message to CBOR but failed."); 302 let message_combined = [message_header, message_body].concat(); 303 let message_finished = Message::Binary(message_combined.into()); 304 ws_send( 305 &mut socket, 306 who, 307 message_finished 308 ).await?; 309 } 310 cursor = current_cursor_count; 311 } 312 // By splitting socket we can send and receive at the same time. In this example we will send 313 // unsolicited messages to client based on some sort of server's internal event (i.e .timer). 314 let (mut sender, mut receiver) = socket.split(); 315 // Spawn a task that will push several messages to the client (does not matter what client does) 316 let mut send_task = tokio::spawn(async move { 317 const PING_INTERVAL: std::time::Duration = std::time::Duration::from_secs(60); 318 const BROADCAST_INTERVAL: std::time::Duration = std::time::Duration::from_secs(20); 319 let mut last_ping = tokio::time::Instant::now(); 320 let mut last_broadcast = tokio::time::Instant::now(); 321 let mut n_msg = 0; 322 loop { 323 tokio::select! { 324 _ = tokio::time::sleep_until(last_ping + PING_INTERVAL) => { 325 tracing::debug!("Sending ping to {who}..."); 326 if ws_send( 327 &mut sender, 328 who, 329 Message::Ping(Bytes::from_static(&[1, 2, 3])) 330 ).await.is_break() { 331 tracing::warn!("Client {who} failed to respond to ping"); 332 break; 333 } 334 tracing::debug!("Sent ping to {who}"); 335 last_ping = tokio::time::Instant::now(); 336 }, 337 _ = tokio::time::sleep_until(last_broadcast + BROADCAST_INTERVAL) => { 338 tracing::debug!("Polling for new messages to send to {who}..."); 339 let current_cursor_count = get_current_cursor_count(&pool).await.unwrap_or_default(); 340 if cursor < current_cursor_count { 341 let missed_messages = get_missed_messages(&pool, cursor).await; 342 if missed_messages.is_err() { 343 tracing::warn!("Error getting missed messages: {missed_messages:?}"); 344 last_broadcast = tokio::time::Instant::now(); 345 continue; 346 } 347 for message in missed_messages.expect("Expected to be able to get missed messages but failed.") { 348 let seq = message.seq; 349 let neg = message.labels[0].neg; 350 let uri = message.labels[0].uri.clone(); 351 let val = message.labels[0].val.clone(); 352 let prefix = if neg { "Negation" } else { "Emitting" }; 353 tracing::info!("{prefix} label {seq} to {who}: {uri} {val}"); 354 let message_header: Vec<u8> = Bytes::from_static(b"\xa2atg#labelsbop\x01").into(); 355 let message_body = serde_cbor::to_vec(&message).expect("Expected to be able to serialize message to CBOR but failed."); 356 let message_combined = [message_header, message_body].concat(); 357 let message_finished = Message::Binary(message_combined.into()); 358 if ws_send( 359 &mut sender, 360 who, 361 message_finished 362 ).await.is_break() { 363 tracing::warn!("Client {who} failed to receive missed message"); 364 break; 365 } 366 n_msg += 1; 367 } 368 cursor = current_cursor_count; 369 } 370 tracing::debug!("Finished poll for {who}"); 371 last_broadcast = tokio::time::Instant::now(); 372 } 373 } 374 } 375 tracing::info!("Sending close to {who}..."); 376 ws_close(sender).await; 377 n_msg 378 }); 379 380 // This second task will receive messages from client and print them on server console 381 let mut recv_task = tokio::spawn(async move { 382 let mut cnt = 0; 383 while let Some(Ok(msg)) = receiver.next().await { 384 cnt += 1; 385 // print message and break if instructed to do so 386 if process_message(msg, who).is_break() { 387 break; 388 } 389 } 390 cnt 391 }); 392 393 // If any one of the tasks exit, abort the other. 394 tokio::select! { 395 rv_a = (&mut send_task) => { 396 match rv_a { 397 Ok(a) => tracing::info!("{a} messages sent to {who}"), 398 Err(a) => tracing::warn!("Error sending messages {a:?}") 399 } 400 recv_task.abort(); 401 }, 402 rv_b = (&mut recv_task) => { 403 match rv_b { 404 Ok(b) => tracing::info!("Received {b} messages"), 405 Err(b) => tracing::warn!("Error receiving messages {b:?}") 406 } 407 send_task.abort(); 408 } 409 } 410 ControlFlow::Continue(()) 411} 412 413async fn ws_send( 414 socket: &mut (impl SinkExt<Message> + Unpin), 415 who: SocketAddr, 416 msg: Message, 417) -> ControlFlow<(), ()> { 418 if socket 419 .send(msg) 420 .await 421 .is_err() 422 { 423 tracing::warn!("client {who} abruptly disconnected"); 424 return ControlFlow::Break(()); 425 } 426 ControlFlow::Continue(()) 427} 428 429async fn ws_close(mut sender: futures::stream::SplitSink<WebSocket, Message>) { 430 if let Err(e) = sender 431 .send(Message::Close(Some(CloseFrame { 432 code: axum::extract::ws::close_code::NORMAL, 433 reason: Utf8Bytes::from_static("Goodbye"), 434 }))) 435 .await 436 { 437 tracing::warn!("Could not send Close due to {e}, probably it is ok?"); 438 } 439} 440 441/// helper to print contents of messages to stdout. Has special treatment for Close. 442#[allow(clippy::cognitive_complexity)] 443fn process_message(msg: Message, who: SocketAddr) -> ControlFlow<(), ()> { 444 match msg { 445 Message::Text(t) => { 446 tracing::debug!(">>> {who} sent str: {t:?}"); 447 } 448 Message::Binary(d) => { 449 tracing::debug!(">>> {} sent {} bytes: {:?}", who, d.len(), d); 450 } 451 Message::Close(c) => { 452 if let Some(cf) = c { 453 tracing::debug!( 454 ">>> {} sent close with code {} and reason `{}`", 455 who, cf.code, cf.reason 456 ); 457 } else { 458 tracing::debug!(">>> {who} somehow sent close message without CloseFrame"); 459 } 460 return ControlFlow::Break(()); 461 } 462 463 Message::Pong(v) => { 464 tracing::debug!(">>> {who} sent pong with {v:?}"); 465 } 466 // You should never need to manually handle Message::Ping, as axum's websocket library 467 // will do so for you automagically by replying with Pong and copying the v according to 468 // spec. But if you need the contents of the pings you can see them here. 469 Message::Ping(v) => { 470 tracing::debug!(">>> {who} sent ping with {v:?}"); 471 } 472 } 473 ControlFlow::Continue(()) 474} 475 476/// WIP: fetch likes from app.bsky.feed.like 477/// https://shimeji.us-east.host.bsky.network/xrpc/com.atproto.repo.listRecords?repo=did:plc:jrtgsidnmxaen4offglr5lsh&collection=app.bsky.feed.like&limit=100 478/// 479/// then return them as a custom feed 480/// request path is at "/xrpc/app.bsky.feed.getFeedSkeleton?<feed>&<limit>&<cursor>" 481#[allow(dead_code)] 482async fn get_feed_skeleton(Query(params): Query<UriParams>) -> impl IntoResponse { 483 if let Some(_uri_patterns) = &params.uriPatterns { 484 let mut _agent = Agent::default(); 485 486 } 487 (StatusCode::OK, Json(serde_json::json!({}))) 488}