A fork of the Cocoon PDS but being made more distributed.
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 6.4 kB View raw
1package dpop 2 3import ( 4 "crypto" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "log/slog" 11 "net/http" 12 "net/url" 13 "strings" 14 "time" 15 16 "github.com/golang-jwt/jwt/v4" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/haileyok/cocoon/oauth/constants" 19 "github.com/lestrrat-go/jwx/v2/jwa" 20 "github.com/lestrrat-go/jwx/v2/jwk" 21) 22 23type Manager struct { 24 nonce *TotpNonce 25 jtiCache *jtiCache 26 logger *slog.Logger 27 hostname string 28} 29 30type ManagerArgs struct { 31 NonceSecret []byte 32 NonceRotationInterval time.Duration 33 OnNonceSecretCreated func([]byte) 34 JTICacheSize int 35 Logger *slog.Logger 36 Hostname string 37} 38 39var ( 40 ErrUseDpopNonce = errors.New("use_dpop_nonce") 41) 42 43func NewManager(args ManagerArgs) *Manager { 44 if args.Logger == nil { 45 args.Logger = slog.Default() 46 } 47 48 if args.JTICacheSize == 0 { 49 args.JTICacheSize = 100_000 50 } 51 52 if args.NonceSecret == nil { 53 args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.") 54 } 55 56 return &Manager{ 57 nonce: NewTotpNonce(TotpNonceArgs{ 58 // TODO: pass this in from the args 59 // timeRoundDuration: args.NonceRotationInterval, 60 timeRoundDuration: time.Minute * 2, 61 Secret: args.NonceSecret, 62 OnSecretCreated: args.OnNonceSecretCreated, 63 }), 64 jtiCache: newJTICache(args.JTICacheSize), 65 logger: args.Logger, 66 hostname: args.Hostname, 67 } 68} 69 70func (dm *Manager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*Proof, error) { 71 if reqMethod == "" { 72 return nil, errors.New("HTTP method is required") 73 } 74 75 if !strings.HasPrefix(reqUrl, "https://") { 76 reqUrl = "https://" + dm.hostname + reqUrl 77 } 78 79 proof := extractProof(headers) 80 if proof == "" { 81 return nil, nil 82 } 83 84 parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 85 var token *jwt.Token 86 87 token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{}) 88 if err != nil { 89 return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err) 90 } 91 92 typ, _ := token.Header["typ"].(string) 93 if typ != "dpop+jwt" { 94 return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`) 95 } 96 97 dpopJwk, jwkOk := token.Header["jwk"].(map[string]any) 98 if !jwkOk { 99 return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`) 100 } 101 102 jwkb, err := json.Marshal(dpopJwk) 103 if err != nil { 104 return nil, fmt.Errorf("failed to marshal jwk: %w", err) 105 } 106 107 key, err := jwk.ParseKey(jwkb) 108 if err != nil { 109 return nil, fmt.Errorf("failed to parse jwk: %w", err) 110 } 111 112 var pubKey any 113 if err := key.Raw(&pubKey); err != nil { 114 return nil, fmt.Errorf("failed to get raw public key: %w", err) 115 } 116 117 token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) { 118 alg := t.Header["alg"].(string) 119 120 switch key.KeyType() { 121 case jwa.EC: 122 if !strings.HasPrefix(alg, "ES") { 123 return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg) 124 } 125 case jwa.RSA: 126 if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") { 127 return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg) 128 } 129 case jwa.OKP: 130 if alg != "EdDSA" { 131 return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg) 132 } 133 } 134 135 return pubKey, nil 136 }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"})) 137 if err != nil { 138 return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err) 139 } 140 141 if !token.Valid { 142 return nil, errors.New("dpop proof jwt is invalid") 143 } 144 145 claims, ok := token.Claims.(jwt.MapClaims) 146 if !ok { 147 return nil, errors.New("no claims in dpop proof jwt") 148 } 149 150 iat, iatOk := claims["iat"].(float64) 151 if !iatOk { 152 return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`) 153 } 154 155 iatTime := time.Unix(int64(iat), 0) 156 now := time.Now() 157 158 if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance { 159 return nil, errors.New("dpop proof too old") 160 } 161 162 if iatTime.Sub(now) > constants.DpopCheckTolerance { 163 return nil, errors.New("dpop proof iat is in the future") 164 } 165 166 jti, _ := claims["jti"].(string) 167 if jti == "" { 168 return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`) 169 } 170 171 if dm.jtiCache.add(jti) { 172 return nil, errors.New("dpop proof replay detected") 173 } 174 175 htm, _ := claims["htm"].(string) 176 if htm == "" { 177 return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`) 178 } 179 180 if htm != reqMethod { 181 return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`) 182 } 183 184 htu, _ := claims["htu"].(string) 185 if htu == "" { 186 return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`) 187 } 188 189 parsedHtu, err := helpers.OauthParseHtu(htu) 190 if err != nil { 191 return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`) 192 } 193 194 u, _ := url.Parse(reqUrl) 195 if parsedHtu != helpers.OauthNormalizeHtu(u) { 196 return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u)) 197 } 198 199 nonce, _ := claims["nonce"].(string) 200 if nonce == "" { 201 // reference impl checks if self.nonce is not null before returning an error, but we always have a 202 // nonce so we do not bother checking 203 return nil, ErrUseDpopNonce 204 } 205 206 if nonce != "" && !dm.nonce.Check(nonce) { 207 // dpop nonce mismatch 208 return nil, ErrUseDpopNonce 209 } 210 211 ath, _ := claims["ath"].(string) 212 213 if accessToken != nil && *accessToken != "" { 214 if ath == "" { 215 return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`) 216 } 217 218 hash := sha256.Sum256([]byte(*accessToken)) 219 if ath != base64.RawURLEncoding.EncodeToString(hash[:]) { 220 return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`) 221 } 222 } else if ath != "" { 223 return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`) 224 } 225 226 thumbBytes, err := key.Thumbprint(crypto.SHA256) 227 if err != nil { 228 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err) 229 } 230 231 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 232 233 return &Proof{ 234 JTI: jti, 235 JKT: thumb, 236 HTM: htm, 237 HTU: htu, 238 }, nil 239} 240 241func extractProof(headers http.Header) string { 242 dpopHeaders := headers.Values("dpop") 243 switch len(dpopHeaders) { 244 case 0: 245 return "" 246 case 1: 247 return dpopHeaders[0] 248 default: 249 return "" 250 } 251} 252 253func (dm *Manager) NextNonce() string { 254 return dm.nonce.NextNonce() 255}