Monorepo for Tangled
tangled.org
1package eventconsumer
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "net/http"
8 "net/url"
9 "sync"
10 "time"
11
12 "tangled.org/core/eventconsumer/cursor"
13 "tangled.org/core/eventstream"
14 "tangled.org/core/log"
15
16 "github.com/avast/retry-go/v4"
17 "github.com/gorilla/websocket"
18)
19
20type ProcessFunc func(ctx context.Context, source Source, event eventstream.Event) error
21
22type ConsumerConfig struct {
23 Sources map[Source]struct{}
24 ProcessFunc ProcessFunc
25 RetryInterval time.Duration
26 MaxRetryInterval time.Duration
27 ConnectionTimeout time.Duration
28 WorkerCount int
29 QueueSize int
30 Logger *slog.Logger
31 CursorStore cursor.Store
32 URLFunc func(Source, int64) (*url.URL, error)
33
34 Dialer *websocket.Dialer
35 RequestHeader http.Header
36 MaxRetryAttempts uint
37 OnConnectExceeded func(Source, error)
38}
39
40func NewConsumerConfig() *ConsumerConfig {
41 return &ConsumerConfig{
42 Sources: make(map[Source]struct{}),
43 }
44}
45
46type Consumer struct {
47 sourceWg sync.WaitGroup
48 workerWg sync.WaitGroup
49 dialer *websocket.Dialer
50 jobQueue chan job
51 logger *slog.Logger
52
53 // sourcesMu guards sources. It must only be held for short, non-blocking
54 // map operations; never across a blocking call (dial, read, close).
55 sourcesMu sync.Mutex
56 sources map[Source]*sourceState
57
58 cfg ConsumerConfig
59}
60
61type sourceState struct {
62 cancel context.CancelFunc
63 conn *websocket.Conn
64
65 cursorMu sync.Mutex
66 cursorMax int64
67}
68
69type job struct {
70 source Source
71 message []byte
72}
73
74func NewConsumer(cfg ConsumerConfig) *Consumer {
75 if cfg.RetryInterval == 0 {
76 cfg.RetryInterval = 15 * time.Minute
77 }
78 if cfg.ConnectionTimeout == 0 {
79 cfg.ConnectionTimeout = 10 * time.Second
80 }
81 if cfg.WorkerCount <= 0 {
82 cfg.WorkerCount = 5
83 }
84 if cfg.MaxRetryInterval == 0 {
85 cfg.MaxRetryInterval = 1 * time.Hour
86 }
87 if cfg.Logger == nil {
88 cfg.Logger = log.New("consumer")
89 }
90 if cfg.QueueSize == 0 {
91 cfg.QueueSize = 100
92 }
93 if cfg.CursorStore == nil {
94 cfg.CursorStore = &cursor.MemoryStore{}
95 }
96 if cfg.URLFunc == nil {
97 cfg.URLFunc = DefaultURL(false)
98 }
99 dialer := cfg.Dialer
100 if dialer == nil {
101 dialer = websocket.DefaultDialer
102 }
103 return &Consumer{
104 cfg: cfg,
105 dialer: dialer,
106 jobQueue: make(chan job, cfg.QueueSize),
107 logger: cfg.Logger,
108 sources: make(map[Source]*sourceState),
109 }
110}
111
112func (c *Consumer) Start(ctx context.Context) {
113 c.cfg.Logger.Info("starting consumer", "config", c.cfg)
114
115 for range c.cfg.WorkerCount {
116 c.workerWg.Add(1)
117 go c.worker(ctx)
118 }
119
120 for source := range c.cfg.Sources {
121 c.AddSource(ctx, source)
122 }
123}
124
125func (c *Consumer) Stop() {
126 // snapshot cancels and conns under lock so we don't hold sourcesMu across Close
127 c.sourcesMu.Lock()
128 cancels := make([]context.CancelFunc, 0, len(c.sources))
129 conns := make([]*websocket.Conn, 0, len(c.sources))
130 for _, st := range c.sources {
131 if st.cancel != nil {
132 cancels = append(cancels, st.cancel)
133 }
134 if st.conn != nil {
135 conns = append(conns, st.conn)
136 }
137 }
138 c.sourcesMu.Unlock()
139
140 for _, cancel := range cancels {
141 cancel()
142 }
143 for _, conn := range conns {
144 conn.Close()
145 }
146
147 c.sourceWg.Wait()
148 close(c.jobQueue)
149 c.workerWg.Wait()
150}
151
152func (c *Consumer) AddSource(ctx context.Context, s Source) {
153 c.sourcesMu.Lock()
154 if _, ok := c.sources[s]; ok {
155 c.sourcesMu.Unlock()
156 c.logger.Info("source already present", "source", s)
157 return
158 }
159 srcCtx, cancel := context.WithCancel(ctx)
160 c.sources[s] = &sourceState{cancel: cancel}
161 c.sourcesMu.Unlock()
162
163 c.sourceWg.Add(1)
164 go c.startConnectionLoop(srcCtx, s)
165}
166
167func (c *Consumer) RemoveSource(s Source) {
168 c.sourcesMu.Lock()
169 st, ok := c.sources[s]
170 if !ok {
171 c.sourcesMu.Unlock()
172 c.logger.Info("source not present", "source", s)
173 return
174 }
175 delete(c.sources, s)
176 cancel := st.cancel
177 conn := st.conn
178 c.sourcesMu.Unlock()
179
180 // release lock before any potentially blocking call
181 if cancel != nil {
182 cancel()
183 }
184 if conn != nil {
185 conn.Close()
186 }
187}
188
189func (c *Consumer) worker(ctx context.Context) {
190 defer c.workerWg.Done()
191 for {
192 select {
193 case <-ctx.Done():
194 return
195 case j, ok := <-c.jobQueue:
196 if !ok {
197 return
198 }
199
200 var ev eventstream.Event
201 err := json.Unmarshal(j.message, &ev)
202 if err != nil {
203 c.logger.Error("error deserializing message", "source", j.source.Key(), "err", err)
204 continue
205 }
206
207 if err := c.cfg.ProcessFunc(ctx, j.source, ev); err != nil {
208 c.logger.Error("error processing message", "source", j.source, "err", err)
209 }
210
211 c.advanceCursor(j.source, ev.Created)
212 }
213 }
214}
215
216func (c *Consumer) advanceCursor(s Source, newCursor int64) {
217 if newCursor == 0 {
218 return
219 }
220 c.sourcesMu.Lock()
221 st, ok := c.sources[s]
222 c.sourcesMu.Unlock()
223 if !ok {
224 return
225 }
226
227 st.cursorMu.Lock()
228 defer st.cursorMu.Unlock()
229 if newCursor <= st.cursorMax {
230 return
231 }
232 st.cursorMax = newCursor
233 c.cfg.CursorStore.Set(s.Key(), newCursor)
234}
235
236func (c *Consumer) startConnectionLoop(ctx context.Context, source Source) {
237 defer c.sourceWg.Done()
238
239 // attempt connection initially
240 err := c.runConnection(ctx, source)
241 if err != nil {
242 c.logger.Error("failed to run connection", "err", err)
243 }
244
245 timer := time.NewTimer(1 * time.Minute)
246 defer timer.Stop()
247
248 // every subsequent attempt is delayed by 1 minute
249 for {
250 select {
251 case <-ctx.Done():
252 return
253 case <-timer.C:
254 err := c.runConnection(ctx, source)
255 if err != nil {
256 c.logger.Error("failed to run connection", "err", err)
257 }
258 timer.Reset(1 * time.Minute)
259 }
260 }
261}
262
263func (c *Consumer) runConnection(ctx context.Context, source Source) error {
264 cursor := c.cfg.CursorStore.Get(source.Key())
265
266 u, err := c.cfg.URLFunc(source, cursor)
267 if err != nil {
268 return err
269 }
270
271 c.logger.Info("connecting", "url", u.String())
272
273 retryOpts := []retry.Option{
274 retry.Attempts(c.cfg.MaxRetryAttempts),
275 retry.DelayType(retry.BackOffDelay),
276 retry.Delay(c.cfg.RetryInterval),
277 retry.MaxDelay(c.cfg.MaxRetryInterval),
278 retry.MaxJitter(c.cfg.RetryInterval / 5),
279 retry.OnRetry(func(n uint, err error) {
280 c.logger.Info("retrying connection",
281 "source", source,
282 "url", u.String(),
283 "attempt", n+1,
284 "err", err,
285 )
286 }),
287 retry.Context(ctx),
288 }
289
290 var conn *websocket.Conn
291
292 err = retry.Do(func() error {
293 connCtx, cancel := context.WithTimeout(ctx, c.cfg.ConnectionTimeout)
294 defer cancel()
295 conn, _, err = c.dialer.DialContext(connCtx, u.String(), c.cfg.RequestHeader)
296 return err
297 }, retryOpts...)
298 if err != nil {
299 if c.cfg.OnConnectExceeded != nil {
300 c.cfg.OnConnectExceeded(source, err)
301 }
302 return err
303 }
304
305 // Register the conn. If the source was removed (or our ctx cancelled)
306 // while we were dialing, drop this conn instead of installing it.
307 c.sourcesMu.Lock()
308 st, ok := c.sources[source]
309 if !ok || ctx.Err() != nil {
310 c.sourcesMu.Unlock()
311 conn.Close()
312 if ctx.Err() != nil {
313 return ctx.Err()
314 }
315 return nil
316 }
317 st.conn = conn
318 c.sourcesMu.Unlock()
319
320 defer func() {
321 // Clear the conn from state, but only if it's still our conn (a
322 // concurrent RemoveSource may have already done it).
323 c.sourcesMu.Lock()
324 if st, ok := c.sources[source]; ok && st.conn == conn {
325 st.conn = nil
326 }
327 c.sourcesMu.Unlock()
328 conn.Close()
329 }()
330
331 c.logger.Info("connected", "source", source)
332
333 for {
334 select {
335 case <-ctx.Done():
336 return nil
337 default:
338 msgType, msg, err := conn.ReadMessage()
339 if err != nil {
340 return err
341 }
342 if msgType != websocket.TextMessage {
343 continue
344 }
345 select {
346 case c.jobQueue <- job{source: source, message: msg}:
347 case <-ctx.Done():
348 return nil
349 }
350 }
351 }
352}