A fork of the Cocoon PDS but being made more distributed.
0

Configure Feed

Select the types of activity you want to include in your feed.

1package client 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "log/slog" 10 "net/http" 11 "net/url" 12 "slices" 13 "strings" 14 "time" 15 16 cache "github.com/go-pkgz/expirable-cache/v3" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/lestrrat-go/jwx/v2/jwk" 19) 20 21type Manager struct { 22 cli *http.Client 23 logger *slog.Logger 24 jwksCache cache.Cache[string, jwk.Key] 25 metadataCache cache.Cache[string, *Metadata] 26} 27 28type ManagerArgs struct { 29 Cli *http.Client 30 Logger *slog.Logger 31} 32 33func NewManager(args ManagerArgs) *Manager { 34 if args.Logger == nil { 35 args.Logger = slog.Default() 36 } 37 38 if args.Cli == nil { 39 args.Cli = http.DefaultClient 40 } 41 42 jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 43 metadataCache := cache.NewCache[string, *Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 45 return &Manager{ 46 cli: args.Cli, 47 logger: args.Logger, 48 jwksCache: jwksCache, 49 metadataCache: metadataCache, 50 } 51} 52 53func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) { 54 metadata, err := cm.getClientMetadata(ctx, clientId) 55 if err != nil { 56 return nil, err 57 } 58 59 var jwks jwk.Key 60 if metadata.TokenEndpointAuthMethod == "private_key_jwt" { 61 if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 { 62 // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to 63 // make sure we use the right one 64 b, err := json.Marshal(metadata.JWKS.Keys[0]) 65 if err != nil { 66 return nil, err 67 } 68 69 k, err := helpers.ParseJWKFromBytes(b) 70 if err != nil { 71 return nil, err 72 } 73 74 jwks = k 75 } else if metadata.JWKSURI != nil { 76 maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI) 77 if err != nil { 78 return nil, err 79 } 80 81 jwks = maybeJwks 82 } else { 83 return nil, fmt.Errorf("no valid jwks found in oauth client metadata") 84 } 85 } 86 87 return &Client{ 88 Metadata: metadata, 89 JWKS: jwks, 90 IsLocalhostClient: isLocalhostClientID(clientId), 91 }, nil 92} 93 94func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) { 95 if isLocalhostClientID(clientId) { 96 return buildLocalhostVirtualMetadata(clientId) 97 } 98 99 cached, ok := cm.metadataCache.Get(clientId) 100 if !ok { 101 req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) 102 if err != nil { 103 return nil, err 104 } 105 106 resp, err := cm.cli.Do(req) 107 if err != nil { 108 return nil, err 109 } 110 defer resp.Body.Close() 111 112 if resp.StatusCode != http.StatusOK { 113 io.Copy(io.Discard, resp.Body) 114 return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode) 115 } 116 117 b, err := io.ReadAll(resp.Body) 118 if err != nil { 119 return nil, fmt.Errorf("error reading bytes from client response: %w", err) 120 } 121 122 validated, err := validateAndParseMetadata(clientId, b) 123 if err != nil { 124 return nil, err 125 } 126 127 cm.metadataCache.Set(clientId, validated, 10*time.Minute) 128 129 return validated, nil 130 } else { 131 return cached, nil 132 } 133} 134 135func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 136 jwks, ok := cm.jwksCache.Get(clientId) 137 if !ok { 138 req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) 139 if err != nil { 140 return nil, err 141 } 142 143 resp, err := cm.cli.Do(req) 144 if err != nil { 145 return nil, err 146 } 147 defer resp.Body.Close() 148 149 if resp.StatusCode != http.StatusOK { 150 io.Copy(io.Discard, resp.Body) 151 return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode) 152 } 153 154 type Keys struct { 155 Keys []map[string]any `json:"keys"` 156 } 157 158 var keys Keys 159 if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { 160 return nil, fmt.Errorf("error unmarshaling keys response: %w", err) 161 } 162 163 if len(keys.Keys) == 0 { 164 return nil, errors.New("no keys in jwks response") 165 } 166 167 // TODO: this is again bad, we should be figuring out which one we need to use... 168 b, err := json.Marshal(keys.Keys[0]) 169 if err != nil { 170 return nil, fmt.Errorf("could not marshal key: %w", err) 171 } 172 173 k, err := helpers.ParseJWKFromBytes(b) 174 if err != nil { 175 return nil, err 176 } 177 178 jwks = k 179 } 180 181 return jwks, nil 182} 183 184func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) { 185 var metadataMap map[string]any 186 if err := json.Unmarshal(b, &metadataMap); err != nil { 187 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 188 } 189 190 _, jwksOk := metadataMap["jwks"].(string) 191 _, jwksUriOk := metadataMap["jwks_uri"].(string) 192 if jwksOk && jwksUriOk { 193 return nil, errors.New("jwks_uri and jwks are mutually exclusive") 194 } 195 196 for _, k := range []string{ 197 "default_max_age", 198 "userinfo_signed_response_alg", 199 "id_token_signed_response_alg", 200 "userinfo_encryhpted_response_alg", 201 "authorization_encrypted_response_enc", 202 "authorization_encrypted_response_alg", 203 "tls_client_certificate_bound_access_tokens", 204 } { 205 _, kOk := metadataMap[k] 206 if kOk { 207 return nil, fmt.Errorf("unsupported `%s` parameter", k) 208 } 209 } 210 211 var metadata Metadata 212 if err := json.Unmarshal(b, &metadata); err != nil { 213 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 214 } 215 216 if metadata.ClientURI == "" { 217 u, err := url.Parse(metadata.ClientID) 218 if err != nil { 219 return nil, fmt.Errorf("unable to parse client id: %w", err) 220 } 221 u.RawPath = "" 222 u.RawQuery = "" 223 metadata.ClientURI = u.String() 224 } 225 226 u, err := url.Parse(metadata.ClientURI) 227 if err != nil { 228 return nil, fmt.Errorf("unable to parse client uri: %w", err) 229 } 230 231 if metadata.ClientName == "" { 232 metadata.ClientName = metadata.ClientURI 233 } 234 235 if isLocalHostname(u.Hostname()) { 236 return nil, fmt.Errorf("`client_uri` hostname is invalid: %s", u.Hostname()) 237 } 238 239 if metadata.Scope == "" { 240 return nil, errors.New("missing `scopes` scope") 241 } 242 243 scopes := strings.Split(metadata.Scope, " ") 244 if !slices.Contains(scopes, "atproto") { 245 return nil, errors.New("missing `atproto` scope") 246 } 247 248 scopesMap := map[string]bool{} 249 for _, scope := range scopes { 250 if scopesMap[scope] { 251 return nil, fmt.Errorf("duplicate scope `%s`", scope) 252 } 253 254 // TODO: check for unsupported scopes 255 256 scopesMap[scope] = true 257 } 258 259 grantTypesMap := map[string]bool{} 260 for _, gt := range metadata.GrantTypes { 261 if grantTypesMap[gt] { 262 return nil, fmt.Errorf("duplicate grant type `%s`", gt) 263 } 264 265 switch gt { 266 case "implicit": 267 return nil, errors.New("grantg type `implicit` is not allowed") 268 case "authorization_code", "refresh_token": 269 // TODO check if this grant type is supported 270 default: 271 return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt) 272 } 273 274 grantTypesMap[gt] = true 275 } 276 277 if metadata.ClientID != clientId { 278 return nil, errors.New("`client_id` does not match") 279 } 280 281 subjectType, subjectTypeOk := metadataMap["subject_type"].(string) 282 if subjectTypeOk && subjectType != "public" { 283 return nil, errors.New("only public `subject_type` is supported") 284 } 285 286 switch metadata.TokenEndpointAuthMethod { 287 case "none": 288 if metadata.TokenEndpointAuthSigningAlg != "" { 289 return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg") 290 } 291 case "private_key_jwt": 292 if metadata.JWKS == nil && metadata.JWKSURI == nil { 293 return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri") 294 } 295 296 if metadata.JWKS != nil && len(metadata.JWKS.Keys) == 0 { 297 return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks") 298 } 299 300 if metadata.TokenEndpointAuthSigningAlg == "" { 301 return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata") 302 } 303 default: 304 return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod) 305 } 306 307 if !metadata.DpopBoundAccessTokens { 308 return nil, errors.New("dpop_bound_access_tokens must be true") 309 } 310 311 if !slices.Contains(metadata.ResponseTypes, "code") { 312 return nil, errors.New("response_types must inclue `code`") 313 } 314 315 if !slices.Contains(metadata.GrantTypes, "authorization_code") { 316 return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`") 317 } 318 319 if len(metadata.RedirectURIs) == 0 { 320 return nil, errors.New("at least one `redirect_uri` is required") 321 } 322 323 if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod != "none" { 324 return nil, errors.New("native clients must authenticate using `none` method") 325 } 326 327 if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") { 328 for _, ruri := range metadata.RedirectURIs { 329 u, err := url.Parse(ruri) 330 if err != nil { 331 return nil, fmt.Errorf("error parsing redirect uri: %w", err) 332 } 333 334 if u.Scheme != "https" { 335 return nil, errors.New("web clients must use https redirect uris") 336 } 337 338 if u.Hostname() == "localhost" { 339 return nil, errors.New("web clients must not use localhost as the hostname") 340 } 341 } 342 } 343 344 for _, ruri := range metadata.RedirectURIs { 345 u, err := url.Parse(ruri) 346 if err != nil { 347 return nil, fmt.Errorf("error parsing redirect uri: %w", err) 348 } 349 350 if u.User != nil { 351 if u.User.Username() != "" { 352 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 353 } 354 355 if _, hasPass := u.User.Password(); hasPass { 356 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 357 } 358 } 359 360 switch true { 361 case u.Hostname() == "localhost": 362 return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)") 363 case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]": 364 if metadata.ApplicationType != "native" { 365 return nil, errors.New("loopback redirect uris are only allowed for native apps") 366 } 367 368 if u.Port() != "" { 369 // reference impl doesn't do anything with this? 370 } 371 372 if u.Scheme != "http" { 373 return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri) 374 } 375 case u.Scheme == "http": 376 return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme") 377 case u.Scheme == "https": 378 if isLocalHostname(u.Hostname()) { 379 return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri) 380 } 381 case strings.Contains(u.Scheme, "."): 382 if metadata.ApplicationType != "native" { 383 return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps") 384 } 385 386 revdomain := reverseDomain(u.Scheme) 387 388 if isLocalHostname(revdomain) { 389 return nil, errors.New("private use uri scheme redirect uris must not be local hostnames") 390 } 391 392 if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" { 393 return nil, fmt.Errorf("private use uri scheme must be in the form ") 394 } 395 default: 396 return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme) 397 } 398 } 399 400 return &metadata, nil 401} 402 403func isLocalhostClientID(clientId string) bool { 404 u, err := url.Parse(clientId) 405 if err != nil { 406 return false 407 } 408 return u.Scheme == "http" && 409 u.Hostname() == "localhost" && 410 u.Port() == "" && 411 (u.Path == "" || u.Path == "/") 412} 413 414func buildLocalhostVirtualMetadata(clientId string) (*Metadata, error) { 415 u, err := url.Parse(clientId) 416 if err != nil { 417 return nil, fmt.Errorf("error parsing localhost client_id: %w", err) 418 } 419 420 q := u.Query() 421 422 redirectURIs := q["redirect_uri"] 423 if len(redirectURIs) == 0 { 424 redirectURIs = []string{"http://127.0.0.1/", "http://[::1]/"} 425 } 426 427 for _, ruri := range redirectURIs { 428 ru, err := url.Parse(ruri) 429 if err != nil { 430 return nil, fmt.Errorf("invalid redirect_uri %q: %w", ruri, err) 431 } 432 if ru.Scheme != "http" || !isLoopbackHost(ru.Hostname()) { 433 return nil, fmt.Errorf("localhost client redirect_uri must use a loopback address, got %q", ruri) 434 } 435 } 436 437 scope := q.Get("scope") 438 if scope == "" { 439 scope = "atproto" 440 } else if !slices.Contains(strings.Split(scope, " "), "atproto") { 441 scope = "atproto " + scope 442 } 443 444 return &Metadata{ 445 ClientID: clientId, 446 ClientName: "Development client", 447 ClientURI: "http://localhost", 448 RedirectURIs: redirectURIs, 449 GrantTypes: []string{"authorization_code", "refresh_token"}, 450 ResponseTypes: []string{"code"}, 451 ApplicationType: "native", 452 DpopBoundAccessTokens: true, 453 Scope: scope, 454 TokenEndpointAuthMethod: "none", 455 }, nil 456} 457 458func isLoopbackHost(hostname string) bool { 459 return hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" 460} 461 462func isLocalHostname(hostname string) bool { 463 pts := strings.Split(hostname, ".") 464 if len(pts) < 2 { 465 return true 466 } 467 468 tld := strings.ToLower(pts[len(pts)-1]) 469 return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example" 470} 471 472func reverseDomain(domain string) string { 473 pts := strings.Split(domain, ".") 474 slices.Reverse(pts) 475 return strings.Join(pts, ".") 476}