Monorepo for Tangled
tangled.org
1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::sync::Mutex;
4use std::time::Duration;
5
6use bobbin_runtime::Clock;
7use bobbin_types::knot_acl::KnotHostKey;
8use tokio::time::Instant;
9
10use crate::client::{KnotClient, knot_endpoint};
11
12const KNOT_ACL_CAPABILITY: &str = "knot-acl";
13const LEGACY_REPROBE_INTERVAL: Duration = Duration::from_secs(300);
14const ERROR_REPROBE_INTERVAL: Duration = Duration::from_secs(60);
15
16struct ProbeRecord {
17 at: Instant,
18 retry_after: Duration,
19}
20
21pub struct CapabilityGate {
22 client: KnotClient,
23 clock: Arc<dyn Clock>,
24 dev: bool,
25 allow_private: bool,
26 native: Mutex<HashSet<KnotHostKey>>,
27 last_probe: Mutex<HashMap<KnotHostKey, ProbeRecord>>,
28}
29
30impl CapabilityGate {
31 pub fn new(client: KnotClient, clock: Arc<dyn Clock>, dev: bool, allow_private: bool) -> Self {
32 Self {
33 client,
34 clock,
35 dev,
36 allow_private,
37 native: Mutex::new(HashSet::new()),
38 last_probe: Mutex::new(HashMap::new()),
39 }
40 }
41
42 pub fn is_native(&self, host: &KnotHostKey) -> bool {
43 self.native.lock().unwrap().contains(host)
44 }
45
46 pub async fn has_knot_acl(&self, host: &KnotHostKey) -> bool {
47 if self.is_native(host) {
48 return true;
49 }
50 let now = self.clock.now_instant();
51 if self.throttled(host, now) {
52 return false;
53 }
54 match self.probe(host).await {
55 Ok(true) => {
56 self.native.lock().unwrap().insert(host.clone());
57 true
58 }
59 Ok(false) => {
60 self.mark(host, now, LEGACY_REPROBE_INTERVAL);
61 false
62 }
63 Err(err) => {
64 tracing::warn!(host = %host, error = %err, "knot capability probe failed");
65 self.mark(host, now, ERROR_REPROBE_INTERVAL);
66 false
67 }
68 }
69 }
70
71 async fn probe(&self, host: &KnotHostKey) -> Result<bool, crate::client::KnotClientError> {
72 let endpoint = knot_endpoint(host.as_str(), self.dev, self.allow_private)?;
73 let caps = self.client.capabilities(&endpoint).await?;
74 Ok(caps.iter().any(|cap| cap == KNOT_ACL_CAPABILITY))
75 }
76
77 fn throttled(&self, host: &KnotHostKey, now: Instant) -> bool {
78 self.last_probe
79 .lock()
80 .unwrap()
81 .get(host)
82 .is_some_and(|rec| now.saturating_duration_since(rec.at) < rec.retry_after)
83 }
84
85 fn mark(&self, host: &KnotHostKey, now: Instant, retry_after: Duration) {
86 self.last_probe.lock().unwrap().insert(
87 host.clone(),
88 ProbeRecord {
89 at: now,
90 retry_after,
91 },
92 );
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use std::sync::atomic::{AtomicU64, Ordering};
100
101 use bobbin_runtime::{ReqwestHttp, SleepFuture, UnixMicros};
102 use serde_json::json;
103 use wiremock::matchers::{method, path};
104 use wiremock::{Mock, MockServer, ResponseTemplate};
105
106 struct ManualClock {
107 base: Instant,
108 offset_micros: AtomicU64,
109 }
110
111 impl ManualClock {
112 fn new() -> Self {
113 Self {
114 base: Instant::now(),
115 offset_micros: AtomicU64::new(0),
116 }
117 }
118
119 fn advance(&self, by: Duration) {
120 self.offset_micros
121 .fetch_add(by.as_micros() as u64, Ordering::SeqCst);
122 }
123 }
124
125 impl Clock for ManualClock {
126 fn now_unix_micros(&self) -> UnixMicros {
127 UnixMicros::new(self.offset_micros.load(Ordering::SeqCst))
128 }
129 fn now_instant(&self) -> Instant {
130 self.base + Duration::from_micros(self.offset_micros.load(Ordering::SeqCst))
131 }
132 fn sleep(&self, _: Duration) -> SleepFuture {
133 Box::pin(async {})
134 }
135 fn sleep_until(&self, _: Instant) -> SleepFuture {
136 Box::pin(async {})
137 }
138 }
139
140 fn gate(server: &MockServer, clock: Arc<dyn Clock>) -> (CapabilityGate, KnotHostKey) {
141 let client = KnotClient::new(ReqwestHttp::shared(reqwest::Client::new()));
142 let url = url::Url::parse(&server.uri()).unwrap();
143 let host = format!("{}:{}", url.host_str().unwrap(), url.port().unwrap());
144 (
145 CapabilityGate::new(client, clock, true, true),
146 KnotHostKey::new(&host),
147 )
148 }
149
150 async fn mount_version(server: &MockServer, caps: serde_json::Value, expect: u64) {
151 Mock::given(method("GET"))
152 .and(path("/xrpc/sh.tangled.knot.version"))
153 .respond_with(
154 ResponseTemplate::new(200)
155 .set_body_json(json!({ "version": "1.0.0 (cafe)", "capabilities": caps })),
156 )
157 .expect(expect)
158 .mount(server)
159 .await;
160 }
161
162 #[tokio::test]
163 async fn declares_knot_acl() {
164 let server = MockServer::start().await;
165 mount_version(&server, json!(["knot-acl"]), 1).await;
166 let (gate, host) = gate(&server, Arc::new(ManualClock::new()));
167 assert!(gate.has_knot_acl(&host).await);
168 assert!(gate.is_native(&host));
169 }
170
171 #[tokio::test]
172 async fn legacy_knot_without_capability() {
173 let server = MockServer::start().await;
174 mount_version(&server, json!([]), 1).await;
175 let (gate, host) = gate(&server, Arc::new(ManualClock::new()));
176 assert!(!gate.has_knot_acl(&host).await);
177 assert!(!gate.is_native(&host));
178 }
179
180 #[tokio::test]
181 async fn native_is_latched_and_survives_probe_error() {
182 let server = MockServer::start().await;
183 mount_version(&server, json!(["knot-acl"]), 1).await;
184 let clock = Arc::new(ManualClock::new());
185 let (gate, host) = gate(&server, clock.clone());
186 assert!(gate.has_knot_acl(&host).await);
187
188 server.reset().await;
189 clock.advance(LEGACY_REPROBE_INTERVAL + Duration::from_secs(1));
190 assert!(
191 gate.has_knot_acl(&host).await,
192 "latched native never re-probes"
193 );
194 assert!(gate.is_native(&host));
195 }
196
197 #[tokio::test]
198 async fn legacy_throttled_then_reprobed_after_interval() {
199 let server = MockServer::start().await;
200 mount_version(&server, json!([]), 2).await;
201 let clock = Arc::new(ManualClock::new());
202 let (gate, host) = gate(&server, clock.clone());
203 assert!(!gate.has_knot_acl(&host).await);
204 clock.advance(Duration::from_secs(60));
205 assert!(!gate.has_knot_acl(&host).await);
206 clock.advance(LEGACY_REPROBE_INTERVAL);
207 assert!(!gate.has_knot_acl(&host).await);
208 }
209
210 #[tokio::test]
211 async fn legacy_upgrade_is_detected_on_reprobe() {
212 let server = MockServer::start().await;
213 Mock::given(method("GET"))
214 .and(path("/xrpc/sh.tangled.knot.version"))
215 .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "version": "1.0.0" })))
216 .up_to_n_times(1)
217 .mount(&server)
218 .await;
219 Mock::given(method("GET"))
220 .and(path("/xrpc/sh.tangled.knot.version"))
221 .respond_with(
222 ResponseTemplate::new(200)
223 .set_body_json(json!({ "version": "1.1.0", "capabilities": ["knot-acl"] })),
224 )
225 .mount(&server)
226 .await;
227 let clock = Arc::new(ManualClock::new());
228 let (gate, host) = gate(&server, clock.clone());
229 assert!(!gate.has_knot_acl(&host).await);
230 clock.advance(LEGACY_REPROBE_INTERVAL + Duration::from_secs(1));
231 assert!(gate.has_knot_acl(&host).await);
232 assert!(gate.is_native(&host));
233 }
234
235 #[tokio::test]
236 async fn probe_error_throttled_briefly_then_reprobed() {
237 let server = MockServer::start().await;
238 Mock::given(method("GET"))
239 .and(path("/xrpc/sh.tangled.knot.version"))
240 .respond_with(ResponseTemplate::new(503))
241 .expect(2)
242 .mount(&server)
243 .await;
244 let clock = Arc::new(ManualClock::new());
245 let (gate, host) = gate(&server, clock.clone());
246 assert!(!gate.has_knot_acl(&host).await);
247 clock.advance(Duration::from_secs(1));
248 assert!(
249 !gate.has_knot_acl(&host).await,
250 "error reprobe is throttled"
251 );
252 clock.advance(ERROR_REPROBE_INTERVAL);
253 assert!(!gate.has_knot_acl(&host).await);
254 }
255}