forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package dpop
2
3import (
4 "crypto"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "log/slog"
11 "net/http"
12 "net/url"
13 "strings"
14 "time"
15
16 "github.com/golang-jwt/jwt/v4"
17 "github.com/haileyok/cocoon/internal/helpers"
18 "github.com/haileyok/cocoon/oauth/constants"
19 "github.com/lestrrat-go/jwx/v2/jwa"
20 "github.com/lestrrat-go/jwx/v2/jwk"
21)
22
23type Manager struct {
24 nonce *TotpNonce
25 jtiCache *jtiCache
26 logger *slog.Logger
27 hostname string
28}
29
30type ManagerArgs struct {
31 NonceSecret []byte
32 NonceRotationInterval time.Duration
33 OnNonceSecretCreated func([]byte)
34 JTICacheSize int
35 Logger *slog.Logger
36 Hostname string
37}
38
39var (
40 ErrUseDpopNonce = errors.New("use_dpop_nonce")
41)
42
43func NewManager(args ManagerArgs) *Manager {
44 if args.Logger == nil {
45 args.Logger = slog.Default()
46 }
47
48 if args.JTICacheSize == 0 {
49 args.JTICacheSize = 100_000
50 }
51
52 if args.NonceSecret == nil {
53 args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.")
54 }
55
56 return &Manager{
57 nonce: NewTotpNonce(TotpNonceArgs{
58 // TODO: pass this in from the args
59 // timeRoundDuration: args.NonceRotationInterval,
60 timeRoundDuration: time.Minute * 2,
61 Secret: args.NonceSecret,
62 OnSecretCreated: args.OnNonceSecretCreated,
63 }),
64 jtiCache: newJTICache(args.JTICacheSize),
65 logger: args.Logger,
66 hostname: args.Hostname,
67 }
68}
69
70func (dm *Manager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*Proof, error) {
71 if reqMethod == "" {
72 return nil, errors.New("HTTP method is required")
73 }
74
75 if !strings.HasPrefix(reqUrl, "https://") {
76 reqUrl = "https://" + dm.hostname + reqUrl
77 }
78
79 proof := extractProof(headers)
80 if proof == "" {
81 return nil, nil
82 }
83
84 parser := jwt.NewParser(jwt.WithoutClaimsValidation())
85 var token *jwt.Token
86
87 token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{})
88 if err != nil {
89 return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err)
90 }
91
92 typ, _ := token.Header["typ"].(string)
93 if typ != "dpop+jwt" {
94 return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`)
95 }
96
97 dpopJwk, jwkOk := token.Header["jwk"].(map[string]any)
98 if !jwkOk {
99 return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`)
100 }
101
102 jwkb, err := json.Marshal(dpopJwk)
103 if err != nil {
104 return nil, fmt.Errorf("failed to marshal jwk: %w", err)
105 }
106
107 key, err := jwk.ParseKey(jwkb)
108 if err != nil {
109 return nil, fmt.Errorf("failed to parse jwk: %w", err)
110 }
111
112 var pubKey any
113 if err := key.Raw(&pubKey); err != nil {
114 return nil, fmt.Errorf("failed to get raw public key: %w", err)
115 }
116
117 token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) {
118 alg := t.Header["alg"].(string)
119
120 switch key.KeyType() {
121 case jwa.EC:
122 if !strings.HasPrefix(alg, "ES") {
123 return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg)
124 }
125 case jwa.RSA:
126 if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") {
127 return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg)
128 }
129 case jwa.OKP:
130 if alg != "EdDSA" {
131 return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg)
132 }
133 }
134
135 return pubKey, nil
136 }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"}))
137 if err != nil {
138 return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err)
139 }
140
141 if !token.Valid {
142 return nil, errors.New("dpop proof jwt is invalid")
143 }
144
145 claims, ok := token.Claims.(jwt.MapClaims)
146 if !ok {
147 return nil, errors.New("no claims in dpop proof jwt")
148 }
149
150 iat, iatOk := claims["iat"].(float64)
151 if !iatOk {
152 return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`)
153 }
154
155 iatTime := time.Unix(int64(iat), 0)
156 now := time.Now()
157
158 if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance {
159 return nil, errors.New("dpop proof too old")
160 }
161
162 if iatTime.Sub(now) > constants.DpopCheckTolerance {
163 return nil, errors.New("dpop proof iat is in the future")
164 }
165
166 jti, _ := claims["jti"].(string)
167 if jti == "" {
168 return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`)
169 }
170
171 if dm.jtiCache.add(jti) {
172 return nil, errors.New("dpop proof replay detected")
173 }
174
175 htm, _ := claims["htm"].(string)
176 if htm == "" {
177 return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`)
178 }
179
180 if htm != reqMethod {
181 return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`)
182 }
183
184 htu, _ := claims["htu"].(string)
185 if htu == "" {
186 return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`)
187 }
188
189 parsedHtu, err := helpers.OauthParseHtu(htu)
190 if err != nil {
191 return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`)
192 }
193
194 u, _ := url.Parse(reqUrl)
195 if parsedHtu != helpers.OauthNormalizeHtu(u) {
196 return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u))
197 }
198
199 nonce, _ := claims["nonce"].(string)
200 if nonce == "" {
201 // reference impl checks if self.nonce is not null before returning an error, but we always have a
202 // nonce so we do not bother checking
203 return nil, ErrUseDpopNonce
204 }
205
206 if nonce != "" && !dm.nonce.Check(nonce) {
207 // dpop nonce mismatch
208 return nil, ErrUseDpopNonce
209 }
210
211 ath, _ := claims["ath"].(string)
212
213 if accessToken != nil && *accessToken != "" {
214 if ath == "" {
215 return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`)
216 }
217
218 hash := sha256.Sum256([]byte(*accessToken))
219 if ath != base64.RawURLEncoding.EncodeToString(hash[:]) {
220 return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`)
221 }
222 } else if ath != "" {
223 return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`)
224 }
225
226 thumbBytes, err := key.Thumbprint(crypto.SHA256)
227 if err != nil {
228 return nil, fmt.Errorf("failed to calculate thumbprint: %w", err)
229 }
230
231 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
232
233 return &Proof{
234 JTI: jti,
235 JKT: thumb,
236 HTM: htm,
237 HTU: htu,
238 }, nil
239}
240
241func extractProof(headers http.Header) string {
242 dpopHeaders := headers.Values("dpop")
243 switch len(dpopHeaders) {
244 case 0:
245 return ""
246 case 1:
247 return dpopHeaders[0]
248 default:
249 return ""
250 }
251}
252
253func (dm *Manager) NextNonce() string {
254 return dm.nonce.NextNonce()
255}