Monorepo for Tangled
tangled.org
1package lexutil
2
3import (
4 "cmp"
5 "context"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "net/url"
10 "time"
11
12 indigoxrpc "github.com/bluesky-social/indigo/xrpc"
13 "github.com/carlmjohnson/versioninfo"
14 "github.com/gorilla/websocket"
15 cbg "github.com/whyrusleeping/cbor-gen"
16)
17
18type Client struct {
19 indigoxrpc.Client
20 Dialer websocket.Dialer
21 Logger *slog.Logger
22}
23
24var _ LexClient = (*Client)(nil)
25
26func makeParams(p map[string]any) url.Values {
27 params := url.Values{}
28 for k, v := range p {
29 if s, ok := v.([]string); ok {
30 for _, v := range s {
31 params.Add(k, v)
32 }
33 } else {
34 params.Add(k, fmt.Sprint(v))
35 }
36 }
37 return params
38}
39
40type processFn func(ctx context.Context, cr *cbg.CborReader) error
41
42func (c *Client) LexDo(ctx context.Context, method string, inputEncoding string, endpoint string, params map[string]any, bodyData any, out any) error {
43 switch method {
44 case Subscription:
45 if process, ok := out.(processFn); ok {
46 return c.LexSubscribe(ctx, endpoint, params, process)
47 } else if redialer, ok := out.(Redialer); ok {
48 return c.LexSubscribeWithRedialer(ctx, endpoint, params, redialer)
49 } else {
50 return fmt.Errorf("unknown output type: %T", out)
51 }
52 default:
53 return c.Client.LexDo(ctx, method, inputEncoding, endpoint, params, bodyData, out)
54 }
55}
56
57func (c *Client) getHeader() http.Header {
58 header := http.Header{}
59 if c.UserAgent != nil {
60 header.Set("User-Agent", *c.UserAgent)
61 } else {
62 header.Set("User-Agent", "extlexutil/"+versioninfo.Short())
63 }
64 if c.Headers != nil {
65 for k, v := range c.Headers {
66 header.Set(k, v)
67 }
68 }
69 return header
70}
71
72func (c *Client) LexSubscribe(ctx context.Context, endpoint string, params map[string]any, process func(ctx context.Context, cr *cbg.CborReader) error) error {
73 logger := cmp.Or(c.Logger, slog.Default().With("system", "events"))
74 rurl, err := url.Parse(c.Host)
75 if err != nil {
76 return err
77 }
78 if rurl.Scheme == "http" {
79 rurl.Scheme = "ws"
80 } else {
81 rurl.Scheme = "wss"
82 }
83 surl := rurl.JoinPath("/xrpc", endpoint)
84 surl.RawQuery = makeParams(params).Encode()
85
86 header := c.getHeader()
87
88 u := surl.String()
89 conn, resp, err := c.Dialer.DialContext(ctx, u, header)
90 if err != nil {
91 return fmt.Errorf("%w: %w", ErrDialFailure, err)
92 }
93
94 logger.Debug("event subscription response", "code", resp.StatusCode, "url", u)
95
96 return c.handleConn(ctx, conn, process)
97}
98
99func (c *Client) LexSubscribeWithRedialer(ctx context.Context, endpoint string, params map[string]any, redialer Redialer) error {
100 logger := cmp.Or(c.Logger, slog.Default().With("system", "events"))
101 rurl, err := url.Parse(c.Host)
102 if err != nil {
103 return err
104 }
105 if rurl.Scheme == "http" {
106 rurl.Scheme = "ws"
107 } else {
108 rurl.Scheme = "wss"
109 }
110 surl := rurl.JoinPath("/xrpc", endpoint)
111
112 header := c.getHeader()
113
114 var backoff int
115 for {
116 select {
117 case <-ctx.Done():
118 return ctx.Err()
119 default:
120 }
121
122 surl.RawQuery = makeParams(params).Encode()
123
124 u := surl.String()
125 conn, resp, err := c.Dialer.DialContext(ctx, u, header)
126 if err != nil {
127 logger.Warn("dialing failed", "err", err, "backoff", backoff)
128 time.Sleep(time.Duration(5+backoff) * time.Second)
129 backoff++
130
131 if backoff > 15 {
132 return fmt.Errorf("%w: %w", ErrDialFailure, err)
133 }
134
135 continue
136 }
137
138 logger.Debug("event subscription response", "code", resp.StatusCode, "url", u)
139
140 if err := c.handleConn(ctx, conn, redialer.Process); err != nil {
141 logger.Warn("host connection failed", "err", err, "backoff", backoff)
142 }
143
144 // updates cursor & backoff
145 updated := redialer.UpdateParams(ctx, params)
146 if updated {
147 backoff = 0
148 }
149 }
150}
151
152func (c *Client) handleConn(ctx context.Context, conn *websocket.Conn, process func(ctx context.Context, cr *cbg.CborReader) error) error {
153 logger := cmp.Or(c.Logger, slog.Default().With("system", "events"))
154 ctx, cancel := context.WithCancel(ctx)
155 defer cancel()
156
157 go func() {
158 t := time.NewTicker(time.Second * 30)
159 defer t.Stop()
160 failcount := 0
161
162 for {
163
164 select {
165 case <-t.C:
166 if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second*10)); err != nil {
167 logger.Warn("failed to ping", "err", err)
168 failcount++
169 if failcount >= 4 {
170 logger.Error("too many ping fails", "count", failcount)
171 conn.Close()
172 return
173 }
174 } else {
175 failcount = 0 // ok ping
176 }
177 case <-ctx.Done():
178 conn.Close()
179 return
180 }
181 }
182 }()
183
184 conn.SetPingHandler(func(message string) error {
185 err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second*60))
186 if err == websocket.ErrCloseSent {
187 return nil
188 }
189 return err
190 })
191
192 conn.SetPongHandler(func(_ string) error {
193 if err := conn.SetReadDeadline(time.Now().Add(time.Minute)); err != nil {
194 logger.Error("failed to set read deadline", "err", err)
195 }
196
197 return nil
198 })
199
200 cr := new(cbg.CborReader)
201
202 for {
203 select {
204 case <-ctx.Done():
205 return ctx.Err()
206 default:
207 }
208
209 mt, rawReader, err := conn.NextReader()
210 if err != nil {
211 return fmt.Errorf("conn err at read: %w", err)
212 }
213
214 if mt != websocket.BinaryMessage {
215 return fmt.Errorf("expected binary message from subscription endpoint")
216 }
217
218 cr.SetReader(rawReader)
219
220 if err := process(ctx, cr); err != nil {
221 return err
222 }
223 }
224}