Microservice to bring 2FA to self hosted PDSes
0

Configure Feed

Select the types of activity you want to include in your feed.

1use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; 2use axum::middleware as ax_middleware; 3mod middleware; 4use axum::body::Body; 5use axum::handler::Handler; 6use axum::http::{Method, header}; 7use axum::routing::post; 8use axum::{Router, routing::get}; 9use axum_template::engine::Engine; 10use handlebars::Handlebars; 11use hyper_util::client::legacy::connect::HttpConnector; 12use hyper_util::rt::TokioExecutor; 13use lettre::{AsyncSmtpTransport, Tokio1Executor}; 14use rust_embed::RustEmbed; 15use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 16use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 17use std::path::Path; 18use std::time::Duration; 19use std::{env, net::SocketAddr}; 20use tower_governor::GovernorLayer; 21use tower_governor::governor::GovernorConfigBuilder; 22use tower_http::compression::CompressionLayer; 23use tower_http::cors::{Any, CorsLayer}; 24use tracing::{error, log}; 25use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 26 27mod xrpc; 28 29type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 30 31#[derive(RustEmbed)] 32#[folder = "email_templates"] 33#[include = "*.hbs"] 34struct EmailTemplates; 35 36#[derive(Clone)] 37struct AppState { 38 account_pool: SqlitePool, 39 pds_gatekeeper_pool: SqlitePool, 40 reverse_proxy_client: HyperUtilClient, 41 pds_base_url: String, 42 mailer: AsyncSmtpTransport<Tokio1Executor>, 43 mailer_from: String, 44 template_engine: Engine<Handlebars<'static>>, 45} 46 47async fn root_handler() -> impl axum::response::IntoResponse { 48 let body = r" 49 50 ...oO _.--X~~OO~~X--._ ...oOO 51 _.-~ / \ II / \ ~-._ 52 [].-~ \ / \||/ \ / ~-.[] ...o 53 ...o _ ||/ \ / || \ / \|| _ 54 (_) |X X || X X| (_) 55 _-~-_ ||\ / \ || / \ /|| _-~-_ 56 ||||| || \ / \ /||\ / \ / || ||||| 57 | |_|| \ / \ / || \ / \ / ||_| | 58 | |~|| X X || X X ||~| | 59==============| | || / \ / \ || / \ / \ || | |============== 60______________| | || / \ / \||/ \ / \ || | |______________ 61 . . | | ||/ \ / || \ / \|| | | . . 62 / | | |X X || X X| | | / / 63 / . | | ||\ / \ || / \ /|| | | . / . 64. / | | || \ / \ /||\ / \ / || | | . . 65 . . | | || \ / \ / || \ / \ / || | | . 66 / | | || X X || X X || | | . / . / 67 / . | | || / \ / \ || / \ / \ || | | / 68 / | | || / \ / \||/ \ / \ || | | . / 69. . . | | ||/ \ / /||\ \ / \|| | | /. . 70 | |_|X X / II \ X X|_| | . . / 71==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 72 "; 73 74 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 75 76 let banner = format!(" {}\n{}", body, intro); 77 78 ( 79 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 80 banner, 81 ) 82} 83 84#[tokio::main] 85async fn main() -> Result<(), Box<dyn std::error::Error>> { 86 setup_tracing(); 87 //TODO prod 88 dotenvy::from_path(Path::new("./pds.env"))?; 89 let pds_root = env::var("PDS_DATA_DIRECTORY")?; 90 // let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data"; 91 let account_db_url = format!("{}/account.sqlite", pds_root); 92 log::info!("accounts_db_url: {}", account_db_url); 93 94 let account_options = SqliteConnectOptions::new() 95 .journal_mode(SqliteJournalMode::Wal) 96 .filename(account_db_url); 97 98 let account_pool = SqlitePoolOptions::new() 99 .max_connections(5) 100 .connect_with(account_options) 101 .await?; 102 103 let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root); 104 let options = SqliteConnectOptions::new() 105 .journal_mode(SqliteJournalMode::Wal) 106 .filename(bells_db_url) 107 .create_if_missing(true); 108 let pds_gatekeeper_pool = SqlitePoolOptions::new() 109 .max_connections(5) 110 .connect_with(options) 111 .await?; 112 113 // Run migrations for the bells_and_whistles database 114 // Note: the migrations are embedded at compile time from the given directory 115 // sqlx 116 sqlx::migrate!("./migrations") 117 .run(&pds_gatekeeper_pool) 118 .await?; 119 120 let client: HyperUtilClient = 121 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 122 .build(HttpConnector::new()); 123 124 //Emailer set up 125 let smtp_url = 126 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file"); 127 let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS") 128 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 129 let mailer: AsyncSmtpTransport<Tokio1Executor> = 130 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 131 //Email templates setup 132 let mut hbs = Handlebars::new(); 133 let _ = hbs.register_embed_templates::<EmailTemplates>(); 134 135 let state = AppState { 136 account_pool, 137 pds_gatekeeper_pool, 138 reverse_proxy_client: client, 139 //TODO should be env prob 140 pds_base_url: "http://localhost:3000".to_string(), 141 mailer, 142 mailer_from: sent_from, 143 template_engine: Engine::from(hbs), 144 }; 145 146 // Rate limiting 147 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 148 let governor_conf = GovernorConfigBuilder::default() 149 .per_second(60) 150 .burst_size(5) 151 .finish() 152 .unwrap(); 153 let governor_limiter = governor_conf.limiter().clone(); 154 let interval = Duration::from_secs(60); 155 // a separate background task to clean up 156 std::thread::spawn(move || { 157 loop { 158 std::thread::sleep(interval); 159 tracing::info!("rate limiting storage size: {}", governor_limiter.len()); 160 governor_limiter.retain_recent(); 161 } 162 }); 163 164 let cors = CorsLayer::new() 165 .allow_origin(Any) 166 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 167 .allow_headers(Any); 168 169 let app = Router::new() 170 .route("/", get(root_handler)) 171 .route( 172 "/xrpc/com.atproto.server.getSession", 173 get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)), 174 ) 175 .route( 176 "/xrpc/com.atproto.server.updateEmail", 177 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 178 ) 179 .route( 180 "/xrpc/com.atproto.server.createSession", 181 post(create_session.layer(GovernorLayer::new(governor_conf))), 182 ) 183 .layer(CompressionLayer::new()) 184 .layer(cors) 185 .with_state(state); 186 187 let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 188 let port: u16 = env::var("PORT") 189 .ok() 190 .and_then(|s| s.parse().ok()) 191 .unwrap_or(8080); 192 let addr: SocketAddr = format!("{host}:{port}") 193 .parse() 194 .expect("valid socket address"); 195 196 let listener = tokio::net::TcpListener::bind(addr).await?; 197 198 let server = axum::serve( 199 listener, 200 app.into_make_service_with_connect_info::<SocketAddr>(), 201 ) 202 .with_graceful_shutdown(shutdown_signal()); 203 204 if let Err(err) = server.await { 205 error!(error = %err, "server error"); 206 } 207 208 Ok(()) 209} 210 211fn setup_tracing() { 212 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 213 tracing_subscriber::registry() 214 .with(env_filter) 215 .with(fmt::layer()) 216 .init(); 217} 218 219async fn shutdown_signal() { 220 // Wait for Ctrl+C 221 let ctrl_c = async { 222 tokio::signal::ctrl_c() 223 .await 224 .expect("failed to install Ctrl+C handler"); 225 }; 226 227 #[cfg(unix)] 228 let terminate = async { 229 use tokio::signal::unix::{SignalKind, signal}; 230 231 let mut sigterm = 232 signal(SignalKind::terminate()).expect("failed to install signal handler"); 233 sigterm.recv().await; 234 }; 235 236 #[cfg(not(unix))] 237 let terminate = std::future::pending::<()>(); 238 239 tokio::select! { 240 _ = ctrl_c => {}, 241 _ = terminate => {}, 242 } 243}