Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)]
2use crate::gate::{get_gate, post_gate};
3use crate::mailer::{Mailer, build_mailer_from_env};
4use crate::oauth_provider::sign_in;
5use crate::xrpc::com_atproto_server::{
6 create_account, create_session, describe_server, get_session, update_email,
7};
8use anyhow::Result;
9use axum::{
10 Router,
11 body::Body,
12 handler::Handler,
13 http::{Method, header},
14 middleware as ax_middleware,
15 routing::get,
16 routing::post,
17};
18use axum_template::engine::Engine;
19use handlebars::Handlebars;
20use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
21use jacquard_common::types::did::Did;
22use jacquard_identity::{PublicResolver, resolver::PlcSource};
23use rand::Rng;
24use rust_embed::RustEmbed;
25use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
26use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
27use std::path::Path;
28use std::sync::Arc;
29use std::time::Duration;
30use std::{env, net::SocketAddr};
31use tower_governor::{
32 GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
33};
34use tower_http::cors::AllowHeaders;
35use tower_http::trace::{DefaultOnRequest, HttpMakeClassifier};
36use tower_http::{
37 compression::CompressionLayer,
38 cors::{Any, CorsLayer},
39 trace::TraceLayer,
40};
41use tracing::{Span, log};
42use tracing_subscriber::{EnvFilter, fmt, prelude::*};
43
44mod auth;
45mod gate;
46pub mod helpers;
47pub mod mailer;
48mod middleware;
49mod oauth_provider;
50mod xrpc;
51
52type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
53
54#[derive(RustEmbed)]
55#[folder = "email_templates"]
56#[include = "*.hbs"]
57struct EmailTemplates;
58
59#[derive(RustEmbed)]
60#[folder = "html_templates"]
61#[include = "*.hbs"]
62struct HtmlTemplates;
63
64/// Mostly the env variables that are used in the app
65#[derive(Clone, Debug)]
66pub struct AppConfig {
67 pds_base_url: String,
68 mailer_from: String,
69 email_subject: String,
70 allow_only_migrations: bool,
71 use_captcha: bool,
72 //The url to redirect to after a successful captcha. Defaults to https://bsky.app, but you may have another social-app fork you rather your users use
73 //that need to capture this redirect url for creating an account
74 default_successful_redirect_url: String,
75 pds_service_did: Did<'static>,
76 gate_jwe_key: Vec<u8>,
77 captcha_success_redirects: Vec<String>,
78}
79
80impl AppConfig {
81 pub fn new() -> Self {
82 let pds_base_url =
83 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
84 let mailer_from = env::var("PDS_EMAIL_FROM_ADDRESS")
85 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file");
86 //Hack not my favorite, but it does work
87 let allow_only_migrations = env::var("GATEKEEPER_ALLOW_ONLY_MIGRATIONS")
88 .map(|val| val.parse::<bool>().unwrap_or(false))
89 .unwrap_or(false);
90
91 let use_captcha = env::var("GATEKEEPER_CREATE_ACCOUNT_CAPTCHA")
92 .map(|val| val.parse::<bool>().unwrap_or(false))
93 .unwrap_or(false);
94
95 // PDS_SERVICE_DID is the did:web if set, if not it's PDS_HOSTNAME
96 let pds_service_did =
97 env::var("PDS_SERVICE_DID").unwrap_or_else(|_| match env::var("PDS_HOSTNAME") {
98 Ok(pds_hostname) => format!("did:web:{}", pds_hostname),
99 Err(_) => {
100 panic!("PDS_HOSTNAME or PDS_SERVICE_DID must be set in your pds.env file")
101 }
102 });
103
104 let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT")
105 .unwrap_or("Sign in to Bluesky".to_string());
106
107 // Load or generate JWE encryption key (32 bytes for AES-256)
108 let gate_jwe_key = env::var("GATEKEEPER_JWE_KEY")
109 .ok()
110 .and_then(|key_hex| hex::decode(key_hex).ok())
111 .unwrap_or_else(|| {
112 // Generate a random 32-byte key if not provided
113 let key: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect();
114 log::warn!("WARNING: No GATEKEEPER_JWE_KEY found in the environment. Generated random key (hex): {}", hex::encode(&key));
115 log::warn!("This is not strictly needed unless you scale PDS Gatekeeper. Will not also be able to verify tokens between reboots, but they are short lived (5mins).");
116 key
117 });
118
119 if gate_jwe_key.len() != 32 {
120 panic!(
121 "GATEKEEPER_JWE_KEY must be 32 bytes (64 hex characters) for AES-256 encryption"
122 );
123 }
124
125 let captcha_success_redirects = match env::var("GATEKEEPER_CAPTCHA_SUCCESS_REDIRECTS") {
126 Ok(from_env) => from_env.split(",").map(|s| s.trim().to_string()).collect(),
127 Err(_) => {
128 vec![
129 String::from("https://bsky.app"),
130 String::from("https://pdsmoover.com"),
131 String::from("https://blacksky.community"),
132 String::from("https://tektite.cc"),
133 ]
134 }
135 };
136
137 AppConfig {
138 pds_base_url,
139 mailer_from,
140 email_subject,
141 allow_only_migrations,
142 use_captcha,
143 default_successful_redirect_url: env::var("GATEKEEPER_DEFAULT_CAPTCHA_REDIRECT")
144 .unwrap_or("https://bsky.app".to_string()),
145 pds_service_did: pds_service_did
146 .parse()
147 .expect("PDS_SERVICE_DID is not a valid did or could not infer from PDS_HOSTNAME"),
148 gate_jwe_key,
149 captcha_success_redirects,
150 }
151 }
152}
153
154#[derive(Clone)]
155pub struct AppState {
156 account_pool: SqlitePool,
157 pds_gatekeeper_pool: SqlitePool,
158 reverse_proxy_client: HyperUtilClient,
159 mailer: Arc<Mailer>,
160 template_engine: Engine<Handlebars<'static>>,
161 resolver: Arc<PublicResolver>,
162 handle_cache: auth::HandleCache,
163 app_config: AppConfig,
164}
165
166async fn root_handler() -> impl axum::response::IntoResponse {
167 let body = r"
168
169 ...oO _.--X~~OO~~X--._ ...oOO
170 _.-~ / \ II / \ ~-._
171 [].-~ \ / \||/ \ / ~-.[] ...o
172 ...o _ ||/ \ / || \ / \|| _
173 (_) |X X || X X| (_)
174 _-~-_ ||\ / \ || / \ /|| _-~-_
175 ||||| || \ / \ /||\ / \ / || |||||
176 | |_|| \ / \ / || \ / \ / ||_| |
177 | |~|| X X || X X ||~| |
178==============| | || / \ / \ || / \ / \ || | |==============
179______________| | || / \ / \||/ \ / \ || | |______________
180 . . | | ||/ \ / || \ / \|| | | . .
181 / | | |X X || X X| | | / /
182 / . | | ||\ / \ || / \ /|| | | . / .
183. / | | || \ / \ /||\ / \ / || | | . .
184 . . | | || \ / \ / || \ / \ / || | | .
185 / | | || X X || X X || | | . / . /
186 / . | | || / \ / \ || / \ / \ || | | /
187 / | | || / \ / \||/ \ / \ || | | . /
188. . . | | ||/ \ / /||\ \ / \|| | | /. .
189 | |_|X X / II \ X X|_| | . . /
190==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |==============
191 ";
192
193 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
194
195 let banner = format!(" {body}\n{intro}");
196
197 (
198 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
199 banner,
200 )
201}
202
203#[tokio::main]
204async fn main() -> Result<(), Box<dyn std::error::Error>> {
205 let pds_env_location =
206 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string());
207
208 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location));
209 if let Err(e) = result_of_finding_pds_env {
210 log::error!(
211 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}"
212 );
213 }
214 // Sets up after the pds.env file is loaded
215 setup_tracing();
216
217 let pds_root =
218 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file");
219 let account_db_url = format!("{pds_root}/account.sqlite");
220
221 let account_options = SqliteConnectOptions::new()
222 .journal_mode(SqliteJournalMode::Wal)
223 .filename(account_db_url)
224 .busy_timeout(Duration::from_secs(5));
225
226 let account_pool = SqlitePoolOptions::new()
227 .max_connections(5)
228 .connect_with(account_options)
229 .await?;
230
231 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
232 let options = SqliteConnectOptions::new()
233 .journal_mode(SqliteJournalMode::Wal)
234 .filename(bells_db_url)
235 .create_if_missing(true)
236 .busy_timeout(Duration::from_secs(5));
237 let pds_gatekeeper_pool = SqlitePoolOptions::new()
238 .max_connections(5)
239 .connect_with(options)
240 .await?;
241
242 // Run migrations for the extra database
243 // Note: the migrations are embedded at compile time from the given directory
244 // sqlx
245 sqlx::migrate!("./migrations")
246 .run(&pds_gatekeeper_pool)
247 .await?;
248
249 let client: HyperUtilClient =
250 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
251 .build(HttpConnector::new());
252
253 //Emailer set up
254 let mailer = Arc::new(build_mailer_from_env()?);
255
256 //Email templates setup
257 let mut hbs = Handlebars::new();
258
259 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
260 if let Ok(users_email_directory) = users_email_directory {
261 hbs.register_template_file(
262 "two_factor_code.hbs",
263 format!("{users_email_directory}/two_factor_code.hbs"),
264 )?;
265 } else {
266 let _ = hbs.register_embed_templates::<EmailTemplates>();
267 }
268
269 let _ = hbs.register_embed_templates::<HtmlTemplates>();
270
271 //Reads the PLC source from the pds env's or defaults to ol faithful
272 let plc_source_url =
273 env::var("PDS_DID_PLC_URL").unwrap_or_else(|_| "https://plc.directory".to_string());
274 let plc_source = PlcSource::PlcDirectory {
275 base: plc_source_url.parse().unwrap(),
276 };
277 let mut resolver = PublicResolver::default();
278 resolver = resolver.with_plc_source(plc_source.clone());
279
280 let state = AppState {
281 account_pool,
282 pds_gatekeeper_pool,
283 reverse_proxy_client: client,
284 mailer,
285 template_engine: Engine::from(hbs),
286 resolver: Arc::new(resolver),
287 handle_cache: auth::HandleCache::new(),
288 app_config: AppConfig::new(),
289 };
290
291 // Rate limiting
292 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
293 let captcha_governor_conf = GovernorConfigBuilder::default()
294 .per_second(60)
295 .burst_size(5)
296 .key_extractor(SmartIpKeyExtractor)
297 .finish()
298 .expect("failed to create governor config for create session. this should not happen and is a bug");
299
300 // Create a second config with the same settings for the other endpoint
301 let sign_in_governor_conf = GovernorConfigBuilder::default()
302 .per_second(60)
303 .burst_size(5)
304 .key_extractor(SmartIpKeyExtractor)
305 .finish()
306 .expect(
307 "failed to create governor config for sign in. this should not happen and is a bug",
308 );
309
310 let create_account_limiter_time: Option<String> =
311 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
312 let create_account_limiter_burst: Option<String> =
313 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
314
315 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally
316 let mut create_account_governor_conf = GovernorConfigBuilder::default();
317 if create_account_limiter_time.is_some() {
318 let time = create_account_limiter_time
319 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
320 .parse::<u64>()
321 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
322 create_account_governor_conf.per_second(time);
323 }
324
325 if create_account_limiter_burst.is_some() {
326 let burst = create_account_limiter_burst
327 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
328 .parse::<u32>()
329 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
330 create_account_governor_conf.burst_size(burst);
331 }
332
333 let create_account_governor_conf = create_account_governor_conf
334 .key_extractor(SmartIpKeyExtractor)
335 .finish().expect(
336 "failed to create governor config for create account. this should not happen and is a bug",
337 );
338
339 let captcha_governor_limiter = captcha_governor_conf.limiter().clone();
340 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
341 let create_account_governor_limiter = create_account_governor_conf.limiter().clone();
342
343 let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf);
344
345 let interval = Duration::from_secs(60);
346 // a separate background task to clean up
347 std::thread::spawn(move || {
348 loop {
349 std::thread::sleep(interval);
350 captcha_governor_limiter.retain_recent();
351 sign_in_governor_limiter.retain_recent();
352 create_account_governor_limiter.retain_recent();
353 }
354 });
355
356 let cors = CorsLayer::new()
357 .allow_origin(Any)
358 .allow_methods([Method::GET, Method::OPTIONS, Method::POST])
359 .allow_headers(AllowHeaders::mirror_request());
360
361 let mut app = Router::new()
362 .route("/", get(root_handler))
363 .route("/xrpc/com.atproto.server.getSession", get(get_session))
364 .route(
365 "/xrpc/com.atproto.server.describeServer",
366 get(describe_server),
367 )
368 .route(
369 "/xrpc/com.atproto.server.updateEmail",
370 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
371 )
372 .route(
373 "/@atproto/oauth-provider/~api/sign-in",
374 post(sign_in).layer(sign_in_governor_layer.clone()),
375 )
376 .route(
377 "/xrpc/com.atproto.server.createSession",
378 post(create_session.layer(sign_in_governor_layer)),
379 )
380 .route(
381 "/xrpc/com.atproto.server.createAccount",
382 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)),
383 );
384
385 if state.app_config.use_captcha {
386 app = app.route(
387 "/gate/signup",
388 get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))),
389 );
390 }
391
392 let request_logging = env::var("GATEKEEPER_REQUEST_LOGGING")
393 .map(|v| v.eq_ignore_ascii_case("true") || v == "1")
394 .unwrap_or(false);
395
396 if request_logging {
397 app = app.layer(request_trace_layer());
398 }
399
400 let app = app
401 .layer(CompressionLayer::new())
402 .layer(cors)
403 .with_state(state);
404
405 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
406 let port: u16 = env::var("GATEKEEPER_PORT")
407 .ok()
408 .and_then(|s| s.parse().ok())
409 .unwrap_or(8080);
410 let addr: SocketAddr = format!("{host}:{port}")
411 .parse()
412 .expect("valid socket address");
413
414 let listener = tokio::net::TcpListener::bind(addr).await?;
415
416 let server = axum::serve(
417 listener,
418 app.into_make_service_with_connect_info::<SocketAddr>(),
419 )
420 .with_graceful_shutdown(shutdown_signal());
421
422 if let Err(err) = server.await {
423 log::error!("server error:{err}");
424 }
425
426 Ok(())
427}
428
429fn setup_tracing() {
430 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
431 let json = env::var("GATEKEEPER_LOG_FORMAT")
432 .map(|v| v.eq_ignore_ascii_case("json"))
433 .unwrap_or(false);
434
435 if json {
436 tracing_subscriber::registry()
437 .with(env_filter)
438 .with(fmt::layer().json())
439 .init();
440 } else {
441 tracing_subscriber::registry()
442 .with(env_filter)
443 .with(fmt::layer())
444 .init();
445 }
446}
447
448async fn shutdown_signal() {
449 // Wait for Ctrl+C
450 let ctrl_c = async {
451 tokio::signal::ctrl_c()
452 .await
453 .expect("failed to install Ctrl+C handler");
454 };
455
456 #[cfg(unix)]
457 let terminate = async {
458 use tokio::signal::unix::{SignalKind, signal};
459
460 let mut sigterm =
461 signal(SignalKind::terminate()).expect("failed to install signal handler");
462 sigterm.recv().await;
463 };
464
465 #[cfg(not(unix))]
466 let terminate = std::future::pending::<()>();
467
468 tokio::select! {
469 _ = ctrl_c => {},
470 _ = terminate => {},
471 }
472}
473
474fn request_trace_layer() -> TraceLayer<
475 HttpMakeClassifier,
476 impl Fn(&axum::http::Request<Body>) -> Span + Clone,
477 DefaultOnRequest,
478 impl Fn(&axum::http::Response<Body>, Duration, &Span) + Clone,
479> {
480 TraceLayer::new_for_http()
481 .make_span_with(|req: &axum::http::Request<Body>| {
482 let headers = req.headers();
483 tracing::info_span!("request",
484 method = %req.method(),
485 path = %req.uri().path(),
486 headers = %format!("{:?}", headers),
487 )
488 })
489 .on_response(
490 |resp: &axum::http::Response<Body>, latency: Duration, _span: &tracing::Span| {
491 tracing::info!(
492 status = resp.status().as_u16(),
493 latency_ms = latency.as_millis() as u64,
494 "response"
495 );
496 },
497 )
498}