Monorepo for Tangled tangled.org
5

Configure Feed

Select the types of activity you want to include in your feed.

1package microvm 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "log/slog" 8 "net" 9 "sync" 10 "time" 11 12 "github.com/miekg/dns" 13) 14 15const ( 16 dnsProxyIOTimeout = 10 * time.Second 17 dnsProxyIdleTimeout = 30 * time.Second 18 dnsProxyShutdownTimeout = 10 * time.Second 19 dnsProxyMaxConnections = 64 20 dnsProxyMaxTCPQueries = 128 21 dnsProxyResolvConfPath = "/etc/resolv.conf" 22) 23 24type DNSProxy struct { 25 port uint32 26 srv *dns.Server 27 28 closeOnce sync.Once 29 closeErr error 30} 31 32func StartDNSProxy(ctx context.Context, cid uint32, logger *slog.Logger) (*DNSProxy, error) { 33 if ctx == nil { 34 ctx = context.Background() 35 } 36 37 if logger == nil { 38 logger = slog.Default() 39 } 40 logger = logger.With("where", "dns_proxy", "cid", cid) 41 42 ln, port, err := listenRandomVsockPort(ctx) 43 if err != nil { 44 return nil, fmt.Errorf("listen for dns proxy: %w", err) 45 } 46 47 resolver, err := newHostDNSResolver(dnsProxyResolvConfPath, logger) 48 if err != nil { 49 _ = ln.Close() 50 return nil, err 51 } 52 53 listener := newLimitedListener( 54 &cidFilteredVsockListener{ 55 Listener: ln, 56 cid: cid, 57 logger: logger, 58 }, 59 dnsProxyMaxConnections, 60 logger, 61 ) 62 63 proxy := &DNSProxy{ 64 port: port, 65 srv: &dns.Server{ 66 Net: "tcp", 67 Listener: listener, 68 Handler: dns.HandlerFunc(resolver.ServeDNS), 69 ReadTimeout: dnsProxyIOTimeout, 70 WriteTimeout: dnsProxyIOTimeout, 71 IdleTimeout: func() time.Duration { return dnsProxyIdleTimeout }, 72 MaxTCPQueries: dnsProxyMaxTCPQueries, 73 MsgInvalidFunc: func(_ []byte, err error) { 74 logger.Warn("dns proxy invalid message", "error", err) 75 }, 76 }, 77 } 78 79 go func() { 80 <-ctx.Done() 81 _ = proxy.Close() 82 }() 83 84 go func() { 85 if err := proxy.srv.ActivateAndServe(); err != nil && !errors.Is(err, net.ErrClosed) { 86 logger.Warn("dns proxy stopped", "error", err) 87 } 88 }() 89 90 logger.Info("started dns proxy", "port", port) 91 return proxy, nil 92} 93 94func (p *DNSProxy) Port() uint32 { 95 if p == nil { 96 return 0 97 } 98 return p.port 99} 100 101func (p *DNSProxy) Close() error { 102 if p == nil || p.srv == nil { 103 return nil 104 } 105 106 p.closeOnce.Do(func() { 107 shutdownCtx, cancel := context.WithTimeout(context.Background(), dnsProxyShutdownTimeout) 108 defer cancel() 109 110 p.closeErr = p.srv.ShutdownContext(shutdownCtx) 111 }) 112 return p.closeErr 113} 114 115type limitedListener struct { 116 net.Listener 117 slots chan struct{} 118 logger *slog.Logger 119} 120 121func newLimitedListener(listener net.Listener, limit int, logger *slog.Logger) net.Listener { 122 if limit <= 0 { 123 return listener 124 } 125 return &limitedListener{ 126 Listener: listener, 127 slots: make(chan struct{}, limit), 128 logger: logger, 129 } 130} 131 132func (l *limitedListener) Accept() (net.Conn, error) { 133 for { 134 conn, err := l.Listener.Accept() 135 if err != nil { 136 return nil, err 137 } 138 139 select { 140 case l.slots <- struct{}{}: 141 return &limitedConn{ 142 Conn: conn, 143 release: func() { 144 <-l.slots 145 }, 146 }, nil 147 default: 148 l.logger.Warn("dns proxy dropped connection because workers are busy") 149 _ = conn.Close() 150 } 151 } 152} 153 154type limitedConn struct { 155 net.Conn 156 once sync.Once 157 release func() 158} 159 160func (c *limitedConn) Close() error { 161 err := c.Conn.Close() 162 c.once.Do(c.release) 163 return err 164} 165 166type hostDNSResolver struct { 167 upstreams []string 168 attempts int 169 timeout time.Duration 170 logger *slog.Logger 171} 172 173func newHostDNSResolver(path string, logger *slog.Logger) (*hostDNSResolver, error) { 174 config, err := dns.ClientConfigFromFile(path) 175 if err != nil { 176 return nil, fmt.Errorf("read host resolv.conf: %w", err) 177 } 178 if len(config.Servers) == 0 { 179 return nil, fmt.Errorf("host resolv.conf has no nameservers") 180 } 181 182 port := config.Port 183 if port == "" { 184 port = "53" 185 } 186 187 upstreams := make([]string, 0, len(config.Servers)) 188 for _, server := range config.Servers { 189 upstreams = append(upstreams, net.JoinHostPort(server, port)) 190 } 191 192 timeout := time.Duration(config.Timeout) * time.Second 193 if timeout <= 0 { 194 timeout = dnsProxyIOTimeout 195 } 196 197 return &hostDNSResolver{ 198 upstreams: upstreams, 199 attempts: max(config.Attempts, 1), 200 timeout: timeout, 201 logger: logger, 202 }, nil 203} 204 205func (r *hostDNSResolver) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { 206 resp, err := r.exchange(req) 207 if err != nil { 208 r.logger.Warn( 209 "dns upstream exchange failed", 210 "question", dnsQuestionLogValue(req), 211 "error", err, 212 ) 213 if err := w.WriteMsg(rcodeResponse(req, dns.RcodeServerFailure)); err != nil { 214 r.logger.Warn("dns proxy response write failed", "error", err) 215 } 216 return 217 } 218 219 filterDNSResponse(resp) 220 221 if err := w.WriteMsg(resp); err != nil { 222 r.logger.Warn("dns proxy response write failed", "error", err) 223 } 224} 225 226func (r *hostDNSResolver) exchange(req *dns.Msg) (*dns.Msg, error) { 227 var errs []error 228 229 for range r.attempts { 230 for _, upstream := range r.upstreams { 231 resp, err := exchangeDNSAt(req, upstream, r.timeout) 232 if err == nil { 233 return resp, nil 234 } 235 errs = append(errs, fmt.Errorf("%s: %w", upstream, err)) 236 } 237 } 238 239 return nil, errors.Join(errs...) 240} 241 242func exchangeDNSAt(req *dns.Msg, addr string, timeout time.Duration) (*dns.Msg, error) { 243 resp, _, err := (&dns.Client{Net: "udp", Timeout: timeout}).Exchange(req, addr) 244 if err != nil { 245 return nil, err 246 } 247 if resp == nil { 248 return nil, fmt.Errorf("empty udp response") 249 } 250 if !resp.Truncated { 251 return resp, nil 252 } 253 254 resp, _, err = (&dns.Client{Net: "tcp", Timeout: timeout}).Exchange(req, addr) 255 if err != nil { 256 return nil, err 257 } 258 if resp == nil { 259 return nil, fmt.Errorf("empty tcp response") 260 } 261 return resp, nil 262} 263 264func filterDNSResponse(msg *dns.Msg) { 265 if msg == nil { 266 return 267 } 268 msg.Answer = filterDNSRRs(msg.Answer) 269 msg.Ns = filterDNSRRs(msg.Ns) 270 msg.Extra = filterDNSRRs(msg.Extra) 271} 272 273func filterDNSRRs(rrs []dns.RR) []dns.RR { 274 filtered := rrs[:0] 275 for _, rr := range rrs { 276 if rr := filterDNSRR(rr); rr != nil { 277 filtered = append(filtered, rr) 278 } 279 } 280 return filtered 281} 282 283func filterDNSRR(rr dns.RR) dns.RR { 284 switch rr := rr.(type) { 285 case *dns.A: 286 if isBlockedNamespaceIP(rr.A) { 287 return nil 288 } 289 case *dns.AAAA: 290 if isBlockedNamespaceIP(rr.AAAA) { 291 return nil 292 } 293 case *dns.SVCB: 294 filterSVCBValues(&rr.Value) 295 case *dns.HTTPS: 296 filterSVCBValues(&rr.Value) 297 } 298 return rr 299} 300 301// this removes any blocked namespaces in ipv4/v6 hints 302func filterSVCBValues(values *[]dns.SVCBKeyValue) { 303 filtered := (*values)[:0] 304 for _, value := range *values { 305 switch value := value.(type) { 306 case *dns.SVCBIPv4Hint: 307 value.Hint = filterDNSIPs(value.Hint) 308 if len(value.Hint) == 0 { 309 continue 310 } 311 case *dns.SVCBIPv6Hint: 312 value.Hint = filterDNSIPs(value.Hint) 313 if len(value.Hint) == 0 { 314 continue 315 } 316 } 317 filtered = append(filtered, value) 318 } 319 *values = filtered 320} 321 322func filterDNSIPs(ips []net.IP) []net.IP { 323 filtered := ips[:0] 324 for _, ip := range ips { 325 if !isBlockedNamespaceIP(ip) { 326 filtered = append(filtered, ip) 327 } 328 } 329 return filtered 330} 331 332func isBlockedNamespaceIP(ip net.IP) bool { 333 if ip == nil { 334 return true 335 } 336 if ip4 := ip.To4(); ip4 != nil { 337 return isBlockedByNamespaceNets(ip4, 32) 338 } 339 return isBlockedByNamespaceNets(ip, 128) 340} 341 342func isBlockedByNamespaceNets(ip net.IP, bits int) bool { 343 for _, blockedNet := range blockedNamespaceNets { 344 if blockedNet == nil { 345 continue 346 } 347 348 _, blockedBits := blockedNet.Mask.Size() 349 if blockedBits != bits { 350 continue 351 } 352 if blockedNet.Contains(ip) { 353 return true 354 } 355 } 356 return false 357} 358 359func rcodeResponse(req *dns.Msg, rcode int) *dns.Msg { 360 resp := new(dns.Msg) 361 if req == nil { 362 resp.Rcode = rcode 363 return resp 364 } 365 resp.SetRcode(req, rcode) 366 return resp 367} 368 369func dnsQuestionLogValue(msg *dns.Msg) string { 370 if msg == nil || len(msg.Question) == 0 { 371 return "" 372 } 373 374 q := msg.Question[0] 375 qtype := dns.TypeToString[q.Qtype] 376 if qtype == "" { 377 qtype = fmt.Sprintf("TYPE%d", q.Qtype) 378 } 379 return fmt.Sprintf("%s/%s", q.Name, qtype) 380}