forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package dpop
2
3import (
4 "crypto/hmac"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/binary"
8 "sync"
9 "time"
10
11 "github.com/haileyok/cocoon/internal/helpers"
12 "github.com/haileyok/cocoon/oauth/constants"
13)
14
15type TotpNonce struct {
16 secret []byte
17
18 mu sync.RWMutex
19
20 currentTimePeriodStart time.Time
21 timeRoundDuration time.Duration
22
23 prev string
24 curr string
25 next string
26}
27
28type TotpNonceArgs struct {
29 timeRoundDuration time.Duration
30 Secret []byte
31 OnSecretCreated func([]byte)
32}
33
34func NewTotpNonce(args TotpNonceArgs) *TotpNonce {
35 if args.timeRoundDuration == 0 {
36 args.timeRoundDuration = time.Minute * 2
37 }
38
39 if args.timeRoundDuration > constants.DpopNonceMaxAge {
40 args.timeRoundDuration = constants.DpopNonceMaxAge
41 }
42
43 if args.Secret == nil {
44 args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength)
45 args.OnSecretCreated(args.Secret)
46 }
47
48 n := &TotpNonce{
49 secret: args.Secret,
50 mu: sync.RWMutex{},
51 timeRoundDuration: time.Minute * 15,
52 }
53
54 n.currentTimePeriodStart = time.Now().Truncate(n.timeRoundDuration)
55 n.prev = n.compute(n.currentTimePeriodStart.Add(-n.timeRoundDuration))
56 n.curr = n.compute(n.currentTimePeriodStart)
57 n.next = n.compute(n.currentTimePeriodStart.Add(n.timeRoundDuration))
58
59 return n
60}
61
62func (n *TotpNonce) currentTruncatedTime(now time.Time) time.Time {
63 return now.Truncate(n.timeRoundDuration)
64}
65
66func (n *TotpNonce) compute(ti time.Time) string {
67 h := hmac.New(sha256.New, n.secret)
68 unixBytes := make([]byte, 8)
69 binary.BigEndian.PutUint64(unixBytes, uint64(ti.UnixNano()))
70 h.Write(unixBytes)
71 return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
72}
73
74func (n *TotpNonce) rotate(now time.Time) {
75 currentTruncated := n.currentTruncatedTime(now)
76
77 if currentTruncated == n.currentTimePeriodStart {
78 return
79 }
80
81 n.currentTimePeriodStart = currentTruncated
82 n.prev = n.curr
83 n.curr = n.next
84 n.next = n.compute(currentTruncated.Add(n.timeRoundDuration))
85}
86
87func (n *TotpNonce) NextNonce() string {
88 n.mu.Lock()
89 defer n.mu.Unlock()
90 n.rotate(time.Now())
91 return n.next
92}
93
94func (n *TotpNonce) Check(nonce string) bool {
95 n.mu.Lock()
96 defer n.mu.Unlock()
97 n.rotate(time.Now())
98 return nonce == n.prev || nonce == n.curr || nonce == n.next
99}