Microservice to bring 2FA to self hosted PDSes
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 18 kB View raw
1#![warn(clippy::unwrap_used)] 2use crate::gate::{get_gate, post_gate}; 3use crate::mailer::{Mailer, build_mailer_from_env}; 4use crate::oauth_provider::sign_in; 5use crate::xrpc::com_atproto_server::{ 6 create_account, create_session, describe_server, get_session, update_email, 7}; 8use anyhow::Result; 9use axum::{ 10 Router, 11 body::Body, 12 handler::Handler, 13 http::{Method, header}, 14 middleware as ax_middleware, 15 routing::get, 16 routing::post, 17}; 18use axum_template::engine::Engine; 19use handlebars::Handlebars; 20use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; 21use jacquard_common::types::did::Did; 22use jacquard_identity::{PublicResolver, resolver::PlcSource}; 23use rand::Rng; 24use rust_embed::RustEmbed; 25use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 26use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 27use std::path::Path; 28use std::sync::Arc; 29use std::time::Duration; 30use std::{env, net::SocketAddr}; 31use tower_governor::{ 32 GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, 33}; 34use tower_http::cors::AllowHeaders; 35use tower_http::trace::{DefaultOnRequest, HttpMakeClassifier}; 36use tower_http::{ 37 compression::CompressionLayer, 38 cors::{Any, CorsLayer}, 39 trace::TraceLayer, 40}; 41use tracing::{Span, log}; 42use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 43 44mod auth; 45mod gate; 46pub mod helpers; 47pub mod mailer; 48mod middleware; 49mod oauth_provider; 50mod xrpc; 51 52type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 53 54#[derive(RustEmbed)] 55#[folder = "email_templates"] 56#[include = "*.hbs"] 57struct EmailTemplates; 58 59#[derive(RustEmbed)] 60#[folder = "html_templates"] 61#[include = "*.hbs"] 62struct HtmlTemplates; 63 64/// Mostly the env variables that are used in the app 65#[derive(Clone, Debug)] 66pub struct AppConfig { 67 pds_base_url: String, 68 mailer_from: String, 69 email_subject: String, 70 allow_only_migrations: bool, 71 use_captcha: bool, 72 //The url to redirect to after a successful captcha. Defaults to https://bsky.app, but you may have another social-app fork you rather your users use 73 //that need to capture this redirect url for creating an account 74 default_successful_redirect_url: String, 75 pds_service_did: Did<'static>, 76 gate_jwe_key: Vec<u8>, 77 captcha_success_redirects: Vec<String>, 78} 79 80impl AppConfig { 81 pub fn new() -> Self { 82 let pds_base_url = 83 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 84 let mailer_from = env::var("PDS_EMAIL_FROM_ADDRESS") 85 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 86 //Hack not my favorite, but it does work 87 let allow_only_migrations = env::var("GATEKEEPER_ALLOW_ONLY_MIGRATIONS") 88 .map(|val| val.parse::<bool>().unwrap_or(false)) 89 .unwrap_or(false); 90 91 let use_captcha = env::var("GATEKEEPER_CREATE_ACCOUNT_CAPTCHA") 92 .map(|val| val.parse::<bool>().unwrap_or(false)) 93 .unwrap_or(false); 94 95 // PDS_SERVICE_DID is the did:web if set, if not it's PDS_HOSTNAME 96 let pds_service_did = 97 env::var("PDS_SERVICE_DID").unwrap_or_else(|_| match env::var("PDS_HOSTNAME") { 98 Ok(pds_hostname) => format!("did:web:{}", pds_hostname), 99 Err(_) => { 100 panic!("PDS_HOSTNAME or PDS_SERVICE_DID must be set in your pds.env file") 101 } 102 }); 103 104 let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT") 105 .unwrap_or("Sign in to Bluesky".to_string()); 106 107 // Load or generate JWE encryption key (32 bytes for AES-256) 108 let gate_jwe_key = env::var("GATEKEEPER_JWE_KEY") 109 .ok() 110 .and_then(|key_hex| hex::decode(key_hex).ok()) 111 .unwrap_or_else(|| { 112 // Generate a random 32-byte key if not provided 113 let key: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect(); 114 log::warn!("WARNING: No GATEKEEPER_JWE_KEY found in the environment. Generated random key (hex): {}", hex::encode(&key)); 115 log::warn!("This is not strictly needed unless you scale PDS Gatekeeper. Will not also be able to verify tokens between reboots, but they are short lived (5mins)."); 116 key 117 }); 118 119 if gate_jwe_key.len() != 32 { 120 panic!( 121 "GATEKEEPER_JWE_KEY must be 32 bytes (64 hex characters) for AES-256 encryption" 122 ); 123 } 124 125 let captcha_success_redirects = match env::var("GATEKEEPER_CAPTCHA_SUCCESS_REDIRECTS") { 126 Ok(from_env) => from_env.split(",").map(|s| s.trim().to_string()).collect(), 127 Err(_) => { 128 vec![ 129 String::from("https://bsky.app"), 130 String::from("https://pdsmoover.com"), 131 String::from("https://blacksky.community"), 132 String::from("https://tektite.cc"), 133 ] 134 } 135 }; 136 137 AppConfig { 138 pds_base_url, 139 mailer_from, 140 email_subject, 141 allow_only_migrations, 142 use_captcha, 143 default_successful_redirect_url: env::var("GATEKEEPER_DEFAULT_CAPTCHA_REDIRECT") 144 .unwrap_or("https://bsky.app".to_string()), 145 pds_service_did: pds_service_did 146 .parse() 147 .expect("PDS_SERVICE_DID is not a valid did or could not infer from PDS_HOSTNAME"), 148 gate_jwe_key, 149 captcha_success_redirects, 150 } 151 } 152} 153 154#[derive(Clone)] 155pub struct AppState { 156 account_pool: SqlitePool, 157 pds_gatekeeper_pool: SqlitePool, 158 reverse_proxy_client: HyperUtilClient, 159 mailer: Arc<Mailer>, 160 template_engine: Engine<Handlebars<'static>>, 161 resolver: Arc<PublicResolver>, 162 handle_cache: auth::HandleCache, 163 app_config: AppConfig, 164} 165 166async fn root_handler() -> impl axum::response::IntoResponse { 167 let body = r" 168 169 ...oO _.--X~~OO~~X--._ ...oOO 170 _.-~ / \ II / \ ~-._ 171 [].-~ \ / \||/ \ / ~-.[] ...o 172 ...o _ ||/ \ / || \ / \|| _ 173 (_) |X X || X X| (_) 174 _-~-_ ||\ / \ || / \ /|| _-~-_ 175 ||||| || \ / \ /||\ / \ / || ||||| 176 | |_|| \ / \ / || \ / \ / ||_| | 177 | |~|| X X || X X ||~| | 178==============| | || / \ / \ || / \ / \ || | |============== 179______________| | || / \ / \||/ \ / \ || | |______________ 180 . . | | ||/ \ / || \ / \|| | | . . 181 / | | |X X || X X| | | / / 182 / . | | ||\ / \ || / \ /|| | | . / . 183. / | | || \ / \ /||\ / \ / || | | . . 184 . . | | || \ / \ / || \ / \ / || | | . 185 / | | || X X || X X || | | . / . / 186 / . | | || / \ / \ || / \ / \ || | | / 187 / | | || / \ / \||/ \ / \ || | | . / 188. . . | | ||/ \ / /||\ \ / \|| | | /. . 189 | |_|X X / II \ X X|_| | . . / 190==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 191 "; 192 193 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 194 195 let banner = format!(" {body}\n{intro}"); 196 197 ( 198 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 199 banner, 200 ) 201} 202 203#[tokio::main] 204async fn main() -> Result<(), Box<dyn std::error::Error>> { 205 let pds_env_location = 206 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 207 208 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); 209 if let Err(e) = result_of_finding_pds_env { 210 log::error!( 211 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" 212 ); 213 } 214 // Sets up after the pds.env file is loaded 215 setup_tracing(); 216 217 let pds_root = 218 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); 219 let account_db_url = format!("{pds_root}/account.sqlite"); 220 221 let account_options = SqliteConnectOptions::new() 222 .journal_mode(SqliteJournalMode::Wal) 223 .filename(account_db_url) 224 .busy_timeout(Duration::from_secs(5)); 225 226 let account_pool = SqlitePoolOptions::new() 227 .max_connections(5) 228 .connect_with(account_options) 229 .await?; 230 231 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 232 let options = SqliteConnectOptions::new() 233 .journal_mode(SqliteJournalMode::Wal) 234 .filename(bells_db_url) 235 .create_if_missing(true) 236 .busy_timeout(Duration::from_secs(5)); 237 let pds_gatekeeper_pool = SqlitePoolOptions::new() 238 .max_connections(5) 239 .connect_with(options) 240 .await?; 241 242 // Run migrations for the extra database 243 // Note: the migrations are embedded at compile time from the given directory 244 // sqlx 245 sqlx::migrate!("./migrations") 246 .run(&pds_gatekeeper_pool) 247 .await?; 248 249 let client: HyperUtilClient = 250 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 251 .build(HttpConnector::new()); 252 253 //Emailer set up 254 let mailer = Arc::new(build_mailer_from_env()?); 255 256 //Email templates setup 257 let mut hbs = Handlebars::new(); 258 259 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 260 if let Ok(users_email_directory) = users_email_directory { 261 hbs.register_template_file( 262 "two_factor_code.hbs", 263 format!("{users_email_directory}/two_factor_code.hbs"), 264 )?; 265 } else { 266 let _ = hbs.register_embed_templates::<EmailTemplates>(); 267 } 268 269 let _ = hbs.register_embed_templates::<HtmlTemplates>(); 270 271 //Reads the PLC source from the pds env's or defaults to ol faithful 272 let plc_source_url = 273 env::var("PDS_DID_PLC_URL").unwrap_or_else(|_| "https://plc.directory".to_string()); 274 let plc_source = PlcSource::PlcDirectory { 275 base: plc_source_url.parse().unwrap(), 276 }; 277 let mut resolver = PublicResolver::default(); 278 resolver = resolver.with_plc_source(plc_source.clone()); 279 280 let state = AppState { 281 account_pool, 282 pds_gatekeeper_pool, 283 reverse_proxy_client: client, 284 mailer, 285 template_engine: Engine::from(hbs), 286 resolver: Arc::new(resolver), 287 handle_cache: auth::HandleCache::new(), 288 app_config: AppConfig::new(), 289 }; 290 291 // Rate limiting 292 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 293 let captcha_governor_conf = GovernorConfigBuilder::default() 294 .per_second(60) 295 .burst_size(5) 296 .key_extractor(SmartIpKeyExtractor) 297 .finish() 298 .expect("failed to create governor config for create session. this should not happen and is a bug"); 299 300 // Create a second config with the same settings for the other endpoint 301 let sign_in_governor_conf = GovernorConfigBuilder::default() 302 .per_second(60) 303 .burst_size(5) 304 .key_extractor(SmartIpKeyExtractor) 305 .finish() 306 .expect( 307 "failed to create governor config for sign in. this should not happen and is a bug", 308 ); 309 310 let create_account_limiter_time: Option<String> = 311 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 312 let create_account_limiter_burst: Option<String> = 313 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 314 315 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 316 let mut create_account_governor_conf = GovernorConfigBuilder::default(); 317 if create_account_limiter_time.is_some() { 318 let time = create_account_limiter_time 319 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 320 .parse::<u64>() 321 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 322 create_account_governor_conf.per_second(time); 323 } 324 325 if create_account_limiter_burst.is_some() { 326 let burst = create_account_limiter_burst 327 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 328 .parse::<u32>() 329 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 330 create_account_governor_conf.burst_size(burst); 331 } 332 333 let create_account_governor_conf = create_account_governor_conf 334 .key_extractor(SmartIpKeyExtractor) 335 .finish().expect( 336 "failed to create governor config for create account. this should not happen and is a bug", 337 ); 338 339 let captcha_governor_limiter = captcha_governor_conf.limiter().clone(); 340 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 341 let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 342 343 let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf); 344 345 let interval = Duration::from_secs(60); 346 // a separate background task to clean up 347 std::thread::spawn(move || { 348 loop { 349 std::thread::sleep(interval); 350 captcha_governor_limiter.retain_recent(); 351 sign_in_governor_limiter.retain_recent(); 352 create_account_governor_limiter.retain_recent(); 353 } 354 }); 355 356 let cors = CorsLayer::new() 357 .allow_origin(Any) 358 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 359 .allow_headers(AllowHeaders::mirror_request()); 360 361 let mut app = Router::new() 362 .route("/", get(root_handler)) 363 .route("/xrpc/com.atproto.server.getSession", get(get_session)) 364 .route( 365 "/xrpc/com.atproto.server.describeServer", 366 get(describe_server), 367 ) 368 .route( 369 "/xrpc/com.atproto.server.updateEmail", 370 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 371 ) 372 .route( 373 "/@atproto/oauth-provider/~api/sign-in", 374 post(sign_in).layer(sign_in_governor_layer.clone()), 375 ) 376 .route( 377 "/xrpc/com.atproto.server.createSession", 378 post(create_session.layer(sign_in_governor_layer)), 379 ) 380 .route( 381 "/xrpc/com.atproto.server.createAccount", 382 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 383 ); 384 385 if state.app_config.use_captcha { 386 app = app.route( 387 "/gate/signup", 388 get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))), 389 ); 390 } 391 392 let request_logging = env::var("GATEKEEPER_REQUEST_LOGGING") 393 .map(|v| v.eq_ignore_ascii_case("true") || v == "1") 394 .unwrap_or(false); 395 396 if request_logging { 397 app = app.layer(request_trace_layer()); 398 } 399 400 let app = app 401 .layer(CompressionLayer::new()) 402 .layer(cors) 403 .with_state(state); 404 405 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()); 406 let port: u16 = env::var("GATEKEEPER_PORT") 407 .ok() 408 .and_then(|s| s.parse().ok()) 409 .unwrap_or(8080); 410 let addr: SocketAddr = format!("{host}:{port}") 411 .parse() 412 .expect("valid socket address"); 413 414 let listener = tokio::net::TcpListener::bind(addr).await?; 415 416 let server = axum::serve( 417 listener, 418 app.into_make_service_with_connect_info::<SocketAddr>(), 419 ) 420 .with_graceful_shutdown(shutdown_signal()); 421 422 if let Err(err) = server.await { 423 log::error!("server error:{err}"); 424 } 425 426 Ok(()) 427} 428 429fn setup_tracing() { 430 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 431 let json = env::var("GATEKEEPER_LOG_FORMAT") 432 .map(|v| v.eq_ignore_ascii_case("json")) 433 .unwrap_or(false); 434 435 if json { 436 tracing_subscriber::registry() 437 .with(env_filter) 438 .with(fmt::layer().json()) 439 .init(); 440 } else { 441 tracing_subscriber::registry() 442 .with(env_filter) 443 .with(fmt::layer()) 444 .init(); 445 } 446} 447 448async fn shutdown_signal() { 449 // Wait for Ctrl+C 450 let ctrl_c = async { 451 tokio::signal::ctrl_c() 452 .await 453 .expect("failed to install Ctrl+C handler"); 454 }; 455 456 #[cfg(unix)] 457 let terminate = async { 458 use tokio::signal::unix::{SignalKind, signal}; 459 460 let mut sigterm = 461 signal(SignalKind::terminate()).expect("failed to install signal handler"); 462 sigterm.recv().await; 463 }; 464 465 #[cfg(not(unix))] 466 let terminate = std::future::pending::<()>(); 467 468 tokio::select! { 469 _ = ctrl_c => {}, 470 _ = terminate => {}, 471 } 472} 473 474fn request_trace_layer() -> TraceLayer< 475 HttpMakeClassifier, 476 impl Fn(&axum::http::Request<Body>) -> Span + Clone, 477 DefaultOnRequest, 478 impl Fn(&axum::http::Response<Body>, Duration, &Span) + Clone, 479> { 480 TraceLayer::new_for_http() 481 .make_span_with(|req: &axum::http::Request<Body>| { 482 let headers = req.headers(); 483 tracing::info_span!("request", 484 method = %req.method(), 485 path = %req.uri().path(), 486 headers = %format!("{:?}", headers), 487 ) 488 }) 489 .on_response( 490 |resp: &axum::http::Response<Body>, latency: Duration, _span: &tracing::Span| { 491 tracing::info!( 492 status = resp.status().as_u16(), 493 latency_ms = latency.as_millis() as u64, 494 "response" 495 ); 496 }, 497 ) 498}