Caddy module to require at-proto authentication and restrict routes to DIDs
1package db
2
3import (
4 "context"
5 "crypto/rand"
6 "database/sql"
7 "encoding/hex"
8 "encoding/json"
9 "fmt"
10 "os"
11 "path/filepath"
12 "sync/atomic"
13
14 "github.com/bluesky-social/indigo/atproto/atcrypto"
15 "github.com/bluesky-social/indigo/atproto/auth/oauth"
16 "github.com/bluesky-social/indigo/atproto/syntax"
17 _ "github.com/mattn/go-sqlite3"
18)
19
20// Ensure DB implements ClientAuthStore
21var _ oauth.ClientAuthStore = (*Store)(nil)
22
23// Store handles SQLite persistence for the plugin.
24type Store struct {
25 db *sql.DB
26 cleanupCounter uint32
27}
28
29// NewStore initializes a new SQLite-backed storage.
30func NewStore(path string) (*Store, error) {
31 if dir := filepath.Dir(path); dir != "" && dir != "." {
32 if err := os.MkdirAll(dir, 0755); err != nil {
33 return nil, fmt.Errorf("failed to create directory for database: %w", err)
34 }
35 }
36
37 // Enable WAL mode for better concurrency
38 db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_busy_timeout=5000")
39 if err != nil {
40 return nil, fmt.Errorf("failed to open database: %w", err)
41 }
42
43 if err := db.Ping(); err != nil {
44 return nil, fmt.Errorf("failed to ping database: %w", err)
45 }
46
47 store := &Store{db: db}
48 if err := store.initSchema(); err != nil {
49 return nil, fmt.Errorf("failed to initialize schema: %w", err)
50 }
51
52 return store, nil
53}
54
55func (s *Store) initSchema() error {
56 const schema = `
57 CREATE TABLE IF NOT EXISTS auth_requests (
58 state TEXT PRIMARY KEY,
59 data TEXT NOT NULL,
60 created_at DATETIME DEFAULT CURRENT_TIMESTAMP
61 );
62
63 CREATE TABLE IF NOT EXISTS auth_sessions (
64 did TEXT NOT NULL,
65 session_id TEXT NOT NULL,
66 data TEXT NOT NULL,
67 updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
68 PRIMARY KEY (did, session_id)
69 );
70
71 CREATE TABLE IF NOT EXISTS system_keys (
72 id TEXT PRIMARY KEY,
73 key_data BLOB NOT NULL,
74 created_at DATETIME DEFAULT CURRENT_TIMESTAMP
75 );
76 `
77 _, err := s.db.Exec(schema)
78 return err
79}
80
81// Close closes the underlying database.
82func (s *Store) Close() error {
83 return s.db.Close()
84}
85
86// GetClientKey retrieves the main client key, generating it if it doesn't exist.
87// Returns the private key and its ID (a hash of the public key or a random string).
88func (s *Store) GetClientKey(ctx context.Context) (atcrypto.PrivateKey, string, error) {
89 var keyData []byte
90 err := s.db.QueryRowContext(ctx, "SELECT key_data FROM system_keys WHERE id = 'client_key'").Scan(&keyData)
91 if err == sql.ErrNoRows {
92 // Generate a new P-256 key
93 pk, err := atcrypto.GeneratePrivateKeyP256()
94 if err != nil {
95 return nil, "", fmt.Errorf("failed to generate new client key: %w", err)
96 }
97
98 keyData = pk.Bytes()
99
100 _, err = s.db.ExecContext(ctx, "INSERT INTO system_keys (id, key_data) VALUES ('client_key', ?)", keyData)
101 if err != nil {
102 return nil, "", fmt.Errorf("failed to save generated client key: %w", err)
103 }
104
105 return pk, "client_key", nil
106 } else if err != nil {
107 return nil, "", fmt.Errorf("failed to load client key: %w", err)
108 }
109
110 pk, err := atcrypto.ParsePrivateBytesP256(keyData)
111 if err != nil {
112 return nil, "", fmt.Errorf("failed to parse existing client key: %w", err)
113 }
114
115 return pk, "client_key", nil
116}
117
118// GetCookieSecret retrieves or generates the cookie secret.
119func (s *Store) GetCookieSecret(ctx context.Context) (string, error) {
120 var secret []byte
121 err := s.db.QueryRowContext(ctx, "SELECT key_data FROM system_keys WHERE id = 'cookie_secret'").Scan(&secret)
122 if err == sql.ErrNoRows {
123 // Generate new random 32 byte secret
124 rawSecret := make([]byte, 32)
125 if _, err := rand.Read(rawSecret); err != nil {
126 return "", fmt.Errorf("failed to generate cookie secret: %w", err)
127 }
128 secretStr := hex.EncodeToString(rawSecret)
129
130 _, err = s.db.ExecContext(ctx, "INSERT INTO system_keys (id, key_data) VALUES ('cookie_secret', ?)", []byte(secretStr))
131 if err != nil {
132 return "", fmt.Errorf("failed to save cookie secret: %w", err)
133 }
134 return secretStr, nil
135 } else if err != nil {
136 return "", fmt.Errorf("failed to load cookie secret: %w", err)
137 }
138
139 return string(secret), nil
140}
141
142// GetSession retrieves session data from the database.
143func (s *Store) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
144 var dataStr string
145 err := s.db.QueryRowContext(ctx, "SELECT data FROM auth_sessions WHERE did = ? AND session_id = ?", did.String(), sessionID).Scan(&dataStr)
146 if err != nil {
147 if err == sql.ErrNoRows {
148 return nil, nil
149 }
150 return nil, fmt.Errorf("failed to query session: %w", err)
151 }
152
153 var sessionData oauth.ClientSessionData
154 if err := json.Unmarshal([]byte(dataStr), &sessionData); err != nil {
155 return nil, fmt.Errorf("failed to parse session data: %w", err)
156 }
157
158 return &sessionData, nil
159}
160
161// SaveSession saves session data into the database.
162func (s *Store) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
163 dataBytes, err := json.Marshal(sess)
164 if err != nil {
165 return fmt.Errorf("failed to serialize session data: %w", err)
166 }
167
168 _, err = s.db.ExecContext(ctx, `
169 INSERT INTO auth_sessions (did, session_id, data, updated_at)
170 VALUES (?, ?, ?, CURRENT_TIMESTAMP)
171 ON CONFLICT(did, session_id) DO UPDATE SET data = excluded.data, updated_at = CURRENT_TIMESTAMP
172 `, sess.AccountDID.String(), sess.SessionID, string(dataBytes))
173 if err != nil {
174 return fmt.Errorf("failed to save session: %w", err)
175 }
176
177 return nil
178}
179
180// DeleteSession removes session data from the database.
181func (s *Store) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
182 _, err := s.db.ExecContext(ctx, "DELETE FROM auth_sessions WHERE did = ? AND session_id = ?", did.String(), sessionID)
183 if err != nil {
184 return fmt.Errorf("failed to delete session: %w", err)
185 }
186 return nil
187}
188
189// Cleanup removes expired auth requests and stale sessions to prevent database bloat.
190func (s *Store) Cleanup(ctx context.Context) error {
191 // Delete auth_requests older than 1 hour
192 _, err := s.db.ExecContext(ctx, "DELETE FROM auth_requests WHERE created_at < datetime('now', '-1 hour')")
193 if err != nil {
194 return fmt.Errorf("cleanup auth_requests failed: %w", err)
195 }
196
197 // Delete auth_sessions older than 30 days
198 _, err = s.db.ExecContext(ctx, "DELETE FROM auth_sessions WHERE updated_at < datetime('now', '-30 days')")
199 if err != nil {
200 return fmt.Errorf("cleanup auth_sessions failed: %w", err)
201 }
202
203 return nil
204}
205
206// GetAuthRequestInfo retrieves the auth request data by state.
207func (s *Store) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
208 var dataStr string
209 err := s.db.QueryRowContext(ctx, "SELECT data FROM auth_requests WHERE state = ?", state).Scan(&dataStr)
210 if err != nil {
211 if err == sql.ErrNoRows {
212 // AuthRequestData not found
213 return nil, fmt.Errorf("auth request info not found for state")
214 }
215 return nil, fmt.Errorf("failed to query auth request: %w", err)
216 }
217
218 var reqData oauth.AuthRequestData
219 if err := json.Unmarshal([]byte(dataStr), &reqData); err != nil {
220 return nil, fmt.Errorf("failed to parse auth request data: %w", err)
221 }
222
223 return &reqData, nil
224}
225
226// SaveAuthRequestInfo saves the auth request info.
227func (s *Store) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
228 dataBytes, err := json.Marshal(info)
229 if err != nil {
230 return fmt.Errorf("failed to serialize auth request data: %w", err)
231 }
232
233 // It shouldn't exist, but we can do an INSERT OR REPLACE just in case.
234 _, err = s.db.ExecContext(ctx, `
235 INSERT OR REPLACE INTO auth_requests (state, data, created_at)
236 VALUES (?, ?, CURRENT_TIMESTAMP)
237 `, info.State, string(dataBytes))
238 if err != nil {
239 return fmt.Errorf("failed to save auth request: %w", err)
240 }
241
242 // Trigger occasional cleanup of stale auth requests and sessions
243 if atomic.AddUint32(&s.cleanupCounter, 1)%100 == 0 {
244 go func() {
245 _ = s.Cleanup(context.Background())
246 }()
247 }
248
249 return nil
250}
251
252// DeleteAuthRequestInfo removes the auth request info after it's been used or expired.
253func (s *Store) DeleteAuthRequestInfo(ctx context.Context, state string) error {
254 _, err := s.db.ExecContext(ctx, "DELETE FROM auth_requests WHERE state = ?", state)
255 if err != nil {
256 return fmt.Errorf("failed to delete auth request: %w", err)
257 }
258 return nil
259}