Monorepo for Tangled
tangled.org
1package eventconsumer
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "net/http"
8 "sync"
9 "time"
10
11 "tangled.org/core/eventconsumer/cursor"
12 "tangled.org/core/eventstream"
13 "tangled.org/core/log"
14
15 "github.com/avast/retry-go/v4"
16 "github.com/gorilla/websocket"
17)
18
19type ProcessFunc func(ctx context.Context, source Source, event eventstream.Event) error
20
21type ConsumerConfig struct {
22 Sources map[Source]struct{}
23 ProcessFunc ProcessFunc
24 RetryInterval time.Duration
25 MaxRetryInterval time.Duration
26 ConnectionTimeout time.Duration
27 WorkerCount int
28 QueueSize int
29 Logger *slog.Logger
30 CursorStore cursor.Store
31
32 Dialer *websocket.Dialer
33 RequestHeader http.Header
34 MaxRetryAttempts uint
35 OnConnectExceeded func(Source, error)
36}
37
38func NewConsumerConfig() *ConsumerConfig {
39 return &ConsumerConfig{
40 Sources: make(map[Source]struct{}),
41 }
42}
43
44type Consumer struct {
45 sourceWg sync.WaitGroup
46 workerWg sync.WaitGroup
47 dialer *websocket.Dialer
48 jobQueue chan job
49 logger *slog.Logger
50
51 // sourcesMu guards sources. It must only be held for short, non-blocking
52 // map operations; never across a blocking call (dial, read, close).
53 sourcesMu sync.Mutex
54 sources map[Source]*sourceState
55
56 cfg ConsumerConfig
57}
58
59type sourceState struct {
60 cancel context.CancelFunc
61 conn *websocket.Conn
62
63 cursorMu sync.Mutex
64 cursorMax int64
65}
66
67type job struct {
68 source Source
69 message []byte
70}
71
72func NewConsumer(cfg ConsumerConfig) *Consumer {
73 if cfg.RetryInterval == 0 {
74 cfg.RetryInterval = 15 * time.Minute
75 }
76 if cfg.ConnectionTimeout == 0 {
77 cfg.ConnectionTimeout = 10 * time.Second
78 }
79 if cfg.WorkerCount <= 0 {
80 cfg.WorkerCount = 5
81 }
82 if cfg.MaxRetryInterval == 0 {
83 cfg.MaxRetryInterval = 1 * time.Hour
84 }
85 if cfg.Logger == nil {
86 cfg.Logger = log.New("consumer")
87 }
88 if cfg.QueueSize == 0 {
89 cfg.QueueSize = 100
90 }
91 if cfg.CursorStore == nil {
92 cfg.CursorStore = &cursor.MemoryStore{}
93 }
94 dialer := cfg.Dialer
95 if dialer == nil {
96 dialer = websocket.DefaultDialer
97 }
98 return &Consumer{
99 cfg: cfg,
100 dialer: dialer,
101 jobQueue: make(chan job, cfg.QueueSize),
102 logger: cfg.Logger,
103 sources: make(map[Source]*sourceState),
104 }
105}
106
107func (c *Consumer) Start(ctx context.Context) {
108 c.cfg.Logger.Info("starting consumer", "config", c.cfg)
109
110 for range c.cfg.WorkerCount {
111 c.workerWg.Add(1)
112 go c.worker(ctx)
113 }
114
115 for source := range c.cfg.Sources {
116 c.AddSource(ctx, source)
117 }
118}
119
120func (c *Consumer) Stop() {
121 // snapshot cancels and conns under lock so we don't hold sourcesMu across Close
122 c.sourcesMu.Lock()
123 cancels := make([]context.CancelFunc, 0, len(c.sources))
124 conns := make([]*websocket.Conn, 0, len(c.sources))
125 for _, st := range c.sources {
126 if st.cancel != nil {
127 cancels = append(cancels, st.cancel)
128 }
129 if st.conn != nil {
130 conns = append(conns, st.conn)
131 }
132 }
133 c.sourcesMu.Unlock()
134
135 for _, cancel := range cancels {
136 cancel()
137 }
138 for _, conn := range conns {
139 conn.Close()
140 }
141
142 c.sourceWg.Wait()
143 close(c.jobQueue)
144 c.workerWg.Wait()
145}
146
147func (c *Consumer) AddSource(ctx context.Context, s Source) {
148 c.sourcesMu.Lock()
149 if _, ok := c.sources[s]; ok {
150 c.sourcesMu.Unlock()
151 c.logger.Info("source already present", "source", s)
152 return
153 }
154 srcCtx, cancel := context.WithCancel(ctx)
155 c.sources[s] = &sourceState{cancel: cancel}
156 c.sourcesMu.Unlock()
157
158 c.sourceWg.Add(1)
159 go c.startConnectionLoop(srcCtx, s)
160}
161
162func (c *Consumer) RemoveSource(s Source) {
163 c.sourcesMu.Lock()
164 st, ok := c.sources[s]
165 if !ok {
166 c.sourcesMu.Unlock()
167 c.logger.Info("source not present", "source", s)
168 return
169 }
170 delete(c.sources, s)
171 cancel := st.cancel
172 conn := st.conn
173 c.sourcesMu.Unlock()
174
175 // release lock before any potentially blocking call
176 if cancel != nil {
177 cancel()
178 }
179 if conn != nil {
180 conn.Close()
181 }
182}
183
184func (c *Consumer) worker(ctx context.Context) {
185 defer c.workerWg.Done()
186 for {
187 select {
188 case <-ctx.Done():
189 return
190 case j, ok := <-c.jobQueue:
191 if !ok {
192 return
193 }
194
195 var ev eventstream.Event
196 err := json.Unmarshal(j.message, &ev)
197 if err != nil {
198 c.logger.Error("error deserializing message", "source", j.source.Key(), "err", err)
199 continue
200 }
201
202 if err := c.cfg.ProcessFunc(ctx, j.source, ev); err != nil {
203 c.logger.Error("error processing message", "source", j.source, "err", err)
204 }
205
206 c.advanceCursor(j.source, ev.Created)
207 }
208 }
209}
210
211func (c *Consumer) advanceCursor(s Source, newCursor int64) {
212 if newCursor == 0 {
213 return
214 }
215 c.sourcesMu.Lock()
216 st, ok := c.sources[s]
217 c.sourcesMu.Unlock()
218 if !ok {
219 return
220 }
221
222 st.cursorMu.Lock()
223 defer st.cursorMu.Unlock()
224 if newCursor <= st.cursorMax {
225 return
226 }
227 st.cursorMax = newCursor
228 c.cfg.CursorStore.Set(s.Key(), newCursor)
229}
230
231func (c *Consumer) startConnectionLoop(ctx context.Context, source Source) {
232 defer c.sourceWg.Done()
233
234 // attempt connection initially
235 err := c.runConnection(ctx, source)
236 if err != nil {
237 c.logger.Error("failed to run connection", "err", err)
238 }
239
240 timer := time.NewTimer(1 * time.Minute)
241 defer timer.Stop()
242
243 // every subsequent attempt is delayed by 1 minute
244 for {
245 select {
246 case <-ctx.Done():
247 return
248 case <-timer.C:
249 err := c.runConnection(ctx, source)
250 if err != nil {
251 c.logger.Error("failed to run connection", "err", err)
252 }
253 timer.Reset(1 * time.Minute)
254 }
255 }
256}
257
258func (c *Consumer) runConnection(ctx context.Context, source Source) error {
259 cursor := c.cfg.CursorStore.Get(source.Key())
260
261 u, err := source.URL(cursor)
262 if err != nil {
263 return err
264 }
265
266 c.logger.Info("connecting", "url", u.String())
267
268 retryOpts := []retry.Option{
269 retry.Attempts(c.cfg.MaxRetryAttempts),
270 retry.DelayType(retry.BackOffDelay),
271 retry.Delay(c.cfg.RetryInterval),
272 retry.MaxDelay(c.cfg.MaxRetryInterval),
273 retry.MaxJitter(c.cfg.RetryInterval / 5),
274 retry.OnRetry(func(n uint, err error) {
275 c.logger.Info("retrying connection",
276 "source", source,
277 "url", u.String(),
278 "attempt", n+1,
279 "err", err,
280 )
281 }),
282 retry.Context(ctx),
283 }
284
285 var conn *websocket.Conn
286
287 err = retry.Do(func() error {
288 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
289 defer cancel()
290 conn, _, err = c.dialer.DialContext(connCtx, u.String(), c.cfg.RequestHeader)
291 return err
292 }, retryOpts...)
293 if err != nil {
294 if c.cfg.OnConnectExceeded != nil {
295 c.cfg.OnConnectExceeded(source, err)
296 }
297 return err
298 }
299
300 // Register the conn. If the source was removed (or our ctx cancelled)
301 // while we were dialing, drop this conn instead of installing it.
302 c.sourcesMu.Lock()
303 st, ok := c.sources[source]
304 if !ok || ctx.Err() != nil {
305 c.sourcesMu.Unlock()
306 conn.Close()
307 if ctx.Err() != nil {
308 return ctx.Err()
309 }
310 return nil
311 }
312 st.conn = conn
313 c.sourcesMu.Unlock()
314
315 defer func() {
316 // Clear the conn from state, but only if it's still our conn (a
317 // concurrent RemoveSource may have already done it).
318 c.sourcesMu.Lock()
319 if st, ok := c.sources[source]; ok && st.conn == conn {
320 st.conn = nil
321 }
322 c.sourcesMu.Unlock()
323 conn.Close()
324 }()
325
326 c.logger.Info("connected", "source", source)
327
328 for {
329 select {
330 case <-ctx.Done():
331 return nil
332 default:
333 msgType, msg, err := conn.ReadMessage()
334 if err != nil {
335 return err
336 }
337 if msgType != websocket.TextMessage {
338 continue
339 }
340 select {
341 case c.jobQueue <- job{source: source, message: msg}:
342 case <-ctx.Done():
343 return nil
344 }
345 }
346 }
347}