Monorepo for Tangled
tangled.org
1package microvm
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "net"
9 "sync"
10 "time"
11
12 "github.com/miekg/dns"
13)
14
15const (
16 dnsProxyIOTimeout = 10 * time.Second
17 dnsProxyIdleTimeout = 30 * time.Second
18 dnsProxyShutdownTimeout = 10 * time.Second
19 dnsProxyMaxConnections = 64
20 dnsProxyMaxTCPQueries = 128
21 dnsProxyResolvConfPath = "/etc/resolv.conf"
22)
23
24type DNSProxy struct {
25 port uint32
26 srv *dns.Server
27
28 closeOnce sync.Once
29 closeErr error
30}
31
32func StartDNSProxy(ctx context.Context, cid uint32, logger *slog.Logger) (*DNSProxy, error) {
33 if ctx == nil {
34 ctx = context.Background()
35 }
36
37 if logger == nil {
38 logger = slog.Default()
39 }
40 logger = logger.With("where", "dns_proxy", "cid", cid)
41
42 ln, port, err := listenRandomVsockPort(ctx)
43 if err != nil {
44 return nil, fmt.Errorf("listen for dns proxy: %w", err)
45 }
46
47 resolver, err := newHostDNSResolver(dnsProxyResolvConfPath, logger)
48 if err != nil {
49 _ = ln.Close()
50 return nil, err
51 }
52
53 listener := newLimitedListener(
54 &cidFilteredVsockListener{
55 Listener: ln,
56 cid: cid,
57 logger: logger,
58 },
59 dnsProxyMaxConnections,
60 logger,
61 )
62
63 proxy := &DNSProxy{
64 port: port,
65 srv: &dns.Server{
66 Net: "tcp",
67 Listener: listener,
68 Handler: dns.HandlerFunc(resolver.ServeDNS),
69 ReadTimeout: dnsProxyIOTimeout,
70 WriteTimeout: dnsProxyIOTimeout,
71 IdleTimeout: func() time.Duration { return dnsProxyIdleTimeout },
72 MaxTCPQueries: dnsProxyMaxTCPQueries,
73 MsgInvalidFunc: func(_ []byte, err error) {
74 logger.Warn("dns proxy invalid message", "error", err)
75 },
76 },
77 }
78
79 go func() {
80 <-ctx.Done()
81 _ = proxy.Close()
82 }()
83
84 go func() {
85 if err := proxy.srv.ActivateAndServe(); err != nil && !errors.Is(err, net.ErrClosed) {
86 logger.Warn("dns proxy stopped", "error", err)
87 }
88 }()
89
90 logger.Info("started dns proxy", "port", port)
91 return proxy, nil
92}
93
94func (p *DNSProxy) Port() uint32 {
95 if p == nil {
96 return 0
97 }
98 return p.port
99}
100
101func (p *DNSProxy) Close() error {
102 if p == nil || p.srv == nil {
103 return nil
104 }
105
106 p.closeOnce.Do(func() {
107 shutdownCtx, cancel := context.WithTimeout(context.Background(), dnsProxyShutdownTimeout)
108 defer cancel()
109
110 p.closeErr = p.srv.ShutdownContext(shutdownCtx)
111 })
112 return p.closeErr
113}
114
115type limitedListener struct {
116 net.Listener
117 slots chan struct{}
118 logger *slog.Logger
119}
120
121func newLimitedListener(listener net.Listener, limit int, logger *slog.Logger) net.Listener {
122 if limit <= 0 {
123 return listener
124 }
125 return &limitedListener{
126 Listener: listener,
127 slots: make(chan struct{}, limit),
128 logger: logger,
129 }
130}
131
132func (l *limitedListener) Accept() (net.Conn, error) {
133 for {
134 conn, err := l.Listener.Accept()
135 if err != nil {
136 return nil, err
137 }
138
139 select {
140 case l.slots <- struct{}{}:
141 return &limitedConn{
142 Conn: conn,
143 release: func() {
144 <-l.slots
145 },
146 }, nil
147 default:
148 l.logger.Warn("dns proxy dropped connection because workers are busy")
149 _ = conn.Close()
150 }
151 }
152}
153
154type limitedConn struct {
155 net.Conn
156 once sync.Once
157 release func()
158}
159
160func (c *limitedConn) Close() error {
161 err := c.Conn.Close()
162 c.once.Do(c.release)
163 return err
164}
165
166type hostDNSResolver struct {
167 upstreams []string
168 attempts int
169 timeout time.Duration
170 logger *slog.Logger
171}
172
173func newHostDNSResolver(path string, logger *slog.Logger) (*hostDNSResolver, error) {
174 config, err := dns.ClientConfigFromFile(path)
175 if err != nil {
176 return nil, fmt.Errorf("read host resolv.conf: %w", err)
177 }
178 if len(config.Servers) == 0 {
179 return nil, fmt.Errorf("host resolv.conf has no nameservers")
180 }
181
182 port := config.Port
183 if port == "" {
184 port = "53"
185 }
186
187 upstreams := make([]string, 0, len(config.Servers))
188 for _, server := range config.Servers {
189 upstreams = append(upstreams, net.JoinHostPort(server, port))
190 }
191
192 timeout := time.Duration(config.Timeout) * time.Second
193 if timeout <= 0 {
194 timeout = dnsProxyIOTimeout
195 }
196
197 return &hostDNSResolver{
198 upstreams: upstreams,
199 attempts: max(config.Attempts, 1),
200 timeout: timeout,
201 logger: logger,
202 }, nil
203}
204
205func (r *hostDNSResolver) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
206 resp, err := r.exchange(req)
207 if err != nil {
208 r.logger.Warn(
209 "dns upstream exchange failed",
210 "question", dnsQuestionLogValue(req),
211 "error", err,
212 )
213 if err := w.WriteMsg(rcodeResponse(req, dns.RcodeServerFailure)); err != nil {
214 r.logger.Warn("dns proxy response write failed", "error", err)
215 }
216 return
217 }
218
219 filterDNSResponse(resp)
220
221 if err := w.WriteMsg(resp); err != nil {
222 r.logger.Warn("dns proxy response write failed", "error", err)
223 }
224}
225
226func (r *hostDNSResolver) exchange(req *dns.Msg) (*dns.Msg, error) {
227 var errs []error
228
229 for range r.attempts {
230 for _, upstream := range r.upstreams {
231 resp, err := exchangeDNSAt(req, upstream, r.timeout)
232 if err == nil {
233 return resp, nil
234 }
235 errs = append(errs, fmt.Errorf("%s: %w", upstream, err))
236 }
237 }
238
239 return nil, errors.Join(errs...)
240}
241
242func exchangeDNSAt(req *dns.Msg, addr string, timeout time.Duration) (*dns.Msg, error) {
243 resp, _, err := (&dns.Client{Net: "udp", Timeout: timeout}).Exchange(req, addr)
244 if err != nil {
245 return nil, err
246 }
247 if resp == nil {
248 return nil, fmt.Errorf("empty udp response")
249 }
250 if !resp.Truncated {
251 return resp, nil
252 }
253
254 resp, _, err = (&dns.Client{Net: "tcp", Timeout: timeout}).Exchange(req, addr)
255 if err != nil {
256 return nil, err
257 }
258 if resp == nil {
259 return nil, fmt.Errorf("empty tcp response")
260 }
261 return resp, nil
262}
263
264func filterDNSResponse(msg *dns.Msg) {
265 if msg == nil {
266 return
267 }
268 msg.Answer = filterDNSRRs(msg.Answer)
269 msg.Ns = filterDNSRRs(msg.Ns)
270 msg.Extra = filterDNSRRs(msg.Extra)
271}
272
273func filterDNSRRs(rrs []dns.RR) []dns.RR {
274 filtered := rrs[:0]
275 for _, rr := range rrs {
276 if rr := filterDNSRR(rr); rr != nil {
277 filtered = append(filtered, rr)
278 }
279 }
280 return filtered
281}
282
283func filterDNSRR(rr dns.RR) dns.RR {
284 switch rr := rr.(type) {
285 case *dns.A:
286 if isBlockedNamespaceIP(rr.A) {
287 return nil
288 }
289 case *dns.AAAA:
290 if isBlockedNamespaceIP(rr.AAAA) {
291 return nil
292 }
293 case *dns.SVCB:
294 filterSVCBValues(&rr.Value)
295 case *dns.HTTPS:
296 filterSVCBValues(&rr.Value)
297 }
298 return rr
299}
300
301// this removes any blocked namespaces in ipv4/v6 hints
302func filterSVCBValues(values *[]dns.SVCBKeyValue) {
303 filtered := (*values)[:0]
304 for _, value := range *values {
305 switch value := value.(type) {
306 case *dns.SVCBIPv4Hint:
307 value.Hint = filterDNSIPs(value.Hint)
308 if len(value.Hint) == 0 {
309 continue
310 }
311 case *dns.SVCBIPv6Hint:
312 value.Hint = filterDNSIPs(value.Hint)
313 if len(value.Hint) == 0 {
314 continue
315 }
316 }
317 filtered = append(filtered, value)
318 }
319 *values = filtered
320}
321
322func filterDNSIPs(ips []net.IP) []net.IP {
323 filtered := ips[:0]
324 for _, ip := range ips {
325 if !isBlockedNamespaceIP(ip) {
326 filtered = append(filtered, ip)
327 }
328 }
329 return filtered
330}
331
332func isBlockedNamespaceIP(ip net.IP) bool {
333 if ip == nil {
334 return true
335 }
336 if ip4 := ip.To4(); ip4 != nil {
337 return isBlockedByNamespaceNets(ip4, 32)
338 }
339 return isBlockedByNamespaceNets(ip, 128)
340}
341
342func isBlockedByNamespaceNets(ip net.IP, bits int) bool {
343 for _, blockedNet := range blockedNamespaceNets {
344 if blockedNet == nil {
345 continue
346 }
347
348 _, blockedBits := blockedNet.Mask.Size()
349 if blockedBits != bits {
350 continue
351 }
352 if blockedNet.Contains(ip) {
353 return true
354 }
355 }
356 return false
357}
358
359func rcodeResponse(req *dns.Msg, rcode int) *dns.Msg {
360 resp := new(dns.Msg)
361 if req == nil {
362 resp.Rcode = rcode
363 return resp
364 }
365 resp.SetRcode(req, rcode)
366 return resp
367}
368
369func dnsQuestionLogValue(msg *dns.Msg) string {
370 if msg == nil || len(msg.Question) == 0 {
371 return ""
372 }
373
374 q := msg.Question[0]
375 qtype := dns.TypeToString[q.Qtype]
376 if qtype == "" {
377 qtype = fmt.Sprintf("TYPE%d", q.Qtype)
378 }
379 return fmt.Sprintf("%s/%s", q.Name, qtype)
380}