Monorepo for Tangled
tangled.org
1package db
2
3import (
4 "database/sql"
5 "log/slog"
6 "strconv"
7 "strings"
8 "time"
9
10 "github.com/bluesky-social/indigo/atproto/syntax"
11 "golang.org/x/crypto/ssh"
12 "tangled.org/core/api/tangled"
13)
14
15type PublicKey struct {
16 Did syntax.DID
17 Rkey syntax.RecordKey
18 tangled.PublicKey
19}
20
21func (d *DB) UpsertPublicKey(pk PublicKey) error {
22 tx, err := d.db.Begin()
23 if err != nil {
24 return err
25 }
26 defer tx.Rollback()
27
28 if pk.Rkey != "" {
29 if _, err := tx.Exec(`delete from public_keys where did = ? and rkey = ?`, pk.Did, pk.Rkey); err != nil {
30 return err
31 }
32 }
33
34 if err := insertPublicKey(tx, d.logger, pk); err != nil {
35 return err
36 }
37
38 return tx.Commit()
39}
40
41func insertPublicKey(tx *sql.Tx, logger *slog.Logger, pk PublicKey) error {
42 if pk.Key == "" {
43 logger.Warn("skipping public key with empty key value", "did", pk.Did, "rkey", pk.Rkey)
44 return nil
45 }
46
47 canonical, ok := normalizePublicKey(pk.Key)
48 if !ok {
49 logger.Warn("skipping malformed public key", "did", pk.Did, "rkey", pk.Rkey)
50 return nil
51 }
52 pk.Key = canonical
53
54 if pk.CreatedAt == "" {
55 pk.CreatedAt = time.Now().Format(time.RFC3339)
56 }
57
58 res, err := tx.Exec(
59 `insert or ignore into public_keys (did, key, rkey, created) values (?, ?, ?, ?)`,
60 pk.Did, pk.Key, pk.Rkey, pk.CreatedAt,
61 )
62 if err != nil {
63 return err
64 }
65
66 if rows, err := res.RowsAffected(); err == nil && rows == 0 {
67 logger.Warn("public key not stored, already registered to another did", "did", pk.Did, "rkey", pk.Rkey)
68 }
69
70 return nil
71}
72
73func (d *DB) DeletePublicKeyByRkey(did syntax.DID, rkey syntax.RecordKey) error {
74 if rkey == "" {
75 return nil
76 }
77
78 query := `delete from public_keys where did = ? and rkey = ?`
79 _, err := d.db.Exec(query, did, rkey)
80 return err
81}
82
83func (d *DB) ReplacePublicKeys(did syntax.DID, keys []PublicKey) error {
84 tx, err := d.db.Begin()
85 if err != nil {
86 return err
87 }
88 defer tx.Rollback()
89
90 if _, err := tx.Exec(`delete from public_keys where did = ?`, did); err != nil {
91 return err
92 }
93
94 if err := insertPublicKeys(tx, d.logger, keys); err != nil {
95 return err
96 }
97
98 return tx.Commit()
99}
100
101func insertPublicKeys(tx *sql.Tx, logger *slog.Logger, keys []PublicKey) error {
102 if len(keys) == 0 {
103 return nil
104 }
105
106 if err := insertPublicKey(tx, logger, keys[0]); err != nil {
107 return err
108 }
109
110 return insertPublicKeys(tx, logger, keys[1:])
111}
112
113func (pk *PublicKey) JSON() map[string]any {
114 return map[string]any{
115 "did": pk.Did,
116 "key": pk.Key,
117 "createdAt": pk.CreatedAt,
118 }
119}
120
121func normalizePublicKey(key string) (string, bool) {
122 parsed, comment, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
123 if err != nil {
124 return "", false
125 }
126
127 canonical := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(parsed)))
128 if comment != "" {
129 canonical += " " + comment
130 }
131
132 return canonical, true
133}
134
135func (d *DB) DidForPublicKey(offered ssh.PublicKey) (syntax.DID, bool, error) {
136 prefix := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(offered)))
137
138 var did syntax.DID
139 err := d.db.QueryRow(
140 `select did from public_keys where key = ? or key like ? limit 1`,
141 prefix, prefix+" %",
142 ).Scan(&did)
143 if err == sql.ErrNoRows {
144 return "", false, nil
145 }
146 if err != nil {
147 return "", false, err
148 }
149
150 return did, true, nil
151}
152
153func (d *DB) GetAllPublicKeys() ([]PublicKey, error) {
154 var keys []PublicKey
155
156 rows, err := d.db.Query(`select key, did, created from public_keys`)
157 if err != nil {
158 return nil, err
159 }
160 defer rows.Close()
161
162 for rows.Next() {
163 var publicKey PublicKey
164 if err := rows.Scan(&publicKey.Key, &publicKey.Did, &publicKey.CreatedAt); err != nil {
165 return nil, err
166 }
167 keys = append(keys, publicKey)
168 }
169
170 if err := rows.Err(); err != nil {
171 return nil, err
172 }
173
174 return keys, nil
175}
176
177func (d *DB) GetPublicKeysPaginated(limit int, cursor string) ([]PublicKey, string, error) {
178 var keys []PublicKey
179
180 offset := 0
181 if cursor != "" {
182 if o, err := strconv.Atoi(cursor); err == nil && o >= 0 {
183 offset = o
184 }
185 }
186
187 query := `select key, did, created from public_keys order by created desc limit ? offset ?`
188 rows, err := d.db.Query(query, limit+1, offset) // +1 to check if there are more results
189 if err != nil {
190 return nil, "", err
191 }
192 defer rows.Close()
193
194 for rows.Next() {
195 var publicKey PublicKey
196 if err := rows.Scan(&publicKey.Key, &publicKey.Did, &publicKey.CreatedAt); err != nil {
197 return nil, "", err
198 }
199 keys = append(keys, publicKey)
200 }
201
202 if err := rows.Err(); err != nil {
203 return nil, "", err
204 }
205
206 // check if there are more results for pagination
207 var nextCursor string
208 if len(keys) > limit {
209 keys = keys[:limit] // remove the extra item
210 nextCursor = strconv.Itoa(offset + limit)
211 }
212
213 return keys, nextCursor, nil
214}