Monorepo for Tangled
tangled.org
1use once_cell::sync::Lazy;
2use prost::Message as ProstMessage;
3use prost_reflect::DescriptorPool;
4use std::io;
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7pub mod v1 {
8 include!("gen/spindle/agent/v1/spindle.agent.v1.rs");
9}
10
11pub use v1::Message;
12
13pub static DESCRIPTOR_POOL: Lazy<DescriptorPool> = Lazy::new(|| {
14 let bytes = include_bytes!("gen/file_descriptor_set.bin");
15 DescriptorPool::decode(&bytes[..]).unwrap()
16});
17
18macro_rules! impl_reflect {
19 ($($t:ident),* $(,)?) => {
20 $(
21 impl prost_reflect::ReflectMessage for v1::$t {
22 fn descriptor(&self) -> prost_reflect::MessageDescriptor {
23 DESCRIPTOR_POOL
24 .get_message_by_name(concat!("spindle.agent.v1.", stringify!($t)))
25 .unwrap()
26 }
27 }
28 )*
29 };
30}
31
32impl_reflect!(
33 Hello,
34 Init,
35 ExecStart,
36 ExecStdout,
37 ExecStderr,
38 ExecExit,
39 ActivateConfig,
40 ActivateConfigResult,
41 BuiltPaths,
42 CacheDrain,
43 CacheDrainResult,
44 Poweroff,
45 PoweroffResult,
46 OpenDebugShell,
47 PtyData,
48 PtyResize,
49 Message,
50);
51
52pub const PROTOCOL_VERSION: u32 = 1;
53pub const DEFAULT_PORT: u32 = 10240;
54pub const MAX_MESSAGE_BYTES: usize = 1024 * 1024;
55
56#[macro_export]
57macro_rules! on_payload {
58 (ref $msg:expr, { $( $field:ident => $body:expr ),* $(,)? }) => {
59 #[allow(unused_variables)]
60 $(if let Some(ref $field) = $msg.$field { Some($body) } else)* { None }
61 };
62 ($msg:expr, { $( $field:ident => $body:expr ),* $(,)? }) => {
63 $(if let Some($field) = $msg.$field { Some($body) } else )* { None }
64 };
65}
66
67pub fn kind(msg: &Message) -> &'static str {
68 // todo(dawn): maybe eventually we should have a custom protoc plugin for
69 // generating an enum, right now not worth it, when we have more needs for
70 // it imo we can consider it again
71 on_payload!(ref msg, {
72 hello => "hello",
73 init => "init",
74 exec_start => "exec_start",
75 exec_stdout => "exec_stdout",
76 exec_stderr => "exec_stderr",
77 exec_exit => "exec_exit",
78 activate_config => "activate_config",
79 activate_config_result => "activate_config_result",
80 built_paths => "built_paths",
81 cache_drain => "cache_drain",
82 cache_drain_result => "cache_drain_result",
83 poweroff => "poweroff",
84 poweroff_result => "poweroff_result",
85 open_debug_shell => "open_debug_shell",
86 pty_data => "pty_data",
87 pty_resize => "pty_resize",
88 })
89 .unwrap_or_else(|| unreachable!("validated message has no payload"))
90}
91
92pub fn error_or_empty(error: Option<String>) -> String {
93 error.filter(|error| !error.is_empty()).unwrap_or_default()
94}
95
96pub async fn write_message<W: AsyncWrite + Unpin>(writer: &mut W, msg: &Message) -> io::Result<()> {
97 if let Err(err) = prost_protovalidate::validate(msg) {
98 return Err(io::Error::new(
99 io::ErrorKind::InvalidData,
100 format!("validate agent message: {err}"),
101 ));
102 }
103
104 let mut data = Vec::new();
105 msg.encode(&mut data)
106 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
107 if data.len() > MAX_MESSAGE_BYTES {
108 return Err(io::Error::new(
109 io::ErrorKind::InvalidData,
110 format!("agent message exceeded {MAX_MESSAGE_BYTES} bytes"),
111 ));
112 }
113
114 writer.write_all(&(data.len() as u32).to_be_bytes()).await?;
115 writer.write_all(&data).await?;
116 writer.flush().await
117}
118
119pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Option<Message>> {
120 let Some(header) = read_header(reader).await? else {
121 return Ok(None);
122 };
123 let size = u32::from_be_bytes(header) as usize;
124 if size > MAX_MESSAGE_BYTES {
125 return Err(io::Error::new(
126 io::ErrorKind::InvalidData,
127 format!("agent message exceeded {MAX_MESSAGE_BYTES} bytes"),
128 ));
129 }
130
131 let mut data = vec![0; size];
132 reader.read_exact(&mut data).await?;
133 let msg = Message::decode(&data[..])
134 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
135
136 if let Err(err) = prost_protovalidate::validate(&msg) {
137 return Err(io::Error::new(
138 io::ErrorKind::InvalidData,
139 format!("validate agent message: {err}"),
140 ));
141 }
142
143 Ok(Some(msg))
144}
145
146async fn read_header<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Option<[u8; 4]>> {
147 let mut header = [0; 4];
148 let mut read = 0;
149 while read < header.len() {
150 match reader.read(&mut header[read..]).await {
151 Ok(0) if read == 0 => return Ok(None),
152 Ok(0) => {
153 return Err(io::Error::new(
154 io::ErrorKind::UnexpectedEof,
155 "partial agent message header",
156 ));
157 }
158 Ok(n) => read += n,
159 Err(err) if err.kind() == io::ErrorKind::Interrupted => {}
160 Err(err) => return Err(err),
161 }
162 }
163 Ok(Some(header))
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[tokio::test]
171 async fn round_trips_protobuf_message() {
172 let msg = Message {
173 id: "built-paths".to_owned(),
174 built_paths: Some(v1::BuiltPaths {
175 paths: vec!["/nix/store/abc-package".to_owned()],
176 reason: "post_build_hook".to_owned(),
177 }),
178 ..Default::default()
179 };
180
181 let mut encoded = Vec::new();
182 write_message(&mut encoded, &msg).await.unwrap();
183
184 let decoded = read_message(&mut &encoded[..]).await.unwrap().unwrap();
185 assert!(decoded.built_paths.is_some());
186 if let Some(p) = decoded.built_paths {
187 assert_eq!(p.paths, ["/nix/store/abc-package"]);
188 assert_eq!(p.reason, "post_build_hook");
189 }
190 }
191
192 #[test]
193 fn validates_messages() {
194 // 1. valid message (exactly one field set)
195 let valid = Message {
196 id: "test-1".to_owned(),
197 hello: Some(v1::Hello::default()),
198 ..Default::default()
199 };
200 assert!(prost_protovalidate::validate(&valid).is_ok());
201
202 // 2. invalid message (zero fields set)
203 let invalid_zero = Message {
204 id: "test-2".to_owned(),
205 ..Default::default()
206 };
207 assert!(prost_protovalidate::validate(&invalid_zero).is_err());
208
209 // 3. invalid message (multiple fields set)
210 let invalid_multi = Message {
211 id: "test-3".to_owned(),
212 hello: Some(v1::Hello::default()),
213 init: Some(v1::Init::default()),
214 ..Default::default()
215 };
216 assert!(prost_protovalidate::validate(&invalid_multi).is_err());
217 }
218}