Monorepo for Tangled tangled.org
5

Configure Feed

Select the types of activity you want to include in your feed.

1package microvm 2 3import ( 4 "context" 5 "crypto/rand" 6 "encoding/binary" 7 "errors" 8 "fmt" 9 "io" 10 "log/slog" 11 "net" 12 "net/http" 13 "net/http/httputil" 14 "net/url" 15 "strings" 16 "sync" 17 "syscall" 18 "time" 19 20 "github.com/mdlayher/vsock" 21) 22 23const ( 24 readCacheProxyPortMin = 20000 25 readCacheProxyPortMax = 60000 26) 27 28type ReadCacheProxy struct { 29 port uint32 30 31 ln *vsock.Listener 32 server *http.Server 33} 34 35func StartReadCacheProxy(ctx context.Context, cid uint32, upstreams []CacheUpstream, logger *slog.Logger) (*ReadCacheProxy, error) { 36 if logger == nil { 37 logger = slog.Default() 38 } 39 logger = logger.With("where", "read_cache", "cid", cid) 40 41 if len(upstreams) == 0 { 42 return nil, nil 43 } 44 45 ln, port, err := listenRandomVsockPort(ctx) 46 if err != nil { 47 return nil, err 48 } 49 50 proxy := &ReadCacheProxy{ 51 port: port, 52 ln: ln, 53 } 54 proxy.server = &http.Server{ 55 Handler: cacheProxyHandler(upstreams, logger), 56 Protocols: cacheProxyProtocols(), 57 ReadHeaderTimeout: 10 * time.Second, 58 } 59 60 filtered := &cidFilteredVsockListener{ 61 Listener: ln, 62 cid: cid, 63 logger: logger, 64 } 65 go func() { 66 if err := proxy.server.Serve(filtered); err != nil && !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) { 67 logger.Warn("proxy stopped", "cid", cid, "port", port, "error", err) 68 } 69 }() 70 71 logger.Info("started proxy", "cid", cid, "port", port, "upstreams", len(upstreams)) 72 return proxy, nil 73} 74 75func (p *ReadCacheProxy) Port() uint32 { 76 if p == nil { 77 return 0 78 } 79 return p.port 80} 81 82func (p *ReadCacheProxy) Close() error { 83 if p == nil { 84 return nil 85 } 86 87 var closeErr error 88 if p.server != nil { 89 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 90 closeErr = errors.Join(closeErr, p.server.Shutdown(ctx)) 91 cancel() 92 p.server = nil 93 } 94 if p.ln != nil { 95 closeErr = errors.Join(closeErr, p.ln.Close()) 96 p.ln = nil 97 } 98 return closeErr 99} 100 101type cidFilteredVsockListener struct { 102 *vsock.Listener 103 cid uint32 104 logger *slog.Logger 105} 106 107func (l *cidFilteredVsockListener) Accept() (net.Conn, error) { 108 for { 109 conn, err := l.Listener.Accept() 110 if err != nil { 111 return nil, err 112 } 113 114 addr, ok := conn.RemoteAddr().(*vsock.Addr) 115 if ok && addr.ContextID == l.cid { 116 return conn, nil 117 } 118 119 l.logger.Warn("dropping proxy connection from unexpected cid", "remote", conn.RemoteAddr(), "expectedCID", l.cid) 120 _ = conn.Close() 121 } 122} 123 124func parseCacheUpstreams(raw []string) ([]*url.URL, error) { 125 upstreams := make([]*url.URL, 0, len(raw)) 126 seen := make(map[string]struct{}, len(raw)) 127 for _, value := range raw { 128 value = strings.TrimSpace(value) 129 if value == "" { 130 continue 131 } 132 if _, ok := seen[value]; ok { 133 continue 134 } 135 seen[value] = struct{}{} 136 137 parsed, err := url.Parse(value) 138 if err != nil { 139 return nil, fmt.Errorf("parse URL %q: %w", value, err) 140 } 141 if parsed.Scheme != "http" && parsed.Scheme != "https" { 142 return nil, fmt.Errorf("URL %q uses unsupported scheme %q", value, parsed.Scheme) 143 } 144 if parsed.Host == "" { 145 return nil, fmt.Errorf("URL %q is missing host", value) 146 } 147 upstreams = append(upstreams, parsed) 148 } 149 return upstreams, nil 150} 151 152type CacheUpstream struct { 153 url *url.URL 154 // guarded upstreams come from the workflow file 155 // requests to them are refused for special-purpose address ranges 156 guarded bool 157} 158 159func BuildCacheUpstreams(rawTrusted, rawGuarded []string) ([]CacheUpstream, error) { 160 trusted, err := parseCacheUpstreams(rawTrusted) 161 if err != nil { 162 return nil, err 163 } 164 guarded, err := parseCacheUpstreams(rawGuarded) 165 if err != nil { 166 return nil, err 167 } 168 return mergeCacheUpstreams(trusted, guarded), nil 169} 170 171func mergeCacheUpstreams(trusted, guarded []*url.URL) []CacheUpstream { 172 merged := make([]CacheUpstream, 0, len(trusted)+len(guarded)) 173 seen := make(map[string]struct{}, len(trusted)+len(guarded)) 174 for _, u := range trusted { 175 if _, ok := seen[u.String()]; ok { 176 continue 177 } 178 seen[u.String()] = struct{}{} 179 merged = append(merged, CacheUpstream{url: u}) 180 } 181 for _, u := range guarded { 182 if _, ok := seen[u.String()]; ok { 183 continue 184 } 185 seen[u.String()] = struct{}{} 186 merged = append(merged, CacheUpstream{url: u, guarded: true}) 187 } 188 return merged 189} 190 191func listenRandomVsockPort(ctx context.Context) (*vsock.Listener, uint32, error) { 192 var lastErr error 193 for range 32 { 194 port, err := randomVsockPort() 195 if err != nil { 196 return nil, 0, err 197 } 198 ln, err := vsock.Listen(port, nil) 199 if err == nil { 200 return ln, port, nil 201 } 202 lastErr = err 203 204 select { 205 case <-ctx.Done(): 206 return nil, 0, ctx.Err() 207 default: 208 } 209 } 210 return nil, 0, fmt.Errorf("listen on random vsock port: %w", lastErr) 211} 212 213func randomVsockPort() (uint32, error) { 214 var data [4]byte 215 if _, err := rand.Read(data[:]); err != nil { 216 return 0, fmt.Errorf("allocate read vsock port: %w", err) 217 } 218 span := uint32(readCacheProxyPortMax - readCacheProxyPortMin) 219 return readCacheProxyPortMin + binary.BigEndian.Uint32(data[:])%span, nil 220} 221 222var proxyTransport = &http.Transport{ 223 Proxy: http.ProxyFromEnvironment, 224 ForceAttemptHTTP2: true, 225 MaxIdleConns: 100, 226 IdleConnTimeout: 90 * time.Second, 227 TLSHandshakeTimeout: 10 * time.Second, 228 ExpectContinueTimeout: 1 * time.Second, 229} 230 231// for guarded upstreams, this will refuse requests made to blocked addresses 232var guardedProxyTransport = &http.Transport{ 233 DialContext: (&net.Dialer{ 234 Timeout: 30 * time.Second, 235 KeepAlive: 30 * time.Second, 236 Control: refuseSpecialPurposeAddrs, 237 }).DialContext, 238 ForceAttemptHTTP2: true, 239 MaxIdleConns: 100, 240 IdleConnTimeout: 90 * time.Second, 241 TLSHandshakeTimeout: 10 * time.Second, 242 ExpectContinueTimeout: 1 * time.Second, 243} 244 245// this should run after dns resolution, so it should cover any rebinding tricks 246func refuseSpecialPurposeAddrs(network, address string, _ syscall.RawConn) error { 247 host, _, err := net.SplitHostPort(address) 248 if err != nil { 249 return fmt.Errorf("split dial address %q: %w", address, err) 250 } 251 ip := net.ParseIP(host) 252 if ip == nil { 253 return fmt.Errorf("refusing to dial non-IP address %q", host) 254 } 255 for _, ipnet := range blockedNamespaceNets { 256 if ipnet.Contains(ip) { 257 return fmt.Errorf("refusing to dial %s: %s is blocked for workflow caches", ip, ipnet) 258 } 259 } 260 return nil 261} 262 263// the proxy is the cache as far as the guest is concerned, so we answer 264// /nix-cache-info ourselves instead of racing the upstreams for it. merging 265// those also doesn't make any sense (none of the options make sense for 266// merging) 267const nixCacheInfo = "StoreDir: /nix/store\nWantMassQuery: 1\nPriority: 40\n" 268 269func cacheProxyHandler(upstreams []CacheUpstream, logger *slog.Logger) http.Handler { 270 proxy := &httputil.ReverseProxy{ 271 // nothing to do here: the racing transport builds the full URL per 272 // upstream, it just needs the guest's path/query left intact 273 Rewrite: func(*httputil.ProxyRequest) {}, 274 ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError), 275 Transport: &parallelRacingTransport{ 276 upstreams: upstreams, 277 underlying: proxyTransport, 278 guardedUnderlying: guardedProxyTransport, 279 logger: logger, 280 }, 281 } 282 283 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 284 if r.URL.Path == "/nix-cache-info" { 285 w.Header().Set("Content-Type", "text/x-nix-cache-info") 286 _, _ = io.WriteString(w, nixCacheInfo) 287 return 288 } 289 proxy.ServeHTTP(w, r) 290 }) 291} 292 293func cacheProxyProtocols() *http.Protocols { 294 protocols := new(http.Protocols) 295 protocols.SetHTTP1(true) 296 protocols.SetUnencryptedHTTP2(true) 297 return protocols 298} 299 300func mergeQuery(base, extra string) string { 301 switch { 302 case base == "": 303 return extra 304 case extra == "": 305 return base 306 default: 307 return base + "&" + extra 308 } 309} 310 311type parallelRacingTransport struct { 312 upstreams []CacheUpstream 313 underlying http.RoundTripper 314 guardedUnderlying http.RoundTripper 315 logger *slog.Logger 316} 317 318func (t *parallelRacingTransport) RoundTrip(req *http.Request) (*http.Response, error) { 319 type result struct { 320 resp *http.Response 321 err error 322 is404 bool 323 idx int 324 } 325 326 resCh := make(chan result, len(t.upstreams)) 327 cancels := make([]context.CancelFunc, len(t.upstreams)) 328 var wg sync.WaitGroup 329 330 for i, upstream := range t.upstreams { 331 wg.Add(1) 332 ctx, cancel := context.WithCancel(req.Context()) 333 cancels[i] = cancel 334 335 go func(idx int, target CacheUpstream, uCtx context.Context) { 336 defer wg.Done() 337 338 raceReq := req.Clone(uCtx) 339 // rewrite to the target, joining the upstream's base path/query 340 // with what the guest asked for 341 raceReq.URL.Scheme = target.url.Scheme 342 raceReq.URL.Host = target.url.Host 343 raceReq.URL.Path = strings.TrimSuffix(target.url.Path, "/") + req.URL.Path 344 raceReq.URL.RawQuery = mergeQuery(target.url.RawQuery, req.URL.RawQuery) 345 // Host wins over URL.Host for the outgoing Host header, and the 346 // reverse proxy preserves the guest's (127.0.0.1:<port>), which 347 // host-routed upstreams like fastly reject with a 421 348 raceReq.Host = target.url.Host 349 // the transport doesn't turn URL userinfo into basic auth, only 350 // http.Client does, so do it ourselves 351 if user := target.url.User; user != nil { 352 password, _ := user.Password() 353 raceReq.SetBasicAuth(user.Username(), password) 354 } 355 356 rt := t.underlying 357 if target.guarded { 358 rt = t.guardedUnderlying 359 } 360 resp, err := rt.RoundTrip(raceReq) 361 if err != nil { 362 resCh <- result{err: err, idx: idx} 363 return 364 } 365 if resp.StatusCode == http.StatusNotFound { 366 _ = resp.Body.Close() // don't care about the body of a 404 367 resCh <- result{is404: true, idx: idx} 368 return 369 } 370 if resp.StatusCode >= 400 { 371 // an erroring upstream must not win over a healthy one 372 _ = resp.Body.Close() 373 resCh <- result{err: fmt.Errorf("upstream returned status %d", resp.StatusCode), idx: idx} 374 return 375 } 376 // yay, ok 377 resCh <- result{resp: resp, idx: idx} 378 }(i, upstream, ctx) 379 } 380 381 go func() { 382 wg.Wait() 383 close(resCh) 384 }() 385 386 var total404s int 387 for res := range resCh { 388 if res.is404 { 389 total404s++ 390 if total404s == len(t.upstreams) { 391 for _, cancel := range cancels { 392 cancel() 393 } 394 return &http.Response{ 395 StatusCode: http.StatusNotFound, 396 Body: io.NopCloser(strings.NewReader("404 nix path not found")), 397 Header: make(http.Header), 398 Request: req, 399 }, nil 400 } 401 continue 402 } 403 404 if res.err != nil { 405 if !errors.Is(res.err, context.Canceled) { 406 t.logger.Warn("upstream failed", 407 "path", req.URL.Path, 408 "error", res.err, 409 ) 410 } 411 continue 412 } 413 414 // cancel other requests 415 for i, cancel := range cancels { 416 if i != res.idx { 417 cancel() 418 } 419 } 420 return res.resp, nil 421 } 422 423 for _, cancel := range cancels { 424 cancel() 425 } 426 return nil, errors.New("all upstreams failed or timed out") 427}