Monorepo for Tangled tangled.org
8

Configure Feed

Select the types of activity you want to include in your feed.

1use once_cell::sync::Lazy; 2use prost::Message as ProstMessage; 3use prost_reflect::DescriptorPool; 4use std::io; 5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; 6 7pub mod v1 { 8 include!("gen/spindle/agent/v1/spindle.agent.v1.rs"); 9} 10 11pub use v1::Message; 12 13pub static DESCRIPTOR_POOL: Lazy<DescriptorPool> = Lazy::new(|| { 14 let bytes = include_bytes!("gen/file_descriptor_set.bin"); 15 DescriptorPool::decode(&bytes[..]).unwrap() 16}); 17 18macro_rules! impl_reflect { 19 ($($t:ident),* $(,)?) => { 20 $( 21 impl prost_reflect::ReflectMessage for v1::$t { 22 fn descriptor(&self) -> prost_reflect::MessageDescriptor { 23 DESCRIPTOR_POOL 24 .get_message_by_name(concat!("spindle.agent.v1.", stringify!($t))) 25 .unwrap() 26 } 27 } 28 )* 29 }; 30} 31 32impl_reflect!( 33 Hello, 34 Init, 35 ExecStart, 36 ExecStdout, 37 ExecStderr, 38 ExecExit, 39 ActivateConfig, 40 ActivateConfigResult, 41 BuiltPaths, 42 CacheDrain, 43 CacheDrainResult, 44 Poweroff, 45 PoweroffResult, 46 OpenDebugShell, 47 PtyData, 48 PtyResize, 49 Message, 50); 51 52pub const PROTOCOL_VERSION: u32 = 1; 53pub const DEFAULT_PORT: u32 = 10240; 54pub const MAX_MESSAGE_BYTES: usize = 1024 * 1024; 55 56#[macro_export] 57macro_rules! on_payload { 58 (ref $msg:expr, { $( $field:ident => $body:expr ),* $(,)? }) => { 59 #[allow(unused_variables)] 60 $(if let Some(ref $field) = $msg.$field { Some($body) } else)* { None } 61 }; 62 ($msg:expr, { $( $field:ident => $body:expr ),* $(,)? }) => { 63 $(if let Some($field) = $msg.$field { Some($body) } else )* { None } 64 }; 65} 66 67pub fn kind(msg: &Message) -> &'static str { 68 // todo(dawn): maybe eventually we should have a custom protoc plugin for 69 // generating an enum, right now not worth it, when we have more needs for 70 // it imo we can consider it again 71 on_payload!(ref msg, { 72 hello => "hello", 73 init => "init", 74 exec_start => "exec_start", 75 exec_stdout => "exec_stdout", 76 exec_stderr => "exec_stderr", 77 exec_exit => "exec_exit", 78 activate_config => "activate_config", 79 activate_config_result => "activate_config_result", 80 built_paths => "built_paths", 81 cache_drain => "cache_drain", 82 cache_drain_result => "cache_drain_result", 83 poweroff => "poweroff", 84 poweroff_result => "poweroff_result", 85 open_debug_shell => "open_debug_shell", 86 pty_data => "pty_data", 87 pty_resize => "pty_resize", 88 }) 89 .unwrap_or_else(|| unreachable!("validated message has no payload")) 90} 91 92pub fn error_or_empty(error: Option<String>) -> String { 93 error.filter(|error| !error.is_empty()).unwrap_or_default() 94} 95 96pub async fn write_message<W: AsyncWrite + Unpin>(writer: &mut W, msg: &Message) -> io::Result<()> { 97 if let Err(err) = prost_protovalidate::validate(msg) { 98 return Err(io::Error::new( 99 io::ErrorKind::InvalidData, 100 format!("validate agent message: {err}"), 101 )); 102 } 103 104 let mut data = Vec::new(); 105 msg.encode(&mut data) 106 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; 107 if data.len() > MAX_MESSAGE_BYTES { 108 return Err(io::Error::new( 109 io::ErrorKind::InvalidData, 110 format!("agent message exceeded {MAX_MESSAGE_BYTES} bytes"), 111 )); 112 } 113 114 writer.write_all(&(data.len() as u32).to_be_bytes()).await?; 115 writer.write_all(&data).await?; 116 writer.flush().await 117} 118 119pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Option<Message>> { 120 let Some(header) = read_header(reader).await? else { 121 return Ok(None); 122 }; 123 let size = u32::from_be_bytes(header) as usize; 124 if size > MAX_MESSAGE_BYTES { 125 return Err(io::Error::new( 126 io::ErrorKind::InvalidData, 127 format!("agent message exceeded {MAX_MESSAGE_BYTES} bytes"), 128 )); 129 } 130 131 let mut data = vec![0; size]; 132 reader.read_exact(&mut data).await?; 133 let msg = Message::decode(&data[..]) 134 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; 135 136 if let Err(err) = prost_protovalidate::validate(&msg) { 137 return Err(io::Error::new( 138 io::ErrorKind::InvalidData, 139 format!("validate agent message: {err}"), 140 )); 141 } 142 143 Ok(Some(msg)) 144} 145 146async fn read_header<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Option<[u8; 4]>> { 147 let mut header = [0; 4]; 148 let mut read = 0; 149 while read < header.len() { 150 match reader.read(&mut header[read..]).await { 151 Ok(0) if read == 0 => return Ok(None), 152 Ok(0) => { 153 return Err(io::Error::new( 154 io::ErrorKind::UnexpectedEof, 155 "partial agent message header", 156 )); 157 } 158 Ok(n) => read += n, 159 Err(err) if err.kind() == io::ErrorKind::Interrupted => {} 160 Err(err) => return Err(err), 161 } 162 } 163 Ok(Some(header)) 164} 165 166#[cfg(test)] 167mod tests { 168 use super::*; 169 170 #[tokio::test] 171 async fn round_trips_protobuf_message() { 172 let msg = Message { 173 id: "built-paths".to_owned(), 174 built_paths: Some(v1::BuiltPaths { 175 paths: vec!["/nix/store/abc-package".to_owned()], 176 reason: "post_build_hook".to_owned(), 177 }), 178 ..Default::default() 179 }; 180 181 let mut encoded = Vec::new(); 182 write_message(&mut encoded, &msg).await.unwrap(); 183 184 let decoded = read_message(&mut &encoded[..]).await.unwrap().unwrap(); 185 assert!(decoded.built_paths.is_some()); 186 if let Some(p) = decoded.built_paths { 187 assert_eq!(p.paths, ["/nix/store/abc-package"]); 188 assert_eq!(p.reason, "post_build_hook"); 189 } 190 } 191 192 #[test] 193 fn validates_messages() { 194 // 1. valid message (exactly one field set) 195 let valid = Message { 196 id: "test-1".to_owned(), 197 hello: Some(v1::Hello::default()), 198 ..Default::default() 199 }; 200 assert!(prost_protovalidate::validate(&valid).is_ok()); 201 202 // 2. invalid message (zero fields set) 203 let invalid_zero = Message { 204 id: "test-2".to_owned(), 205 ..Default::default() 206 }; 207 assert!(prost_protovalidate::validate(&invalid_zero).is_err()); 208 209 // 3. invalid message (multiple fields set) 210 let invalid_multi = Message { 211 id: "test-3".to_owned(), 212 hello: Some(v1::Hello::default()), 213 init: Some(v1::Init::default()), 214 ..Default::default() 215 }; 216 assert!(prost_protovalidate::validate(&invalid_multi).is_err()); 217 } 218}