Monorepo for Tangled
tangled.org
1package microvm
2
3import (
4 "context"
5 "crypto/rand"
6 "encoding/binary"
7 "errors"
8 "fmt"
9 "io"
10 "log/slog"
11 "net"
12 "net/http"
13 "net/http/httputil"
14 "net/url"
15 "strings"
16 "sync"
17 "syscall"
18 "time"
19
20 "github.com/mdlayher/vsock"
21)
22
23const (
24 readCacheProxyPortMin = 20000
25 readCacheProxyPortMax = 60000
26)
27
28type ReadCacheProxy struct {
29 port uint32
30
31 ln *vsock.Listener
32 server *http.Server
33}
34
35func StartReadCacheProxy(ctx context.Context, cid uint32, upstreams []CacheUpstream, logger *slog.Logger) (*ReadCacheProxy, error) {
36 if logger == nil {
37 logger = slog.Default()
38 }
39 logger = logger.With("where", "read_cache", "cid", cid)
40
41 if len(upstreams) == 0 {
42 return nil, nil
43 }
44
45 ln, port, err := listenRandomVsockPort(ctx)
46 if err != nil {
47 return nil, err
48 }
49
50 proxy := &ReadCacheProxy{
51 port: port,
52 ln: ln,
53 }
54 proxy.server = &http.Server{
55 Handler: cacheProxyHandler(upstreams, logger),
56 Protocols: cacheProxyProtocols(),
57 ReadHeaderTimeout: 10 * time.Second,
58 }
59
60 filtered := &cidFilteredVsockListener{
61 Listener: ln,
62 cid: cid,
63 logger: logger,
64 }
65 go func() {
66 if err := proxy.server.Serve(filtered); err != nil && !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) {
67 logger.Warn("proxy stopped", "cid", cid, "port", port, "error", err)
68 }
69 }()
70
71 logger.Info("started proxy", "cid", cid, "port", port, "upstreams", len(upstreams))
72 return proxy, nil
73}
74
75func (p *ReadCacheProxy) Port() uint32 {
76 if p == nil {
77 return 0
78 }
79 return p.port
80}
81
82func (p *ReadCacheProxy) Close() error {
83 if p == nil {
84 return nil
85 }
86
87 var closeErr error
88 if p.server != nil {
89 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
90 closeErr = errors.Join(closeErr, p.server.Shutdown(ctx))
91 cancel()
92 p.server = nil
93 }
94 if p.ln != nil {
95 closeErr = errors.Join(closeErr, p.ln.Close())
96 p.ln = nil
97 }
98 return closeErr
99}
100
101type cidFilteredVsockListener struct {
102 *vsock.Listener
103 cid uint32
104 logger *slog.Logger
105}
106
107func (l *cidFilteredVsockListener) Accept() (net.Conn, error) {
108 for {
109 conn, err := l.Listener.Accept()
110 if err != nil {
111 return nil, err
112 }
113
114 addr, ok := conn.RemoteAddr().(*vsock.Addr)
115 if ok && addr.ContextID == l.cid {
116 return conn, nil
117 }
118
119 l.logger.Warn("dropping proxy connection from unexpected cid", "remote", conn.RemoteAddr(), "expectedCID", l.cid)
120 _ = conn.Close()
121 }
122}
123
124func parseCacheUpstreams(raw []string) ([]*url.URL, error) {
125 upstreams := make([]*url.URL, 0, len(raw))
126 seen := make(map[string]struct{}, len(raw))
127 for _, value := range raw {
128 value = strings.TrimSpace(value)
129 if value == "" {
130 continue
131 }
132 if _, ok := seen[value]; ok {
133 continue
134 }
135 seen[value] = struct{}{}
136
137 parsed, err := url.Parse(value)
138 if err != nil {
139 return nil, fmt.Errorf("parse URL %q: %w", value, err)
140 }
141 if parsed.Scheme != "http" && parsed.Scheme != "https" {
142 return nil, fmt.Errorf("URL %q uses unsupported scheme %q", value, parsed.Scheme)
143 }
144 if parsed.Host == "" {
145 return nil, fmt.Errorf("URL %q is missing host", value)
146 }
147 upstreams = append(upstreams, parsed)
148 }
149 return upstreams, nil
150}
151
152type CacheUpstream struct {
153 url *url.URL
154 // guarded upstreams come from the workflow file
155 // requests to them are refused for special-purpose address ranges
156 guarded bool
157}
158
159func BuildCacheUpstreams(rawTrusted, rawGuarded []string) ([]CacheUpstream, error) {
160 trusted, err := parseCacheUpstreams(rawTrusted)
161 if err != nil {
162 return nil, err
163 }
164 guarded, err := parseCacheUpstreams(rawGuarded)
165 if err != nil {
166 return nil, err
167 }
168 return mergeCacheUpstreams(trusted, guarded), nil
169}
170
171func mergeCacheUpstreams(trusted, guarded []*url.URL) []CacheUpstream {
172 merged := make([]CacheUpstream, 0, len(trusted)+len(guarded))
173 seen := make(map[string]struct{}, len(trusted)+len(guarded))
174 for _, u := range trusted {
175 if _, ok := seen[u.String()]; ok {
176 continue
177 }
178 seen[u.String()] = struct{}{}
179 merged = append(merged, CacheUpstream{url: u})
180 }
181 for _, u := range guarded {
182 if _, ok := seen[u.String()]; ok {
183 continue
184 }
185 seen[u.String()] = struct{}{}
186 merged = append(merged, CacheUpstream{url: u, guarded: true})
187 }
188 return merged
189}
190
191func listenRandomVsockPort(ctx context.Context) (*vsock.Listener, uint32, error) {
192 var lastErr error
193 for range 32 {
194 port, err := randomVsockPort()
195 if err != nil {
196 return nil, 0, err
197 }
198 ln, err := vsock.Listen(port, nil)
199 if err == nil {
200 return ln, port, nil
201 }
202 lastErr = err
203
204 select {
205 case <-ctx.Done():
206 return nil, 0, ctx.Err()
207 default:
208 }
209 }
210 return nil, 0, fmt.Errorf("listen on random vsock port: %w", lastErr)
211}
212
213func randomVsockPort() (uint32, error) {
214 var data [4]byte
215 if _, err := rand.Read(data[:]); err != nil {
216 return 0, fmt.Errorf("allocate read vsock port: %w", err)
217 }
218 span := uint32(readCacheProxyPortMax - readCacheProxyPortMin)
219 return readCacheProxyPortMin + binary.BigEndian.Uint32(data[:])%span, nil
220}
221
222var proxyTransport = &http.Transport{
223 Proxy: http.ProxyFromEnvironment,
224 ForceAttemptHTTP2: true,
225 MaxIdleConns: 100,
226 IdleConnTimeout: 90 * time.Second,
227 TLSHandshakeTimeout: 10 * time.Second,
228 ExpectContinueTimeout: 1 * time.Second,
229}
230
231// for guarded upstreams, this will refuse requests made to blocked addresses
232var guardedProxyTransport = &http.Transport{
233 DialContext: (&net.Dialer{
234 Timeout: 30 * time.Second,
235 KeepAlive: 30 * time.Second,
236 Control: refuseSpecialPurposeAddrs,
237 }).DialContext,
238 ForceAttemptHTTP2: true,
239 MaxIdleConns: 100,
240 IdleConnTimeout: 90 * time.Second,
241 TLSHandshakeTimeout: 10 * time.Second,
242 ExpectContinueTimeout: 1 * time.Second,
243}
244
245// this should run after dns resolution, so it should cover any rebinding tricks
246func refuseSpecialPurposeAddrs(network, address string, _ syscall.RawConn) error {
247 host, _, err := net.SplitHostPort(address)
248 if err != nil {
249 return fmt.Errorf("split dial address %q: %w", address, err)
250 }
251 ip := net.ParseIP(host)
252 if ip == nil {
253 return fmt.Errorf("refusing to dial non-IP address %q", host)
254 }
255 for _, ipnet := range blockedNamespaceNets {
256 if ipnet.Contains(ip) {
257 return fmt.Errorf("refusing to dial %s: %s is blocked for workflow caches", ip, ipnet)
258 }
259 }
260 return nil
261}
262
263// the proxy is the cache as far as the guest is concerned, so we answer
264// /nix-cache-info ourselves instead of racing the upstreams for it. merging
265// those also doesn't make any sense (none of the options make sense for
266// merging)
267const nixCacheInfo = "StoreDir: /nix/store\nWantMassQuery: 1\nPriority: 40\n"
268
269func cacheProxyHandler(upstreams []CacheUpstream, logger *slog.Logger) http.Handler {
270 proxy := &httputil.ReverseProxy{
271 // nothing to do here: the racing transport builds the full URL per
272 // upstream, it just needs the guest's path/query left intact
273 Rewrite: func(*httputil.ProxyRequest) {},
274 ErrorLog: slog.NewLogLogger(logger.Handler(), slog.LevelError),
275 Transport: ¶llelRacingTransport{
276 upstreams: upstreams,
277 underlying: proxyTransport,
278 guardedUnderlying: guardedProxyTransport,
279 logger: logger,
280 },
281 }
282
283 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
284 if r.URL.Path == "/nix-cache-info" {
285 w.Header().Set("Content-Type", "text/x-nix-cache-info")
286 _, _ = io.WriteString(w, nixCacheInfo)
287 return
288 }
289 proxy.ServeHTTP(w, r)
290 })
291}
292
293func cacheProxyProtocols() *http.Protocols {
294 protocols := new(http.Protocols)
295 protocols.SetHTTP1(true)
296 protocols.SetUnencryptedHTTP2(true)
297 return protocols
298}
299
300func mergeQuery(base, extra string) string {
301 switch {
302 case base == "":
303 return extra
304 case extra == "":
305 return base
306 default:
307 return base + "&" + extra
308 }
309}
310
311type parallelRacingTransport struct {
312 upstreams []CacheUpstream
313 underlying http.RoundTripper
314 guardedUnderlying http.RoundTripper
315 logger *slog.Logger
316}
317
318func (t *parallelRacingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
319 type result struct {
320 resp *http.Response
321 err error
322 is404 bool
323 idx int
324 }
325
326 resCh := make(chan result, len(t.upstreams))
327 cancels := make([]context.CancelFunc, len(t.upstreams))
328 var wg sync.WaitGroup
329
330 for i, upstream := range t.upstreams {
331 wg.Add(1)
332 ctx, cancel := context.WithCancel(req.Context())
333 cancels[i] = cancel
334
335 go func(idx int, target CacheUpstream, uCtx context.Context) {
336 defer wg.Done()
337
338 raceReq := req.Clone(uCtx)
339 // rewrite to the target, joining the upstream's base path/query
340 // with what the guest asked for
341 raceReq.URL.Scheme = target.url.Scheme
342 raceReq.URL.Host = target.url.Host
343 raceReq.URL.Path = strings.TrimSuffix(target.url.Path, "/") + req.URL.Path
344 raceReq.URL.RawQuery = mergeQuery(target.url.RawQuery, req.URL.RawQuery)
345 // Host wins over URL.Host for the outgoing Host header, and the
346 // reverse proxy preserves the guest's (127.0.0.1:<port>), which
347 // host-routed upstreams like fastly reject with a 421
348 raceReq.Host = target.url.Host
349 // the transport doesn't turn URL userinfo into basic auth, only
350 // http.Client does, so do it ourselves
351 if user := target.url.User; user != nil {
352 password, _ := user.Password()
353 raceReq.SetBasicAuth(user.Username(), password)
354 }
355
356 rt := t.underlying
357 if target.guarded {
358 rt = t.guardedUnderlying
359 }
360 resp, err := rt.RoundTrip(raceReq)
361 if err != nil {
362 resCh <- result{err: err, idx: idx}
363 return
364 }
365 if resp.StatusCode == http.StatusNotFound {
366 _ = resp.Body.Close() // don't care about the body of a 404
367 resCh <- result{is404: true, idx: idx}
368 return
369 }
370 if resp.StatusCode >= 400 {
371 // an erroring upstream must not win over a healthy one
372 _ = resp.Body.Close()
373 resCh <- result{err: fmt.Errorf("upstream returned status %d", resp.StatusCode), idx: idx}
374 return
375 }
376 // yay, ok
377 resCh <- result{resp: resp, idx: idx}
378 }(i, upstream, ctx)
379 }
380
381 go func() {
382 wg.Wait()
383 close(resCh)
384 }()
385
386 var total404s int
387 for res := range resCh {
388 if res.is404 {
389 total404s++
390 if total404s == len(t.upstreams) {
391 for _, cancel := range cancels {
392 cancel()
393 }
394 return &http.Response{
395 StatusCode: http.StatusNotFound,
396 Body: io.NopCloser(strings.NewReader("404 nix path not found")),
397 Header: make(http.Header),
398 Request: req,
399 }, nil
400 }
401 continue
402 }
403
404 if res.err != nil {
405 if !errors.Is(res.err, context.Canceled) {
406 t.logger.Warn("upstream failed",
407 "path", req.URL.Path,
408 "error", res.err,
409 )
410 }
411 continue
412 }
413
414 // cancel other requests
415 for i, cancel := range cancels {
416 if i != res.idx {
417 cancel()
418 }
419 }
420 return res.resp, nil
421 }
422
423 for _, cancel := range cancels {
424 cancel()
425 }
426 return nil, errors.New("all upstreams failed or timed out")
427}