Monorepo for Tangled
tangled.org
1package eventstream
2
3import (
4 "database/sql"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "net/http"
11 "net/http/httptest"
12 "strconv"
13 "strings"
14 "sync"
15 "testing"
16 "time"
17
18 "github.com/gorilla/websocket"
19 _ "github.com/mattn/go-sqlite3"
20 "tangled.org/core/notifier"
21)
22
23type memSource struct {
24 mu sync.Mutex
25 events []Event
26}
27
28func (s *memSource) add(ev Event) {
29 s.mu.Lock()
30 defer s.mu.Unlock()
31 s.events = append(s.events, ev)
32}
33
34func (s *memSource) GetEvents(cursor int64, limit int) ([]Event, error) {
35 s.mu.Lock()
36 defer s.mu.Unlock()
37 out := []Event{}
38 for _, ev := range s.events {
39 if ev.Created > cursor {
40 out = append(out, ev)
41 if len(out) == limit {
42 break
43 }
44 }
45 }
46 return out, nil
47}
48
49func mkEvent(i int) Event {
50 return Event{
51 Rkey: fmt.Sprintf("rk-%04d", i),
52 Nsid: "sh.tangled.test",
53 EventJson: json.RawMessage(fmt.Sprintf(`{"i":%d}`, i)),
54 Created: int64(i + 1),
55 }
56}
57
58func startServer(t *testing.T, src Backend, cfg StreamConfig) (string, *notifier.Notifier, <-chan error) {
59 t.Helper()
60 n := notifier.New()
61 cfg.Backend = src
62 cfg.Notifier = &n
63 cfg.Logger = slog.New(slog.NewTextHandler(io.Discard, nil))
64
65 errCh := make(chan error, 1)
66 mux := http.NewServeMux()
67 mux.HandleFunc("/events", func(w http.ResponseWriter, r *http.Request) {
68 errCh <- Stream(w, r, cfg)
69 })
70 srv := httptest.NewServer(mux)
71 t.Cleanup(srv.Close)
72 wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/events"
73 return wsURL, &n, errCh
74}
75
76func dial(t *testing.T, wsURL string, cursor int64) *websocket.Conn {
77 t.Helper()
78 if cursor != 0 {
79 wsURL += "?cursor=" + strconv.FormatInt(cursor, 10)
80 }
81 c, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
82 if err != nil {
83 t.Fatalf("dial: %v", err)
84 }
85 t.Cleanup(func() { c.Close() })
86 return c
87}
88
89func readN(t *testing.T, c *websocket.Conn, n int) []Event {
90 t.Helper()
91 c.SetReadDeadline(time.Now().Add(2 * time.Second))
92 out := make([]Event, 0, n)
93 for range n {
94 _, msg, err := c.ReadMessage()
95 if err != nil {
96 t.Fatalf("read message at %d/%d: %v", len(out), n, err)
97 }
98 var ev Event
99 if err := json.Unmarshal(msg, &ev); err != nil {
100 t.Fatalf("unmarshal: %v", err)
101 }
102 out = append(out, ev)
103 }
104 return out
105}
106
107func TestStream_DrainStopsOnShortBatch(t *testing.T) {
108 src := &memSource{}
109 for i := range 7 {
110 src.add(mkEvent(i))
111 }
112
113 wsURL, _, errCh := startServer(t, src, StreamConfig{
114 BatchSize: 3,
115 MaxBatchesPerDrain: 10,
116 })
117 c := dial(t, wsURL, 0)
118
119 got := readN(t, c, 7)
120 for i, ev := range got {
121 if ev.Created != int64(i+1) {
122 t.Fatalf("event %d: got created=%d", i, ev.Created)
123 }
124 }
125
126 c.Close()
127 select {
128 case err := <-errCh:
129 if err != nil && !isCloseErr(err) {
130 t.Fatalf("server error: %v", err)
131 }
132 case <-time.After(2 * time.Second):
133 t.Fatal("server did not exit")
134 }
135}
136
137func TestStream_DrainHitsCap_ReturnsErrDrainCap(t *testing.T) {
138 src := &memSource{}
139 for i := range 5 {
140 src.add(mkEvent(i))
141 }
142
143 wsURL, _, errCh := startServer(t, src, StreamConfig{
144 BatchSize: 2,
145 MaxBatchesPerDrain: 2,
146 })
147 c := dial(t, wsURL, 0)
148
149 got := readN(t, c, 4)
150 if len(got) != 4 {
151 t.Fatalf("want 4 events before cap, got %d", len(got))
152 }
153 if got[3].Created != 4 {
154 t.Fatalf("last delivered created = %d, want 4", got[3].Created)
155 }
156
157 select {
158 case err := <-errCh:
159 if !errors.Is(err, ErrDrainCap) {
160 t.Fatalf("want ErrDrainCap, got %v", err)
161 }
162 case <-time.After(2 * time.Second):
163 t.Fatal("server did not return cap error")
164 }
165}
166
167func TestStream_CursorResume(t *testing.T) {
168 src := &memSource{}
169 for i := range 5 {
170 src.add(mkEvent(i))
171 }
172
173 wsURL, _, errCh := startServer(t, src, StreamConfig{
174 BatchSize: 10,
175 MaxBatchesPerDrain: 10,
176 })
177 c := dial(t, wsURL, 3)
178
179 got := readN(t, c, 2)
180 if got[0].Created != 4 || got[1].Created != 5 {
181 t.Fatalf("resume from cursor: got %d,%d want 4,5", got[0].Created, got[1].Created)
182 }
183
184 c.Close()
185 <-errCh
186}
187
188func TestStream_LiveDelivery(t *testing.T) {
189 src := &memSource{}
190
191 wsURL, n, errCh := startServer(t, src, StreamConfig{
192 BatchSize: 10,
193 MaxBatchesPerDrain: 10,
194 })
195 c := dial(t, wsURL, 0)
196
197 src.add(mkEvent(42))
198 n.NotifyAll()
199
200 got := readN(t, c, 1)
201 if got[0].Created != 43 {
202 t.Fatalf("live event created = %d, want 43", got[0].Created)
203 }
204
205 c.Close()
206 <-errCh
207}
208
209func TestStream_LiveBurstExceedsBatchSize_DrainsAll(t *testing.T) {
210 src := &memSource{}
211
212 wsURL, n, errCh := startServer(t, src, StreamConfig{
213 BatchSize: 5,
214 MaxBatchesPerDrain: 100,
215 })
216 c := dial(t, wsURL, 0)
217
218 const burst = 17
219 for i := range burst {
220 src.add(mkEvent(i))
221 }
222 n.NotifyAll()
223
224 got := readN(t, c, burst)
225 if len(got) != burst {
226 t.Fatalf("got %d events, want %d", len(got), burst)
227 }
228 for i, ev := range got {
229 if ev.Created != int64(i+1) {
230 t.Fatalf("event %d: got created=%d want %d", i, ev.Created, i+1)
231 }
232 }
233
234 c.Close()
235 <-errCh
236}
237
238func TestInsert_MonotonicCreatedUnderConcurrency(t *testing.T) {
239 db, err := sql.Open("sqlite3", t.TempDir()+"/events.db")
240 if err != nil {
241 t.Fatalf("open: %v", err)
242 }
243 t.Cleanup(func() { db.Close() })
244 if _, err := db.Exec(`create table events (
245 rkey text not null,
246 nsid text not null,
247 event text not null,
248 created integer not null,
249 primary key (rkey, nsid)
250 )`); err != nil {
251 t.Fatalf("schema: %v", err)
252 }
253
254 n := notifier.New()
255 const total = 300
256 var wg sync.WaitGroup
257 for i := range total {
258 wg.Add(1)
259 go func(i int) {
260 defer wg.Done()
261 if err := Insert(db, Event{
262 Rkey: fmt.Sprintf("rk-%d", i),
263 Nsid: "sh.tangled.test",
264 EventJson: json.RawMessage("{}"),
265 }, &n); err != nil {
266 t.Errorf("insert %d: %v", i, err)
267 }
268 }(i)
269 }
270 wg.Wait()
271
272 rows, err := db.Query(`select created from events order by created asc`)
273 if err != nil {
274 t.Fatalf("read: %v", err)
275 }
276 defer rows.Close()
277
278 var prev int64
279 count := 0
280 for rows.Next() {
281 var c int64
282 if err := rows.Scan(&c); err != nil {
283 t.Fatalf("scan: %v", err)
284 }
285 if count > 0 && c <= prev {
286 t.Fatalf("created not strictly increasing: %d <= %d", c, prev)
287 }
288 prev = c
289 count++
290 }
291 if count != total {
292 t.Fatalf("got %d rows, want %d", count, total)
293 }
294}
295
296func isCloseErr(err error) bool {
297 if err == nil {
298 return false
299 }
300 if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
301 return true
302 }
303 return strings.Contains(err.Error(), "use of closed network connection") ||
304 strings.Contains(err.Error(), "websocket: close") ||
305 strings.Contains(err.Error(), "broken pipe")
306}