Caddy module to require at-proto authentication and restrict routes to DIDs
1package db
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8
9 "github.com/bluesky-social/indigo/atproto/atcrypto"
10 "github.com/bluesky-social/indigo/atproto/auth/oauth"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 _ "github.com/mattn/go-sqlite3"
13)
14
15// Ensure DB implements ClientAuthStore
16var _ oauth.ClientAuthStore = (*Store)(nil)
17
18// Store handles SQLite persistence for the plugin.
19type Store struct {
20 db *sql.DB
21}
22
23// NewStore initializes a new SQLite-backed storage.
24func NewStore(path string) (*Store, error) {
25 // Enable WAL mode for better concurrency
26 db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_busy_timeout=5000")
27 if err != nil {
28 return nil, fmt.Errorf("failed to open database: %w", err)
29 }
30
31 if err := db.Ping(); err != nil {
32 return nil, fmt.Errorf("failed to ping database: %w", err)
33 }
34
35 store := &Store{db: db}
36 if err := store.initSchema(); err != nil {
37 return nil, fmt.Errorf("failed to initialize schema: %w", err)
38 }
39
40 return store, nil
41}
42
43func (s *Store) initSchema() error {
44 const schema = `
45 CREATE TABLE IF NOT EXISTS auth_requests (
46 state TEXT PRIMARY KEY,
47 data TEXT NOT NULL,
48 created_at DATETIME DEFAULT CURRENT_TIMESTAMP
49 );
50
51 CREATE TABLE IF NOT EXISTS auth_sessions (
52 did TEXT NOT NULL,
53 session_id TEXT NOT NULL,
54 data TEXT NOT NULL,
55 updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
56 PRIMARY KEY (did, session_id)
57 );
58
59 CREATE TABLE IF NOT EXISTS system_keys (
60 id TEXT PRIMARY KEY,
61 key_data BLOB NOT NULL,
62 created_at DATETIME DEFAULT CURRENT_TIMESTAMP
63 );
64 `
65 _, err := s.db.Exec(schema)
66 return err
67}
68
69// Close closes the underlying database.
70func (s *Store) Close() error {
71 return s.db.Close()
72}
73
74// GetClientKey retrieves the main client key, generating it if it doesn't exist.
75// Returns the private key and its ID (a hash of the public key or a random string).
76func (s *Store) GetClientKey(ctx context.Context) (atcrypto.PrivateKey, string, error) {
77 var keyData []byte
78 err := s.db.QueryRowContext(ctx, "SELECT key_data FROM system_keys WHERE id = 'client_key'").Scan(&keyData)
79 if err == sql.ErrNoRows {
80 // Generate a new P-256 key
81 pk, err := atcrypto.GeneratePrivateKeyP256()
82 if err != nil {
83 return nil, "", fmt.Errorf("failed to generate new client key: %w", err)
84 }
85
86 keyData = pk.Bytes()
87
88 _, err = s.db.ExecContext(ctx, "INSERT INTO system_keys (id, key_data) VALUES ('client_key', ?)", keyData)
89 if err != nil {
90 return nil, "", fmt.Errorf("failed to save generated client key: %w", err)
91 }
92
93 return pk, "client_key", nil
94 } else if err != nil {
95 return nil, "", fmt.Errorf("failed to load client key: %w", err)
96 }
97
98 pk, err := atcrypto.ParsePrivateBytesP256(keyData)
99 if err != nil {
100 return nil, "", fmt.Errorf("failed to parse existing client key: %w", err)
101 }
102
103 return pk, "client_key", nil
104}
105
106// GetLatestSession returns the most recently updated session for a DID.
107func (s *Store) GetLatestSession(ctx context.Context, did syntax.DID) (*oauth.ClientSessionData, error) {
108 var dataStr string
109 err := s.db.QueryRowContext(ctx, "SELECT data FROM auth_sessions WHERE did = ? ORDER BY updated_at DESC LIMIT 1", did.String()).Scan(&dataStr)
110 if err != nil {
111 if err == sql.ErrNoRows {
112 return nil, nil
113 }
114 return nil, fmt.Errorf("failed to query latest session: %w", err)
115 }
116
117 var sessionData oauth.ClientSessionData
118 if err := json.Unmarshal([]byte(dataStr), &sessionData); err != nil {
119 return nil, fmt.Errorf("failed to parse session data: %w", err)
120 }
121
122 return &sessionData, nil
123}
124
125// GetSession retrieves session data from the database.
126func (s *Store) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
127 var dataStr string
128 err := s.db.QueryRowContext(ctx, "SELECT data FROM auth_sessions WHERE did = ? AND session_id = ?", did.String(), sessionID).Scan(&dataStr)
129 if err != nil {
130 if err == sql.ErrNoRows {
131 // Some oauth methods in indigo might expect a specific error if not found, let's return it as nil/not found or check indigo docs.
132 // Indigo's memstore returns (nil, nil) or custom error. We'll return nil for the session and potentially an error if we need to.
133 // Currently indigo's oauth store interface doesn't strictly dictate `ErrNotFound`, but usually `nil, nil` or `nil, Err` is handled. Let's return nil, nil.
134 return nil, nil
135 }
136 return nil, fmt.Errorf("failed to query session: %w", err)
137 }
138
139 var sessionData oauth.ClientSessionData
140 if err := json.Unmarshal([]byte(dataStr), &sessionData); err != nil {
141 return nil, fmt.Errorf("failed to parse session data: %w", err)
142 }
143
144 return &sessionData, nil
145}
146
147// SaveSession saves session data into the database.
148func (s *Store) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
149 dataBytes, err := json.Marshal(sess)
150 if err != nil {
151 return fmt.Errorf("failed to serialize session data: %w", err)
152 }
153
154 _, err = s.db.ExecContext(ctx, `
155 INSERT INTO auth_sessions (did, session_id, data, updated_at)
156 VALUES (?, ?, ?, CURRENT_TIMESTAMP)
157 ON CONFLICT(did, session_id) DO UPDATE SET data = excluded.data, updated_at = CURRENT_TIMESTAMP
158 `, sess.AccountDID.String(), sess.SessionID, string(dataBytes))
159 if err != nil {
160 return fmt.Errorf("failed to save session: %w", err)
161 }
162
163 return nil
164}
165
166// DeleteSession removes session data from the database.
167func (s *Store) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
168 _, err := s.db.ExecContext(ctx, "DELETE FROM auth_sessions WHERE did = ? AND session_id = ?", did.String(), sessionID)
169 if err != nil {
170 return fmt.Errorf("failed to delete session: %w", err)
171 }
172 return nil
173}
174
175// GetAuthRequestInfo retrieves the auth request data by state.
176func (s *Store) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
177 var dataStr string
178 err := s.db.QueryRowContext(ctx, "SELECT data FROM auth_requests WHERE state = ?", state).Scan(&dataStr)
179 if err != nil {
180 if err == sql.ErrNoRows {
181 // AuthRequestData not found
182 return nil, fmt.Errorf("auth request info not found for state")
183 }
184 return nil, fmt.Errorf("failed to query auth request: %w", err)
185 }
186
187 var reqData oauth.AuthRequestData
188 if err := json.Unmarshal([]byte(dataStr), &reqData); err != nil {
189 return nil, fmt.Errorf("failed to parse auth request data: %w", err)
190 }
191
192 return &reqData, nil
193}
194
195// SaveAuthRequestInfo saves the auth request info.
196func (s *Store) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
197 dataBytes, err := json.Marshal(info)
198 if err != nil {
199 return fmt.Errorf("failed to serialize auth request data: %w", err)
200 }
201
202 // Creating is fine. It shouldn't exist, but we can do an INSERT OR REPLACE just in case.
203 _, err = s.db.ExecContext(ctx, `
204 INSERT OR REPLACE INTO auth_requests (state, data, created_at)
205 VALUES (?, ?, CURRENT_TIMESTAMP)
206 `, info.State, string(dataBytes))
207 if err != nil {
208 return fmt.Errorf("failed to save auth request: %w", err)
209 }
210
211 return nil
212}
213
214// DeleteAuthRequestInfo removes the auth request info after it's been used or expired.
215func (s *Store) DeleteAuthRequestInfo(ctx context.Context, state string) error {
216 _, err := s.db.ExecContext(ctx, "DELETE FROM auth_requests WHERE state = ?", state)
217 if err != nil {
218 return fmt.Errorf("failed to delete auth request: %w", err)
219 }
220 return nil
221}