Monorepo for Tangled
tangled.org
1use anyhow::{Context, Result};
2use std::io;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use tokio::net::{TcpListener, TcpStream, UdpSocket};
8use tokio::task::{JoinError, JoinHandle, JoinSet};
9use tokio_vsock::{VsockAddr, VsockStream};
10use tracing::{info, warn};
11
12const DEFAULT_DNS_PROXY_ADDR: &str = "127.0.0.1:53";
13const SHUTTLE_DNS_PROXY_ADDR_ENV: &str = "SHUTTLE_DNS_PROXY_ADDR";
14
15const MAX_DNS_MESSAGE_BYTES: usize = u16::MAX as usize;
16const DNS_IO_TIMEOUT: Duration = Duration::from_secs(10);
17const DNS_TCP_IDLE_TIMEOUT: Duration = Duration::from_secs(120);
18
19// implements a proxy that sends dns requests to the spindle and
20// lets the spindle resolve any queries, and streams the response back.
21//
22// we use this because the way spindle isolates QEMU VMs is unreliable
23// to rely on. if we unblock the blackholed-routes for private nameservers
24// like slirp one, we risk leaking internal DNS zones, even if the guest
25// can't connect to them. there is also potentially DNS rebinding issues.
26// and other slirp4netns quirks...
27//
28// this way we also get to filter the DNS queries very easily, so we can
29// make sure we remove everything that would leak a host information.
30pub struct DnsProxy {
31 handles: Vec<JoinHandle<()>>,
32}
33
34#[derive(Clone)]
35struct HostDnsClient {
36 host_cid: u32,
37 host_port: u32,
38}
39
40impl DnsProxy {
41 pub async fn start(host_cid: u32, host_port: u32) -> Result<Option<Self>> {
42 if host_port == 0 {
43 return Ok(None);
44 }
45
46 let addr = std::env::var(SHUTTLE_DNS_PROXY_ADDR_ENV)
47 .unwrap_or_else(|_| DEFAULT_DNS_PROXY_ADDR.to_owned());
48
49 let udp = Arc::new(
50 UdpSocket::bind(&addr)
51 .await
52 .with_context(|| format!("bind dns udp listener {addr}"))?,
53 );
54
55 let tcp = TcpListener::bind(&addr)
56 .await
57 .with_context(|| format!("bind dns tcp listener {addr}"))?;
58
59 let host = HostDnsClient {
60 host_cid,
61 host_port,
62 };
63
64 let handles = vec![
65 tokio::spawn(udp_loop(udp, host.clone())),
66 tokio::spawn(tcp_loop(tcp, host)),
67 ];
68
69 info!(%addr, host_cid, host_port, "dns proxy ready");
70 Ok(Some(Self { handles }))
71 }
72}
73
74impl Drop for DnsProxy {
75 fn drop(&mut self) {
76 for handle in self.handles.drain(..) {
77 handle.abort();
78 }
79 }
80}
81
82impl HostDnsClient {
83 async fn query(&self, query: Vec<u8>) -> Result<Vec<u8>> {
84 match self.query_once(&query).await {
85 Ok(response) => Ok(response),
86 Err(first_error) => self.query_once(&query).await.with_context(|| {
87 format!("dns host query failed after retry; first error: {first_error:#}")
88 }),
89 }
90 }
91
92 async fn query_once(&self, query: &[u8]) -> Result<Vec<u8>> {
93 let addr = VsockAddr::new(self.host_cid, self.host_port);
94
95 let mut host = tokio::time::timeout(DNS_IO_TIMEOUT, VsockStream::connect(addr))
96 .await
97 .context("dns host connect timed out")?
98 .with_context(|| {
99 format!(
100 "dial host dns proxy cid={} port={}",
101 self.host_cid, self.host_port
102 )
103 })?;
104
105 tokio::time::timeout(DNS_IO_TIMEOUT, async {
106 write_dns_packet(&mut host, query)
107 .await
108 .context("write dns query to host")?;
109
110 read_dns_packet(&mut host)
111 .await
112 .context("read dns response from host")?
113 .context("host dns proxy closed without response")
114 })
115 .await
116 .context("dns host query timed out")?
117 }
118}
119
120async fn udp_loop(socket: Arc<UdpSocket>, host: HostDnsClient) {
121 let mut buf = vec![0; MAX_DNS_MESSAGE_BYTES];
122 let mut tasks = JoinSet::new();
123
124 loop {
125 tokio::select! {
126 received = socket.recv_from(&mut buf) => match received {
127 Ok((len, peer)) => {
128 let query = buf[..len].to_vec();
129 let socket = socket.clone();
130 let host = host.clone();
131
132 tasks.spawn(async move {
133 if let Err(error) = handle_udp_query(socket, peer, query, host).await {
134 warn!(%peer, %error, "dns udp query failed");
135 }
136 });
137 }
138 Err(error) => warn!(%error, "dns udp recv failed"),
139 },
140
141 Some(result) = tasks.join_next(), if !tasks.is_empty() => {
142 log_dns_task_result(result);
143 }
144 }
145 }
146}
147
148async fn handle_udp_query(
149 socket: Arc<UdpSocket>,
150 peer: SocketAddr,
151 query: Vec<u8>,
152 host: HostDnsClient,
153) -> Result<()> {
154 let response = host.query(query).await?;
155
156 socket
157 .send_to(&response, peer)
158 .await
159 .context("send dns udp response")?;
160
161 Ok(())
162}
163
164async fn tcp_loop(listener: TcpListener, host: HostDnsClient) {
165 let mut tasks = JoinSet::new();
166
167 loop {
168 tokio::select! {
169 accepted = listener.accept() => match accepted {
170 Ok((conn, peer)) => {
171 let host = host.clone();
172
173 tasks.spawn(async move {
174 if let Err(error) = handle_tcp_conn(conn, host).await {
175 warn!(%peer, %error, "dns tcp connection failed");
176 }
177 });
178 }
179 Err(error) => warn!(%error, "dns tcp accept failed"),
180 },
181
182 Some(result) = tasks.join_next(), if !tasks.is_empty() => {
183 log_dns_task_result(result);
184 }
185 }
186 }
187}
188
189async fn handle_tcp_conn(mut tcp: TcpStream, host: HostDnsClient) -> Result<()> {
190 loop {
191 let query = tokio::time::timeout(DNS_TCP_IDLE_TIMEOUT, read_dns_packet(&mut tcp))
192 .await
193 .context("dns tcp idle timeout")?
194 .context("read dns tcp query")?;
195
196 let Some(query) = query else {
197 return Ok(());
198 };
199
200 let response = host.query(query).await?;
201
202 write_dns_packet(&mut tcp, &response)
203 .await
204 .context("write dns tcp response")?;
205 }
206}
207
208async fn read_dns_packet<R>(reader: &mut R) -> io::Result<Option<Vec<u8>>>
209where
210 R: AsyncRead + Unpin,
211{
212 let mut len_buf = [0; 2];
213
214 match reader.read_exact(&mut len_buf).await {
215 Ok(_) => {}
216 Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
217 Err(error) => return Err(error),
218 }
219
220 let len = u16::from_be_bytes(len_buf) as usize;
221 if len == 0 {
222 return Err(io::Error::new(
223 io::ErrorKind::InvalidData,
224 "empty dns packet",
225 ));
226 }
227
228 let mut packet = vec![0; len];
229 reader.read_exact(&mut packet).await?;
230 Ok(Some(packet))
231}
232
233async fn write_dns_packet<W>(writer: &mut W, packet: &[u8]) -> io::Result<()>
234where
235 W: AsyncWrite + Unpin,
236{
237 if packet.is_empty() || packet.len() > MAX_DNS_MESSAGE_BYTES {
238 return Err(io::Error::new(
239 io::ErrorKind::InvalidData,
240 format!("invalid dns packet size {}", packet.len()),
241 ));
242 }
243
244 writer
245 .write_all(&(packet.len() as u16).to_be_bytes())
246 .await?;
247 writer.write_all(packet).await?;
248 writer.flush().await
249}
250
251fn log_dns_task_result(result: Result<(), JoinError>) {
252 if let Err(error) = result {
253 warn!(%error, "dns proxy task failed");
254 }
255}