Monorepo for Tangled
tangled.org
1package microvm
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log/slog"
9 "net"
10 "sync"
11 "time"
12
13 "github.com/mdlayher/vsock"
14
15 "tangled.org/core/spindle/agentproto"
16 agentv1 "tangled.org/core/spindle/agentproto/gen"
17)
18
19const guestWorkflowUser = "spindle-workflow"
20
21var errGuestTimedOut = errors.New("guest reported step timed out")
22
23type agentHub struct {
24 l *slog.Logger
25 ln *vsock.Listener
26 pending map[uint32]chan net.Conn
27 mu sync.Mutex
28}
29
30func newAgentHub(port uint32, l *slog.Logger) (*agentHub, error) {
31 ln, err := vsock.Listen(port, nil)
32 if err != nil {
33 return nil, fmt.Errorf("listen for agent on vsock port %d: %w", port, err)
34 }
35 h := &agentHub{
36 l: l,
37 ln: ln,
38 pending: make(map[uint32]chan net.Conn),
39 }
40 go h.acceptLoop()
41 return h, nil
42}
43
44func (h *agentHub) expect(cid uint32) (<-chan net.Conn, func(), error) {
45 h.mu.Lock()
46 defer h.mu.Unlock()
47 if _, exists := h.pending[cid]; exists {
48 return nil, nil, fmt.Errorf("already waiting for agent cid %d", cid)
49 }
50 ch := make(chan net.Conn, 1)
51 h.pending[cid] = ch
52 unregister := func() {
53 h.mu.Lock()
54 delete(h.pending, cid)
55 h.mu.Unlock()
56 close(ch)
57 for conn := range ch {
58 if conn != nil {
59 _ = conn.Close()
60 }
61 }
62 }
63 return ch, unregister, nil
64}
65
66func (h *agentHub) acceptLoop() {
67 for {
68 conn, err := h.ln.Accept()
69 if err != nil {
70 h.l.Error("agent vsock accept failed", "error", err)
71 return
72 }
73
74 addr, ok := conn.RemoteAddr().(*vsock.Addr)
75 if !ok {
76 h.l.Warn("agent connection has unexpected remote address", "remote", conn.RemoteAddr())
77 _ = conn.Close()
78 continue
79 }
80
81 h.mu.Lock()
82 ch, ok := h.pending[addr.ContextID]
83 if ok {
84 delete(h.pending, addr.ContextID)
85 }
86 h.mu.Unlock()
87
88 // todo: if / when we add agent recovery (reconnect) we should add a
89 // boot-initialized session credential to prevent random connections...
90 // checking cid here works to ensure for now since we dont attempt to
91 // reconnect, so we block anything else thats not expected (and agent
92 // runs first in the boot sequence always).
93 if !ok {
94 h.l.Warn("dropping agent connection for unknown cid", "cid", addr.ContextID)
95 _ = conn.Close()
96 continue
97 }
98
99 select {
100 case ch <- conn:
101 default:
102 _ = conn.Close()
103 }
104 }
105}
106
107type AgentExec struct {
108 *agentv1.ExecStart
109 ID string
110 Stdout io.Writer
111 Stderr io.Writer
112}
113
114type AgentSession struct {
115 conn net.Conn
116 enc *agentproto.Encoder
117 dec *agentproto.Decoder
118 l *slog.Logger
119 mu sync.Mutex
120}
121
122func NewAgentSession(conn net.Conn, l *slog.Logger) *AgentSession {
123 return &AgentSession{
124 conn: conn,
125 enc: agentproto.NewEncoder(conn),
126 dec: agentproto.NewDecoder(conn),
127 l: l,
128 }
129}
130
131func (s *AgentSession) Init(ctx context.Context, init *agentv1.Init) error {
132 s.mu.Lock()
133 defer s.mu.Unlock()
134
135 hello, err := s.decode(ctx)
136 if err != nil {
137 return fmt.Errorf("read agent hello: %w", err)
138 }
139 helloPayload := hello.Hello
140 if helloPayload == nil {
141 return fmt.Errorf("expected agent hello, got nil")
142 }
143 s.l.Info("agent connected", "protocol", helloPayload.ProtocolVersion, "version", helloPayload.AgentVersion, "boot", helloPayload.BootId, "nix", helloPayload.NixVersion)
144
145 if err := s.enc.Encode(&agentproto.Message{
146 Id: "init",
147 Init: init,
148 }); err != nil {
149 return fmt.Errorf("send agent init: %w", err)
150 }
151 return nil
152}
153
154func (s *AgentSession) Exec(ctx context.Context, exec AgentExec) (int, error) {
155 s.mu.Lock()
156 defer s.mu.Unlock()
157
158 if exec.ID == "" {
159 return 0, fmt.Errorf("empty ID passed to Exec")
160 }
161
162 if exec.ExecStart.TimeoutSeconds == 0 {
163 exec.ExecStart.TimeoutSeconds = timeoutSeconds(ctx, guestTimeoutGrace)
164 }
165
166 if err := s.enc.Encode(&agentproto.Message{
167 Id: exec.ID,
168 ExecStart: exec.ExecStart,
169 }); err != nil {
170 return 0, fmt.Errorf("send exec_start: %w", err)
171 }
172
173 for {
174 msg, err := s.decode(ctx)
175 if err != nil {
176 return 0, err
177 }
178 if msg.BuiltPaths == nil && msg.Id != exec.ID {
179 continue
180 }
181
182 if p := msg.ExecStdout; p != nil {
183 _, _ = io.WriteString(exec.Stdout, p.Data)
184 } else if p := msg.ExecStderr; p != nil {
185 _, _ = io.WriteString(exec.Stderr, p.Data)
186 } else if p := msg.BuiltPaths; p != nil {
187 // s.l.Debug("guest built paths", "reason", p.Reason, "count", len(p.Paths))
188 } else if p := msg.ExecExit; p != nil {
189 if p.Error != "" {
190 s.l.Warn("guest exec error", "id", msg.Id, "error", p.Error)
191 }
192 if p.TimedOut {
193 return int(p.ExitCode), errGuestTimedOut
194 }
195 return int(p.ExitCode), nil
196 }
197 }
198}
199
200func (s *AgentSession) ActivateConfig(ctx context.Context, id string, req *agentv1.ActivateConfig, out io.Writer) (*agentv1.ActivateConfigResult, error) {
201 s.mu.Lock()
202 defer s.mu.Unlock()
203
204 if id == "" {
205 return nil, fmt.Errorf("empty ID passed to ActivateConfig")
206 }
207 if req.TimeoutSeconds == 0 {
208 req.TimeoutSeconds = timeoutSeconds(ctx, guestTimeoutGrace)
209 }
210 if err := s.enc.Encode(&agentproto.Message{
211 Id: id,
212 ActivateConfig: req,
213 }); err != nil {
214 return nil, fmt.Errorf("send activate_config: %w", err)
215 }
216
217 for {
218 msg, err := s.decode(ctx)
219 if err != nil {
220 return nil, err
221 }
222 if msg.BuiltPaths == nil && msg.Id != id {
223 continue
224 }
225
226 if p := msg.BuiltPaths; p != nil {
227 // s.l.Debug("guest built paths", "reason", p.Reason, "count", len(p.Paths))
228 } else if p := msg.ExecStderr; p != nil {
229 if out != nil {
230 _, _ = io.WriteString(out, p.Data)
231 }
232 } else if p := msg.ExecStdout; p != nil {
233 if out != nil {
234 _, _ = io.WriteString(out, p.Data)
235 }
236 } else if p := msg.ActivateConfigResult; p != nil {
237 if p.Error != "" {
238 return nil, fmt.Errorf("activate config failed: %s", p.Error)
239 }
240 if p.Toplevel == "" {
241 return nil, fmt.Errorf("activate config returned empty toplevel")
242 }
243 return p, nil
244 }
245 }
246}
247
248func (s *AgentSession) Poweroff(ctx context.Context) error {
249 s.mu.Lock()
250 defer s.mu.Unlock()
251
252 id := "poweroff"
253 if err := s.enc.Encode(&agentproto.Message{
254 Id: id,
255 Poweroff: &agentv1.Poweroff{},
256 }); err != nil {
257 return fmt.Errorf("send poweroff: %w", err)
258 }
259
260 for {
261 msg, err := s.decode(ctx)
262 if err != nil {
263 return err
264 }
265 if msg.Id != id {
266 continue
267 }
268 p := msg.PoweroffResult
269 if p == nil {
270 continue
271 }
272 if p.Error != "" {
273 return fmt.Errorf("guest poweroff failed: %s", p.Error)
274 }
275 return nil
276 }
277}
278
279func (s *AgentSession) Drain(ctx context.Context) (uint32, error) {
280 s.mu.Lock()
281 defer s.mu.Unlock()
282
283 drainID := "cache-drain"
284 if err := s.enc.Encode(&agentproto.Message{
285 Id: drainID,
286 CacheDrain: &agentv1.CacheDrain{
287 TimeoutSeconds: timeoutSeconds(ctx, 0),
288 },
289 }); err != nil {
290 return 0, fmt.Errorf("send cache_drain: %w", err)
291 }
292
293 for {
294 msg, err := s.decode(ctx)
295 if err != nil {
296 return 0, err
297 }
298 if msg.Id != drainID {
299 continue
300 }
301 p := msg.CacheDrainResult
302 if p == nil {
303 continue
304 }
305 s.l.Info("cache drain complete", "uploaded", p.CacheUploaded, "failed", p.CacheFailed, "queued", p.CacheQueued, "active", p.CacheActive)
306 if p.Error != "" {
307 return 0, fmt.Errorf("cache drain failed: %s", p.Error)
308 }
309 if p.CacheFailed > 0 {
310 return 0, fmt.Errorf("cache drain failed for %d paths", p.CacheFailed)
311 }
312 if p.CacheQueued > 0 || p.CacheActive > 0 {
313 return 0, fmt.Errorf("cache drain incomplete: queued=%d active=%d", p.CacheQueued, p.CacheActive)
314 }
315 return p.CacheUploaded, nil
316 }
317}
318
319func (s *AgentSession) decode(ctx context.Context) (*agentproto.Message, error) {
320 if err := ctx.Err(); err != nil {
321 return nil, err
322 }
323
324 if deadline, ok := ctx.Deadline(); ok {
325 _ = s.conn.SetReadDeadline(deadline)
326 } else {
327 _ = s.conn.SetReadDeadline(time.Time{})
328 }
329
330 // a blocked vsock read wont wake up just from the ctx being cancelled,
331 // only a deadline will wake it up, so if the VM crashes mid-step the read would
332 // hang until workflow timeout. so we will set a deadline in the past to cancel it.
333 //
334 // we set a deadline here instead of closing the connection, this is the long-lived
335 // connection that everything reuses, so we only really want to interrupt it for this
336 // current read. this also lands as a timeout error which the netErr.Timeout() check
337 // below maps to ctx.Err() correctly
338 stop := context.AfterFunc(ctx, func() {
339 _ = s.conn.SetReadDeadline(time.Now())
340 })
341 defer stop()
342
343 msg, err := s.dec.Decode()
344 if err != nil {
345 var netErr net.Error
346 if errors.As(err, &netErr) && netErr.Timeout() && ctx.Err() != nil {
347 return nil, ctx.Err()
348 }
349 return nil, fmt.Errorf("read agent message: %w", err)
350 }
351 return msg, nil
352}
353
354func (s *AgentSession) Close() error {
355 if s == nil || s.conn == nil {
356 return nil
357 }
358 return s.conn.Close()
359}
360
361// this pulls the deadline from the context and converts it to what the
362// agentproto expects
363func timeoutSeconds(ctx context.Context, lead time.Duration) uint32 {
364 deadline, ok := ctx.Deadline()
365 if !ok {
366 return 0
367 }
368 seconds := int64((time.Until(deadline) - lead).Round(time.Second) / time.Second)
369 if seconds < 1 {
370 return 1
371 }
372 if seconds > int64(^uint32(0)) {
373 return ^uint32(0)
374 }
375 return uint32(seconds)
376}