Monorepo for Tangled
tangled.org
1use once_cell::sync::Lazy;
2use prost::Message as ProstMessage;
3use prost_reflect::DescriptorPool;
4use std::io;
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6
7pub mod v1 {
8 include!("gen/spindle/agent/v1/spindle.agent.v1.rs");
9}
10
11pub use v1::Message;
12
13pub static DESCRIPTOR_POOL: Lazy<DescriptorPool> = Lazy::new(|| {
14 let bytes = include_bytes!("gen/file_descriptor_set.bin");
15 DescriptorPool::decode(&bytes[..]).unwrap()
16});
17
18macro_rules! impl_reflect {
19 ($($t:ident),* $(,)?) => {
20 $(
21 impl prost_reflect::ReflectMessage for v1::$t {
22 fn descriptor(&self) -> prost_reflect::MessageDescriptor {
23 DESCRIPTOR_POOL
24 .get_message_by_name(concat!("spindle.agent.v1.", stringify!($t)))
25 .unwrap()
26 }
27 }
28 )*
29 };
30}
31
32impl_reflect!(
33 Hello,
34 Init,
35 ExecStart,
36 ExecStdout,
37 ExecStderr,
38 ExecExit,
39 ActivateConfig,
40 ActivateConfigResult,
41 BuiltPaths,
42 CacheDrain,
43 CacheDrainResult,
44 Poweroff,
45 PoweroffResult,
46 Message,
47);
48
49pub const PROTOCOL_VERSION: u32 = 1;
50pub const DEFAULT_PORT: u32 = 10240;
51pub const MAX_MESSAGE_BYTES: usize = 1024 * 1024;
52
53#[macro_export]
54macro_rules! on_payload {
55 (ref $msg:expr, { $( $field:ident => $body:expr ),* $(,)? }) => {
56 #[allow(unused_variables)]
57 $(if let Some(ref $field) = $msg.$field { Some($body) } else)* { None }
58 };
59 ($msg:expr, { $( $field:ident => $body:expr ),* $(,)? }) => {
60 $(if let Some($field) = $msg.$field { Some($body) } else )* { None }
61 };
62}
63
64pub fn kind(msg: &Message) -> &'static str {
65 // todo(dawn): maybe eventually we should have a custom protoc plugin for
66 // generating an enum, right now not worth it, when we have more needs for
67 // it imo we can consider it again
68 on_payload!(ref msg, {
69 hello => "hello",
70 init => "init",
71 exec_start => "exec_start",
72 exec_stdout => "exec_stdout",
73 exec_stderr => "exec_stderr",
74 exec_exit => "exec_exit",
75 activate_config => "activate_config",
76 activate_config_result => "activate_config_result",
77 built_paths => "built_paths",
78 cache_drain => "cache_drain",
79 cache_drain_result => "cache_drain_result",
80 poweroff => "poweroff",
81 poweroff_result => "poweroff_result",
82 })
83 .unwrap_or_else(|| unreachable!("validated message has no payload"))
84}
85
86pub fn error_or_empty(error: Option<String>) -> String {
87 error.filter(|error| !error.is_empty()).unwrap_or_default()
88}
89
90pub async fn write_message<W: AsyncWrite + Unpin>(writer: &mut W, msg: &Message) -> io::Result<()> {
91 if let Err(err) = prost_protovalidate::validate(msg) {
92 return Err(io::Error::new(
93 io::ErrorKind::InvalidData,
94 format!("validate agent message: {err}"),
95 ));
96 }
97
98 let mut data = Vec::new();
99 msg.encode(&mut data)
100 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
101 if data.len() > MAX_MESSAGE_BYTES {
102 return Err(io::Error::new(
103 io::ErrorKind::InvalidData,
104 format!("agent message exceeded {MAX_MESSAGE_BYTES} bytes"),
105 ));
106 }
107
108 writer.write_all(&(data.len() as u32).to_be_bytes()).await?;
109 writer.write_all(&data).await?;
110 writer.flush().await
111}
112
113pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Option<Message>> {
114 let Some(header) = read_header(reader).await? else {
115 return Ok(None);
116 };
117 let size = u32::from_be_bytes(header) as usize;
118 if size > MAX_MESSAGE_BYTES {
119 return Err(io::Error::new(
120 io::ErrorKind::InvalidData,
121 format!("agent message exceeded {MAX_MESSAGE_BYTES} bytes"),
122 ));
123 }
124
125 let mut data = vec![0; size];
126 reader.read_exact(&mut data).await?;
127 let msg = Message::decode(&data[..])
128 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
129
130 if let Err(err) = prost_protovalidate::validate(&msg) {
131 return Err(io::Error::new(
132 io::ErrorKind::InvalidData,
133 format!("validate agent message: {err}"),
134 ));
135 }
136
137 Ok(Some(msg))
138}
139
140async fn read_header<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Option<[u8; 4]>> {
141 let mut header = [0; 4];
142 let mut read = 0;
143 while read < header.len() {
144 match reader.read(&mut header[read..]).await {
145 Ok(0) if read == 0 => return Ok(None),
146 Ok(0) => {
147 return Err(io::Error::new(
148 io::ErrorKind::UnexpectedEof,
149 "partial agent message header",
150 ));
151 }
152 Ok(n) => read += n,
153 Err(err) if err.kind() == io::ErrorKind::Interrupted => {}
154 Err(err) => return Err(err),
155 }
156 }
157 Ok(Some(header))
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[tokio::test]
165 async fn round_trips_protobuf_message() {
166 let msg = Message {
167 id: "built-paths".to_owned(),
168 built_paths: Some(v1::BuiltPaths {
169 paths: vec!["/nix/store/abc-package".to_owned()],
170 reason: "post_build_hook".to_owned(),
171 }),
172 ..Default::default()
173 };
174
175 let mut encoded = Vec::new();
176 write_message(&mut encoded, &msg).await.unwrap();
177
178 let decoded = read_message(&mut &encoded[..]).await.unwrap().unwrap();
179 assert!(decoded.built_paths.is_some());
180 if let Some(p) = decoded.built_paths {
181 assert_eq!(p.paths, ["/nix/store/abc-package"]);
182 assert_eq!(p.reason, "post_build_hook");
183 }
184 }
185
186 #[test]
187 fn validates_messages() {
188 // 1. valid message (exactly one field set)
189 let valid = Message {
190 id: "test-1".to_owned(),
191 hello: Some(v1::Hello::default()),
192 ..Default::default()
193 };
194 assert!(prost_protovalidate::validate(&valid).is_ok());
195
196 // 2. invalid message (zero fields set)
197 let invalid_zero = Message {
198 id: "test-2".to_owned(),
199 ..Default::default()
200 };
201 assert!(prost_protovalidate::validate(&invalid_zero).is_err());
202
203 // 3. invalid message (multiple fields set)
204 let invalid_multi = Message {
205 id: "test-3".to_owned(),
206 hello: Some(v1::Hello::default()),
207 init: Some(v1::Init::default()),
208 ..Default::default()
209 };
210 assert!(prost_protovalidate::validate(&invalid_multi).is_err());
211 }
212}