Microservice to bring 2FA to self hosted PDSes
1use crate::xrpc::com_atproto_server::{create_session, get_session, update_email};
2use axum::middleware as ax_middleware;
3mod middleware;
4use axum::body::Body;
5use axum::handler::Handler;
6use axum::http::{Method, header};
7use axum::routing::post;
8use axum::{Router, routing::get};
9use axum_template::engine::Engine;
10use handlebars::Handlebars;
11use hyper_util::client::legacy::connect::HttpConnector;
12use hyper_util::rt::TokioExecutor;
13use lettre::{AsyncSmtpTransport, Tokio1Executor};
14use rust_embed::RustEmbed;
15use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
16use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
17use std::path::Path;
18use std::time::Duration;
19use std::{env, net::SocketAddr};
20use tower_governor::GovernorLayer;
21use tower_governor::governor::GovernorConfigBuilder;
22use tower_http::compression::CompressionLayer;
23use tower_http::cors::{Any, CorsLayer};
24use tracing::{error, log};
25use tracing_subscriber::{EnvFilter, fmt, prelude::*};
26
27mod xrpc;
28
29type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
30
31#[derive(RustEmbed)]
32#[folder = "email_templates"]
33#[include = "*.hbs"]
34struct EmailTemplates;
35
36#[derive(Clone)]
37struct AppState {
38 account_pool: SqlitePool,
39 pds_gatekeeper_pool: SqlitePool,
40 reverse_proxy_client: HyperUtilClient,
41 pds_base_url: String,
42 mailer: AsyncSmtpTransport<Tokio1Executor>,
43 mailer_from: String,
44 template_engine: Engine<Handlebars<'static>>,
45}
46
47async fn root_handler() -> impl axum::response::IntoResponse {
48 let body = r"
49
50 ...oO _.--X~~OO~~X--._ ...oOO
51 _.-~ / \ II / \ ~-._
52 [].-~ \ / \||/ \ / ~-.[] ...o
53 ...o _ ||/ \ / || \ / \|| _
54 (_) |X X || X X| (_)
55 _-~-_ ||\ / \ || / \ /|| _-~-_
56 ||||| || \ / \ /||\ / \ / || |||||
57 | |_|| \ / \ / || \ / \ / ||_| |
58 | |~|| X X || X X ||~| |
59==============| | || / \ / \ || / \ / \ || | |==============
60______________| | || / \ / \||/ \ / \ || | |______________
61 . . | | ||/ \ / || \ / \|| | | . .
62 / | | |X X || X X| | | / /
63 / . | | ||\ / \ || / \ /|| | | . / .
64. / | | || \ / \ /||\ / \ / || | | . .
65 . . | | || \ / \ / || \ / \ / || | | .
66 / | | || X X || X X || | | . / . /
67 / . | | || / \ / \ || / \ / \ || | | /
68 / | | || / \ / \||/ \ / \ || | | . /
69. . . | | ||/ \ / /||\ \ / \|| | | /. .
70 | |_|X X / II \ X X|_| | . . /
71==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |==============
72 ";
73
74 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
75
76 let banner = format!(" {}\n{}", body, intro);
77
78 (
79 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
80 banner,
81 )
82}
83
84#[tokio::main]
85async fn main() -> Result<(), Box<dyn std::error::Error>> {
86 setup_tracing();
87 //TODO prod
88 dotenvy::from_path(Path::new("./pds.env"))?;
89 let pds_root = env::var("PDS_DATA_DIRECTORY")?;
90 // let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data";
91 let account_db_url = format!("{}/account.sqlite", pds_root);
92 log::info!("accounts_db_url: {}", account_db_url);
93
94 let account_options = SqliteConnectOptions::new()
95 .journal_mode(SqliteJournalMode::Wal)
96 .filename(account_db_url);
97
98 let account_pool = SqlitePoolOptions::new()
99 .max_connections(5)
100 .connect_with(account_options)
101 .await?;
102
103 let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root);
104 let options = SqliteConnectOptions::new()
105 .journal_mode(SqliteJournalMode::Wal)
106 .filename(bells_db_url)
107 .create_if_missing(true);
108 let pds_gatekeeper_pool = SqlitePoolOptions::new()
109 .max_connections(5)
110 .connect_with(options)
111 .await?;
112
113 // Run migrations for the bells_and_whistles database
114 // Note: the migrations are embedded at compile time from the given directory
115 // sqlx
116 sqlx::migrate!("./migrations")
117 .run(&pds_gatekeeper_pool)
118 .await?;
119
120 let client: HyperUtilClient =
121 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
122 .build(HttpConnector::new());
123
124 //Emailer set up
125 let smtp_url =
126 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file");
127 let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS")
128 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file");
129 let mailer: AsyncSmtpTransport<Tokio1Executor> =
130 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
131 //Email templates setup
132 let mut hbs = Handlebars::new();
133 let _ = hbs.register_embed_templates::<EmailTemplates>();
134
135 let state = AppState {
136 account_pool,
137 pds_gatekeeper_pool,
138 reverse_proxy_client: client,
139 //TODO should be env prob
140 pds_base_url: "http://localhost:3000".to_string(),
141 mailer,
142 mailer_from: sent_from,
143 template_engine: Engine::from(hbs),
144 };
145
146 // Rate limiting
147 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
148 let governor_conf = GovernorConfigBuilder::default()
149 .per_second(60)
150 .burst_size(5)
151 .finish()
152 .unwrap();
153 let governor_limiter = governor_conf.limiter().clone();
154 let interval = Duration::from_secs(60);
155 // a separate background task to clean up
156 std::thread::spawn(move || {
157 loop {
158 std::thread::sleep(interval);
159 tracing::info!("rate limiting storage size: {}", governor_limiter.len());
160 governor_limiter.retain_recent();
161 }
162 });
163
164 let cors = CorsLayer::new()
165 .allow_origin(Any)
166 .allow_methods([Method::GET, Method::OPTIONS, Method::POST])
167 .allow_headers(Any);
168
169 let app = Router::new()
170 .route("/", get(root_handler))
171 .route(
172 "/xrpc/com.atproto.server.getSession",
173 get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)),
174 )
175 .route(
176 "/xrpc/com.atproto.server.updateEmail",
177 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
178 )
179 .route(
180 "/xrpc/com.atproto.server.createSession",
181 post(create_session.layer(GovernorLayer::new(governor_conf))),
182 )
183 .layer(CompressionLayer::new())
184 .layer(cors)
185 .with_state(state);
186
187 let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
188 let port: u16 = env::var("PORT")
189 .ok()
190 .and_then(|s| s.parse().ok())
191 .unwrap_or(8080);
192 let addr: SocketAddr = format!("{host}:{port}")
193 .parse()
194 .expect("valid socket address");
195
196 let listener = tokio::net::TcpListener::bind(addr).await?;
197
198 let server = axum::serve(
199 listener,
200 app.into_make_service_with_connect_info::<SocketAddr>(),
201 )
202 .with_graceful_shutdown(shutdown_signal());
203
204 if let Err(err) = server.await {
205 error!(error = %err, "server error");
206 }
207
208 Ok(())
209}
210
211fn setup_tracing() {
212 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
213 tracing_subscriber::registry()
214 .with(env_filter)
215 .with(fmt::layer())
216 .init();
217}
218
219async fn shutdown_signal() {
220 // Wait for Ctrl+C
221 let ctrl_c = async {
222 tokio::signal::ctrl_c()
223 .await
224 .expect("failed to install Ctrl+C handler");
225 };
226
227 #[cfg(unix)]
228 let terminate = async {
229 use tokio::signal::unix::{SignalKind, signal};
230
231 let mut sigterm =
232 signal(SignalKind::terminate()).expect("failed to install signal handler");
233 sigterm.recv().await;
234 };
235
236 #[cfg(not(unix))]
237 let terminate = std::future::pending::<()>();
238
239 tokio::select! {
240 _ = ctrl_c => {},
241 _ = terminate => {},
242 }
243}