Monorepo for Tangled tangled.org
5

Configure Feed

Select the types of activity you want to include in your feed.

1use once_cell::sync::Lazy; 2use prost::Message as ProstMessage; 3use prost_reflect::DescriptorPool; 4use std::io; 5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; 6 7pub mod v1 { 8 include!("gen/spindle/agent/v1/spindle.agent.v1.rs"); 9} 10 11pub use v1::Message; 12 13pub static DESCRIPTOR_POOL: Lazy<DescriptorPool> = Lazy::new(|| { 14 let bytes = include_bytes!("gen/file_descriptor_set.bin"); 15 DescriptorPool::decode(&bytes[..]).unwrap() 16}); 17 18macro_rules! impl_reflect { 19 ($($t:ident),* $(,)?) => { 20 $( 21 impl prost_reflect::ReflectMessage for v1::$t { 22 fn descriptor(&self) -> prost_reflect::MessageDescriptor { 23 DESCRIPTOR_POOL 24 .get_message_by_name(concat!("spindle.agent.v1.", stringify!($t))) 25 .unwrap() 26 } 27 } 28 )* 29 }; 30} 31 32impl_reflect!( 33 Hello, 34 Init, 35 ExecStart, 36 ExecStdout, 37 ExecStderr, 38 ExecExit, 39 ActivateConfig, 40 ActivateConfigResult, 41 BuiltPaths, 42 CacheDrain, 43 CacheDrainResult, 44 Poweroff, 45 PoweroffResult, 46 Message, 47); 48 49pub const PROTOCOL_VERSION: u32 = 1; 50pub const DEFAULT_PORT: u32 = 10240; 51pub const MAX_MESSAGE_BYTES: usize = 1024 * 1024; 52 53#[macro_export] 54macro_rules! on_payload { 55 (ref $msg:expr, { $( $field:ident => $body:expr ),* $(,)? }) => { 56 #[allow(unused_variables)] 57 $(if let Some(ref $field) = $msg.$field { Some($body) } else)* { None } 58 }; 59 ($msg:expr, { $( $field:ident => $body:expr ),* $(,)? }) => { 60 $(if let Some($field) = $msg.$field { Some($body) } else )* { None } 61 }; 62} 63 64pub fn kind(msg: &Message) -> &'static str { 65 // todo(dawn): maybe eventually we should have a custom protoc plugin for 66 // generating an enum, right now not worth it, when we have more needs for 67 // it imo we can consider it again 68 on_payload!(ref msg, { 69 hello => "hello", 70 init => "init", 71 exec_start => "exec_start", 72 exec_stdout => "exec_stdout", 73 exec_stderr => "exec_stderr", 74 exec_exit => "exec_exit", 75 activate_config => "activate_config", 76 activate_config_result => "activate_config_result", 77 built_paths => "built_paths", 78 cache_drain => "cache_drain", 79 cache_drain_result => "cache_drain_result", 80 poweroff => "poweroff", 81 poweroff_result => "poweroff_result", 82 }) 83 .unwrap_or_else(|| unreachable!("validated message has no payload")) 84} 85 86pub fn error_or_empty(error: Option<String>) -> String { 87 error.filter(|error| !error.is_empty()).unwrap_or_default() 88} 89 90pub async fn write_message<W: AsyncWrite + Unpin>(writer: &mut W, msg: &Message) -> io::Result<()> { 91 if let Err(err) = prost_protovalidate::validate(msg) { 92 return Err(io::Error::new( 93 io::ErrorKind::InvalidData, 94 format!("validate agent message: {err}"), 95 )); 96 } 97 98 let mut data = Vec::new(); 99 msg.encode(&mut data) 100 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; 101 if data.len() > MAX_MESSAGE_BYTES { 102 return Err(io::Error::new( 103 io::ErrorKind::InvalidData, 104 format!("agent message exceeded {MAX_MESSAGE_BYTES} bytes"), 105 )); 106 } 107 108 writer.write_all(&(data.len() as u32).to_be_bytes()).await?; 109 writer.write_all(&data).await?; 110 writer.flush().await 111} 112 113pub async fn read_message<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Option<Message>> { 114 let Some(header) = read_header(reader).await? else { 115 return Ok(None); 116 }; 117 let size = u32::from_be_bytes(header) as usize; 118 if size > MAX_MESSAGE_BYTES { 119 return Err(io::Error::new( 120 io::ErrorKind::InvalidData, 121 format!("agent message exceeded {MAX_MESSAGE_BYTES} bytes"), 122 )); 123 } 124 125 let mut data = vec![0; size]; 126 reader.read_exact(&mut data).await?; 127 let msg = Message::decode(&data[..]) 128 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; 129 130 if let Err(err) = prost_protovalidate::validate(&msg) { 131 return Err(io::Error::new( 132 io::ErrorKind::InvalidData, 133 format!("validate agent message: {err}"), 134 )); 135 } 136 137 Ok(Some(msg)) 138} 139 140async fn read_header<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Option<[u8; 4]>> { 141 let mut header = [0; 4]; 142 let mut read = 0; 143 while read < header.len() { 144 match reader.read(&mut header[read..]).await { 145 Ok(0) if read == 0 => return Ok(None), 146 Ok(0) => { 147 return Err(io::Error::new( 148 io::ErrorKind::UnexpectedEof, 149 "partial agent message header", 150 )); 151 } 152 Ok(n) => read += n, 153 Err(err) if err.kind() == io::ErrorKind::Interrupted => {} 154 Err(err) => return Err(err), 155 } 156 } 157 Ok(Some(header)) 158} 159 160#[cfg(test)] 161mod tests { 162 use super::*; 163 164 #[tokio::test] 165 async fn round_trips_protobuf_message() { 166 let msg = Message { 167 id: "built-paths".to_owned(), 168 built_paths: Some(v1::BuiltPaths { 169 paths: vec!["/nix/store/abc-package".to_owned()], 170 reason: "post_build_hook".to_owned(), 171 }), 172 ..Default::default() 173 }; 174 175 let mut encoded = Vec::new(); 176 write_message(&mut encoded, &msg).await.unwrap(); 177 178 let decoded = read_message(&mut &encoded[..]).await.unwrap().unwrap(); 179 assert!(decoded.built_paths.is_some()); 180 if let Some(p) = decoded.built_paths { 181 assert_eq!(p.paths, ["/nix/store/abc-package"]); 182 assert_eq!(p.reason, "post_build_hook"); 183 } 184 } 185 186 #[test] 187 fn validates_messages() { 188 // 1. valid message (exactly one field set) 189 let valid = Message { 190 id: "test-1".to_owned(), 191 hello: Some(v1::Hello::default()), 192 ..Default::default() 193 }; 194 assert!(prost_protovalidate::validate(&valid).is_ok()); 195 196 // 2. invalid message (zero fields set) 197 let invalid_zero = Message { 198 id: "test-2".to_owned(), 199 ..Default::default() 200 }; 201 assert!(prost_protovalidate::validate(&invalid_zero).is_err()); 202 203 // 3. invalid message (multiple fields set) 204 let invalid_multi = Message { 205 id: "test-3".to_owned(), 206 hello: Some(v1::Hello::default()), 207 init: Some(v1::Init::default()), 208 ..Default::default() 209 }; 210 assert!(prost_protovalidate::validate(&invalid_multi).is_err()); 211 } 212}