forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package client
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "net/http"
11 "net/url"
12 "slices"
13 "strings"
14 "time"
15
16 cache "github.com/go-pkgz/expirable-cache/v3"
17 "github.com/haileyok/cocoon/internal/helpers"
18 "github.com/lestrrat-go/jwx/v2/jwk"
19)
20
21type Manager struct {
22 cli *http.Client
23 logger *slog.Logger
24 jwksCache cache.Cache[string, jwk.Key]
25 metadataCache cache.Cache[string, *Metadata]
26}
27
28type ManagerArgs struct {
29 Cli *http.Client
30 Logger *slog.Logger
31}
32
33func NewManager(args ManagerArgs) *Manager {
34 if args.Logger == nil {
35 args.Logger = slog.Default()
36 }
37
38 if args.Cli == nil {
39 args.Cli = http.DefaultClient
40 }
41
42 jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
43 metadataCache := cache.NewCache[string, *Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
44
45 return &Manager{
46 cli: args.Cli,
47 logger: args.Logger,
48 jwksCache: jwksCache,
49 metadataCache: metadataCache,
50 }
51}
52
53func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) {
54 metadata, err := cm.getClientMetadata(ctx, clientId)
55 if err != nil {
56 return nil, err
57 }
58
59 var jwks jwk.Key
60 if metadata.TokenEndpointAuthMethod == "private_key_jwt" {
61 if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 {
62 // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
63 // make sure we use the right one
64 b, err := json.Marshal(metadata.JWKS.Keys[0])
65 if err != nil {
66 return nil, err
67 }
68
69 k, err := helpers.ParseJWKFromBytes(b)
70 if err != nil {
71 return nil, err
72 }
73
74 jwks = k
75 } else if metadata.JWKSURI != nil {
76 maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
77 if err != nil {
78 return nil, err
79 }
80
81 jwks = maybeJwks
82 } else {
83 return nil, fmt.Errorf("no valid jwks found in oauth client metadata")
84 }
85 }
86
87 return &Client{
88 Metadata: metadata,
89 JWKS: jwks,
90 IsLocalhostClient: isLocalhostClientID(clientId),
91 }, nil
92}
93
94func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) {
95 if isLocalhostClientID(clientId) {
96 return buildLocalhostVirtualMetadata(clientId)
97 }
98
99 cached, ok := cm.metadataCache.Get(clientId)
100 if !ok {
101 req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil)
102 if err != nil {
103 return nil, err
104 }
105
106 resp, err := cm.cli.Do(req)
107 if err != nil {
108 return nil, err
109 }
110 defer resp.Body.Close()
111
112 if resp.StatusCode != http.StatusOK {
113 io.Copy(io.Discard, resp.Body)
114 return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode)
115 }
116
117 b, err := io.ReadAll(resp.Body)
118 if err != nil {
119 return nil, fmt.Errorf("error reading bytes from client response: %w", err)
120 }
121
122 validated, err := validateAndParseMetadata(clientId, b)
123 if err != nil {
124 return nil, err
125 }
126
127 cm.metadataCache.Set(clientId, validated, 10*time.Minute)
128
129 return validated, nil
130 } else {
131 return cached, nil
132 }
133}
134
135func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) {
136 jwks, ok := cm.jwksCache.Get(clientId)
137 if !ok {
138 req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil)
139 if err != nil {
140 return nil, err
141 }
142
143 resp, err := cm.cli.Do(req)
144 if err != nil {
145 return nil, err
146 }
147 defer resp.Body.Close()
148
149 if resp.StatusCode != http.StatusOK {
150 io.Copy(io.Discard, resp.Body)
151 return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode)
152 }
153
154 type Keys struct {
155 Keys []map[string]any `json:"keys"`
156 }
157
158 var keys Keys
159 if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
160 return nil, fmt.Errorf("error unmarshaling keys response: %w", err)
161 }
162
163 if len(keys.Keys) == 0 {
164 return nil, errors.New("no keys in jwks response")
165 }
166
167 // TODO: this is again bad, we should be figuring out which one we need to use...
168 b, err := json.Marshal(keys.Keys[0])
169 if err != nil {
170 return nil, fmt.Errorf("could not marshal key: %w", err)
171 }
172
173 k, err := helpers.ParseJWKFromBytes(b)
174 if err != nil {
175 return nil, err
176 }
177
178 jwks = k
179 }
180
181 return jwks, nil
182}
183
184func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) {
185 var metadataMap map[string]any
186 if err := json.Unmarshal(b, &metadataMap); err != nil {
187 return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
188 }
189
190 _, jwksOk := metadataMap["jwks"].(string)
191 _, jwksUriOk := metadataMap["jwks_uri"].(string)
192 if jwksOk && jwksUriOk {
193 return nil, errors.New("jwks_uri and jwks are mutually exclusive")
194 }
195
196 for _, k := range []string{
197 "default_max_age",
198 "userinfo_signed_response_alg",
199 "id_token_signed_response_alg",
200 "userinfo_encryhpted_response_alg",
201 "authorization_encrypted_response_enc",
202 "authorization_encrypted_response_alg",
203 "tls_client_certificate_bound_access_tokens",
204 } {
205 _, kOk := metadataMap[k]
206 if kOk {
207 return nil, fmt.Errorf("unsupported `%s` parameter", k)
208 }
209 }
210
211 var metadata Metadata
212 if err := json.Unmarshal(b, &metadata); err != nil {
213 return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
214 }
215
216 if metadata.ClientURI == "" {
217 u, err := url.Parse(metadata.ClientID)
218 if err != nil {
219 return nil, fmt.Errorf("unable to parse client id: %w", err)
220 }
221 u.RawPath = ""
222 u.RawQuery = ""
223 metadata.ClientURI = u.String()
224 }
225
226 u, err := url.Parse(metadata.ClientURI)
227 if err != nil {
228 return nil, fmt.Errorf("unable to parse client uri: %w", err)
229 }
230
231 if metadata.ClientName == "" {
232 metadata.ClientName = metadata.ClientURI
233 }
234
235 if isLocalHostname(u.Hostname()) {
236 return nil, fmt.Errorf("`client_uri` hostname is invalid: %s", u.Hostname())
237 }
238
239 if metadata.Scope == "" {
240 return nil, errors.New("missing `scopes` scope")
241 }
242
243 scopes := strings.Split(metadata.Scope, " ")
244 if !slices.Contains(scopes, "atproto") {
245 return nil, errors.New("missing `atproto` scope")
246 }
247
248 scopesMap := map[string]bool{}
249 for _, scope := range scopes {
250 if scopesMap[scope] {
251 return nil, fmt.Errorf("duplicate scope `%s`", scope)
252 }
253
254 // TODO: check for unsupported scopes
255
256 scopesMap[scope] = true
257 }
258
259 grantTypesMap := map[string]bool{}
260 for _, gt := range metadata.GrantTypes {
261 if grantTypesMap[gt] {
262 return nil, fmt.Errorf("duplicate grant type `%s`", gt)
263 }
264
265 switch gt {
266 case "implicit":
267 return nil, errors.New("grantg type `implicit` is not allowed")
268 case "authorization_code", "refresh_token":
269 // TODO check if this grant type is supported
270 default:
271 return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt)
272 }
273
274 grantTypesMap[gt] = true
275 }
276
277 if metadata.ClientID != clientId {
278 return nil, errors.New("`client_id` does not match")
279 }
280
281 subjectType, subjectTypeOk := metadataMap["subject_type"].(string)
282 if subjectTypeOk && subjectType != "public" {
283 return nil, errors.New("only public `subject_type` is supported")
284 }
285
286 switch metadata.TokenEndpointAuthMethod {
287 case "none":
288 if metadata.TokenEndpointAuthSigningAlg != "" {
289 return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg")
290 }
291 case "private_key_jwt":
292 if metadata.JWKS == nil && metadata.JWKSURI == nil {
293 return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri")
294 }
295
296 if metadata.JWKS != nil && len(metadata.JWKS.Keys) == 0 {
297 return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks")
298 }
299
300 if metadata.TokenEndpointAuthSigningAlg == "" {
301 return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata")
302 }
303 default:
304 return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod)
305 }
306
307 if !metadata.DpopBoundAccessTokens {
308 return nil, errors.New("dpop_bound_access_tokens must be true")
309 }
310
311 if !slices.Contains(metadata.ResponseTypes, "code") {
312 return nil, errors.New("response_types must inclue `code`")
313 }
314
315 if !slices.Contains(metadata.GrantTypes, "authorization_code") {
316 return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`")
317 }
318
319 if len(metadata.RedirectURIs) == 0 {
320 return nil, errors.New("at least one `redirect_uri` is required")
321 }
322
323 if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod != "none" {
324 return nil, errors.New("native clients must authenticate using `none` method")
325 }
326
327 if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") {
328 for _, ruri := range metadata.RedirectURIs {
329 u, err := url.Parse(ruri)
330 if err != nil {
331 return nil, fmt.Errorf("error parsing redirect uri: %w", err)
332 }
333
334 if u.Scheme != "https" {
335 return nil, errors.New("web clients must use https redirect uris")
336 }
337
338 if u.Hostname() == "localhost" {
339 return nil, errors.New("web clients must not use localhost as the hostname")
340 }
341 }
342 }
343
344 for _, ruri := range metadata.RedirectURIs {
345 u, err := url.Parse(ruri)
346 if err != nil {
347 return nil, fmt.Errorf("error parsing redirect uri: %w", err)
348 }
349
350 if u.User != nil {
351 if u.User.Username() != "" {
352 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
353 }
354
355 if _, hasPass := u.User.Password(); hasPass {
356 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
357 }
358 }
359
360 switch true {
361 case u.Hostname() == "localhost":
362 return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)")
363 case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]":
364 if metadata.ApplicationType != "native" {
365 return nil, errors.New("loopback redirect uris are only allowed for native apps")
366 }
367
368 if u.Port() != "" {
369 // reference impl doesn't do anything with this?
370 }
371
372 if u.Scheme != "http" {
373 return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri)
374 }
375 case u.Scheme == "http":
376 return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme")
377 case u.Scheme == "https":
378 if isLocalHostname(u.Hostname()) {
379 return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri)
380 }
381 case strings.Contains(u.Scheme, "."):
382 if metadata.ApplicationType != "native" {
383 return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps")
384 }
385
386 revdomain := reverseDomain(u.Scheme)
387
388 if isLocalHostname(revdomain) {
389 return nil, errors.New("private use uri scheme redirect uris must not be local hostnames")
390 }
391
392 if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" {
393 return nil, fmt.Errorf("private use uri scheme must be in the form ")
394 }
395 default:
396 return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme)
397 }
398 }
399
400 return &metadata, nil
401}
402
403func isLocalhostClientID(clientId string) bool {
404 u, err := url.Parse(clientId)
405 if err != nil {
406 return false
407 }
408 return u.Scheme == "http" &&
409 u.Hostname() == "localhost" &&
410 u.Port() == "" &&
411 (u.Path == "" || u.Path == "/")
412}
413
414func buildLocalhostVirtualMetadata(clientId string) (*Metadata, error) {
415 u, err := url.Parse(clientId)
416 if err != nil {
417 return nil, fmt.Errorf("error parsing localhost client_id: %w", err)
418 }
419
420 q := u.Query()
421
422 redirectURIs := q["redirect_uri"]
423 if len(redirectURIs) == 0 {
424 redirectURIs = []string{"http://127.0.0.1/", "http://[::1]/"}
425 }
426
427 for _, ruri := range redirectURIs {
428 ru, err := url.Parse(ruri)
429 if err != nil {
430 return nil, fmt.Errorf("invalid redirect_uri %q: %w", ruri, err)
431 }
432 if ru.Scheme != "http" || !isLoopbackHost(ru.Hostname()) {
433 return nil, fmt.Errorf("localhost client redirect_uri must use a loopback address, got %q", ruri)
434 }
435 }
436
437 scope := q.Get("scope")
438 if scope == "" {
439 scope = "atproto"
440 } else if !slices.Contains(strings.Split(scope, " "), "atproto") {
441 scope = "atproto " + scope
442 }
443
444 return &Metadata{
445 ClientID: clientId,
446 ClientName: "Development client",
447 ClientURI: "http://localhost",
448 RedirectURIs: redirectURIs,
449 GrantTypes: []string{"authorization_code", "refresh_token"},
450 ResponseTypes: []string{"code"},
451 ApplicationType: "native",
452 DpopBoundAccessTokens: true,
453 Scope: scope,
454 TokenEndpointAuthMethod: "none",
455 }, nil
456}
457
458func isLoopbackHost(hostname string) bool {
459 return hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1"
460}
461
462func isLocalHostname(hostname string) bool {
463 pts := strings.Split(hostname, ".")
464 if len(pts) < 2 {
465 return true
466 }
467
468 tld := strings.ToLower(pts[len(pts)-1])
469 return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example"
470}
471
472func reverseDomain(domain string) string {
473 pts := strings.Split(domain, ".")
474 slices.Reverse(pts)
475 return strings.Join(pts, ".")
476}