Monorepo for Tangled
tangled.org
1package db
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "tangled.org/core/appview/models"
13 "tangled.org/core/appview/pagination"
14 "tangled.org/core/orm"
15)
16
17func CreateNotification(e Execer, notification *models.Notification) error {
18 query := `
19 INSERT INTO notifications (recipient_did, actor_did, type, entity_type, entity_id, read, repo_id, issue_id, pull_id)
20 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
21 `
22
23 result, err := e.Exec(query,
24 notification.RecipientDid,
25 notification.ActorDid,
26 string(notification.Type),
27 notification.EntityType,
28 notification.EntityId,
29 notification.Read,
30 notification.RepoId,
31 notification.IssueId,
32 notification.PullId,
33 )
34 if err != nil {
35 return fmt.Errorf("failed to create notification: %w", err)
36 }
37
38 id, err := result.LastInsertId()
39 if err != nil {
40 return fmt.Errorf("failed to get notification ID: %w", err)
41 }
42
43 notification.ID = id
44 return nil
45}
46
47// GetNotificationsPaginated retrieves notifications with filters and pagination
48func GetNotificationsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Notification, error) {
49 var conditions []string
50 var args []any
51
52 for _, filter := range filters {
53 conditions = append(conditions, filter.Condition())
54 args = append(args, filter.Arg()...)
55 }
56
57 whereClause := ""
58 if len(conditions) > 0 {
59 whereClause = "WHERE " + conditions[0]
60 for _, condition := range conditions[1:] {
61 whereClause += " AND " + condition
62 }
63 }
64 pageClause := ""
65 if page.Limit > 0 {
66 pageClause = " limit ? offset ? "
67 args = append(args, page.Limit, page.Offset)
68 }
69
70 query := fmt.Sprintf(`
71 select id, recipient_did, actor_did, type, entity_type, entity_id, read, created, repo_id, issue_id, pull_id
72 from notifications
73 %s
74 order by created desc
75 %s
76 `, whereClause, pageClause)
77
78 rows, err := e.QueryContext(context.Background(), query, args...)
79 if err != nil {
80 return nil, fmt.Errorf("failed to query notifications: %w", err)
81 }
82 defer rows.Close()
83
84 var notifications []*models.Notification
85 for rows.Next() {
86 var n models.Notification
87 var typeStr string
88 var createdStr string
89 err := rows.Scan(
90 &n.ID,
91 &n.RecipientDid,
92 &n.ActorDid,
93 &typeStr,
94 &n.EntityType,
95 &n.EntityId,
96 &n.Read,
97 &createdStr,
98 &n.RepoId,
99 &n.IssueId,
100 &n.PullId,
101 )
102 if err != nil {
103 return nil, fmt.Errorf("failed to scan notification: %w", err)
104 }
105 n.Type = models.NotificationType(typeStr)
106 n.Created, err = time.Parse(time.RFC3339, createdStr)
107 if err != nil {
108 return nil, fmt.Errorf("failed to parse created timestamp: %w", err)
109 }
110 notifications = append(notifications, &n)
111 }
112
113 return notifications, nil
114}
115
116func GetNotificationWithEntity(e Execer, notificationID int64, userDID string) (*models.NotificationWithEntity, error) {
117 results, err := GetNotificationsWithEntities(e, pagination.Page{Limit: 1, Offset: 0},
118 orm.FilterEq("n.id", notificationID),
119 orm.FilterEq("n.recipient_did", userDID),
120 )
121 if err != nil {
122 return nil, err
123 }
124 if len(results) == 0 {
125 return nil, fmt.Errorf("notification not found")
126 }
127 return results[0], nil
128}
129
130// GetNotificationsWithEntities retrieves notifications with their related entities
131func GetNotificationsWithEntities(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.NotificationWithEntity, error) {
132 var conditions []string
133 var args []any
134
135 for _, filter := range filters {
136 conditions = append(conditions, filter.Condition())
137 args = append(args, filter.Arg()...)
138 }
139
140 whereClause := ""
141 if len(conditions) > 0 {
142 whereClause = "WHERE " + conditions[0]
143 for _, condition := range conditions[1:] {
144 whereClause += " AND " + condition
145 }
146 }
147
148 query := fmt.Sprintf(`
149 select
150 n.id, n.recipient_did, n.actor_did, n.type, n.entity_type, n.entity_id,
151 n.read, n.created, n.repo_id, n.issue_id, n.pull_id,
152 r.id as r_id, r.did as r_did, r.rkey as r_rkey, r.name as r_name, r.description as r_description, r.website as r_website, r.topics as r_topics,
153 i.id as i_id, i.did as i_did, i.issue_id as i_issue_id, i.title as i_title, i.open as i_open,
154 p.id as p_id, p.owner_did as p_owner_did, p.pull_id as p_pull_id, p.title as p_title, p.state as p_state
155 from notifications n
156 left join repos r on n.repo_id = r.id
157 left join issues i on n.issue_id = i.id
158 left join pulls p on n.pull_id = p.id
159 %s
160 order by n.created desc
161 limit ? offset ?
162 `, whereClause)
163
164 args = append(args, page.Limit, page.Offset)
165
166 rows, err := e.QueryContext(context.Background(), query, args...)
167 if err != nil {
168 return nil, fmt.Errorf("failed to query notifications with entities: %w", err)
169 }
170 defer rows.Close()
171
172 var notifications []*models.NotificationWithEntity
173 for rows.Next() {
174 var n models.Notification
175 var typeStr string
176 var createdStr string
177 var repo models.Repo
178 var issue models.Issue
179 var pull models.Pull
180 var rId, iId, pId sql.NullInt64
181 var rDid, rRkey, rName, rDescription, rWebsite, rTopicStr sql.NullString
182 var iDid sql.NullString
183 var iIssueId sql.NullInt64
184 var iTitle sql.NullString
185 var iOpen sql.NullBool
186 var pOwnerDid sql.NullString
187 var pPullId sql.NullInt64
188 var pTitle sql.NullString
189 var pState sql.NullInt64
190
191 err := rows.Scan(
192 &n.ID, &n.RecipientDid, &n.ActorDid, &typeStr, &n.EntityType, &n.EntityId,
193 &n.Read, &createdStr, &n.RepoId, &n.IssueId, &n.PullId,
194 &rId, &rDid, &rRkey, &rName, &rDescription, &rWebsite, &rTopicStr,
195 &iId, &iDid, &iIssueId, &iTitle, &iOpen,
196 &pId, &pOwnerDid, &pPullId, &pTitle, &pState,
197 )
198 if err != nil {
199 return nil, fmt.Errorf("failed to scan notification with entities: %w", err)
200 }
201
202 n.Type = models.NotificationType(typeStr)
203 n.Created, err = time.Parse(time.RFC3339, createdStr)
204 if err != nil {
205 return nil, fmt.Errorf("failed to parse created timestamp: %w", err)
206 }
207
208 entry := &models.NotificationWithEntity{Notification: &n}
209
210 // populate repo if present
211 if rId.Valid {
212 repo.Id = rId.Int64
213 if rDid.Valid {
214 repo.Did = rDid.String
215 }
216 if rRkey.Valid {
217 repo.Rkey = rRkey.String
218 }
219 if rName.Valid {
220 repo.Name = rName.String
221 }
222 if rDescription.Valid {
223 repo.Description = rDescription.String
224 }
225 if rWebsite.Valid {
226 repo.Website = rWebsite.String
227 }
228 if rTopicStr.Valid {
229 repo.Topics = strings.Fields(rTopicStr.String)
230 }
231 entry.Repo = &repo
232 }
233
234 // populate issue if present
235 if iId.Valid {
236 issue.Id = iId.Int64
237 if iDid.Valid {
238 issue.Did = iDid.String
239 }
240 if iIssueId.Valid {
241 issue.IssueId = int(iIssueId.Int64)
242 }
243 if iTitle.Valid {
244 issue.Title = iTitle.String
245 }
246 if iOpen.Valid {
247 issue.Open = iOpen.Bool
248 }
249 entry.Issue = &issue
250 }
251
252 // populate pull if present
253 if pId.Valid {
254 pull.ID = int(pId.Int64)
255 if pOwnerDid.Valid {
256 pull.OwnerDid = pOwnerDid.String
257 }
258 if pPullId.Valid {
259 pull.PullId = int(pPullId.Int64)
260 }
261 if pTitle.Valid {
262 pull.Title = pTitle.String
263 }
264 if pState.Valid {
265 pull.State = models.PullState(pState.Int64)
266 }
267 entry.Pull = &pull
268 }
269
270 notifications = append(notifications, entry)
271 }
272
273 return notifications, nil
274}
275
276// GetNotifications retrieves notifications with filters
277func GetNotifications(e Execer, filters ...orm.Filter) ([]*models.Notification, error) {
278 return GetNotificationsPaginated(e, pagination.FirstPage(), filters...)
279}
280
281func CountNotifications(e Execer, filters ...orm.Filter) (int64, error) {
282 var conditions []string
283 var args []any
284 for _, filter := range filters {
285 conditions = append(conditions, filter.Condition())
286 args = append(args, filter.Arg()...)
287 }
288
289 whereClause := ""
290 if conditions != nil {
291 whereClause = " where " + strings.Join(conditions, " and ")
292 }
293
294 query := fmt.Sprintf(`select count(1) from notifications %s`, whereClause)
295 var count int64
296 err := e.QueryRow(query, args...).Scan(&count)
297
298 if !errors.Is(err, sql.ErrNoRows) && err != nil {
299 return 0, err
300 }
301
302 return count, nil
303}
304
305func MarkNotificationRead(e Execer, notificationID int64, userDID string) error {
306 idFilter := orm.FilterEq("id", notificationID)
307 recipientFilter := orm.FilterEq("recipient_did", userDID)
308
309 query := fmt.Sprintf(`
310 UPDATE notifications
311 SET read = 1
312 WHERE %s AND %s
313 `, idFilter.Condition(), recipientFilter.Condition())
314
315 args := append(idFilter.Arg(), recipientFilter.Arg()...)
316
317 result, err := e.Exec(query, args...)
318 if err != nil {
319 return fmt.Errorf("failed to mark notification as read: %w", err)
320 }
321
322 rowsAffected, err := result.RowsAffected()
323 if err != nil {
324 return fmt.Errorf("failed to get rows affected: %w", err)
325 }
326
327 if rowsAffected == 0 {
328 return fmt.Errorf("notification not found or access denied")
329 }
330
331 return nil
332}
333
334func MarkNotificationsReadForIssue(e Execer, userDID, repoDid string, issueNum int) error {
335 query := `
336 update notifications set read = 1
337 where recipient_did = ?
338 and read = 0
339 and issue_id = (select id from issues where repo_did = ? and issue_id = ?)
340 `
341 _, err := e.Exec(query, userDID, repoDid, issueNum)
342 return err
343}
344
345func MarkNotificationsReadForPull(e Execer, userDID, repoDid string, pullNum int) error {
346 query := `
347 update notifications set read = 1
348 where recipient_did = ?
349 and read = 0
350 and pull_id = (select p.id from pulls p where p.pull_id = ? and p.repo_did = ?)
351 `
352 _, err := e.Exec(query, userDID, pullNum, repoDid)
353 return err
354}
355
356func MarkNotificationUnread(e Execer, notificationID int64, userDID string) error {
357 idFilter := orm.FilterEq("id", notificationID)
358 recipientFilter := orm.FilterEq("recipient_did", userDID)
359
360 query := fmt.Sprintf(`
361 UPDATE notifications
362 SET read = 0
363 WHERE %s AND %s
364 `, idFilter.Condition(), recipientFilter.Condition())
365
366 args := append(idFilter.Arg(), recipientFilter.Arg()...)
367
368 result, err := e.Exec(query, args...)
369 if err != nil {
370 return fmt.Errorf("failed to mark notification as unread: %w", err)
371 }
372
373 rowsAffected, err := result.RowsAffected()
374 if err != nil {
375 return fmt.Errorf("failed to get rows affected: %w", err)
376 }
377
378 if rowsAffected == 0 {
379 return fmt.Errorf("notification not found or access denied")
380 }
381
382 return nil
383}
384
385func MarkAllNotificationsRead(e Execer, userDID string) error {
386 recipientFilter := orm.FilterEq("recipient_did", userDID)
387 readFilter := orm.FilterEq("read", 0)
388
389 query := fmt.Sprintf(`
390 UPDATE notifications
391 SET read = 1
392 WHERE %s AND %s
393 `, recipientFilter.Condition(), readFilter.Condition())
394
395 args := append(recipientFilter.Arg(), readFilter.Arg()...)
396
397 _, err := e.Exec(query, args...)
398 if err != nil {
399 return fmt.Errorf("failed to mark all notifications as read: %w", err)
400 }
401
402 return nil
403}
404
405func DeleteNotification(e Execer, notificationID int64, userDID string) error {
406 idFilter := orm.FilterEq("id", notificationID)
407 recipientFilter := orm.FilterEq("recipient_did", userDID)
408
409 query := fmt.Sprintf(`
410 DELETE FROM notifications
411 WHERE %s AND %s
412 `, idFilter.Condition(), recipientFilter.Condition())
413
414 args := append(idFilter.Arg(), recipientFilter.Arg()...)
415
416 result, err := e.Exec(query, args...)
417 if err != nil {
418 return fmt.Errorf("failed to delete notification: %w", err)
419 }
420
421 rowsAffected, err := result.RowsAffected()
422 if err != nil {
423 return fmt.Errorf("failed to get rows affected: %w", err)
424 }
425
426 if rowsAffected == 0 {
427 return fmt.Errorf("notification not found or access denied")
428 }
429
430 return nil
431}
432
433func GetNotificationPreference(e Execer, userDid string) (*models.NotificationPreferences, error) {
434 prefs, err := GetNotificationPreferences(e, orm.FilterEq("user_did", userDid))
435 if err != nil {
436 return nil, err
437 }
438
439 p, ok := prefs[syntax.DID(userDid)]
440 if !ok {
441 return models.DefaultNotificationPreferences(syntax.DID(userDid)), nil
442 }
443
444 return p, nil
445}
446
447func GetNotificationPreferences(e Execer, filters ...orm.Filter) (map[syntax.DID]*models.NotificationPreferences, error) {
448 prefsMap := make(map[syntax.DID]*models.NotificationPreferences)
449
450 var conditions []string
451 var args []any
452 for _, filter := range filters {
453 conditions = append(conditions, filter.Condition())
454 args = append(args, filter.Arg()...)
455 }
456
457 whereClause := ""
458 if conditions != nil {
459 whereClause = " where " + strings.Join(conditions, " and ")
460 }
461
462 query := fmt.Sprintf(`
463 select
464 id,
465 user_did,
466 repo_starred,
467 issue_created,
468 issue_commented,
469 pull_created,
470 pull_commented,
471 followed,
472 user_mentioned,
473 pull_merged,
474 issue_closed,
475 email_notifications
476 from
477 notification_preferences
478 %s
479 `, whereClause)
480
481 rows, err := e.Query(query, args...)
482 if err != nil {
483 return nil, err
484 }
485 defer rows.Close()
486
487 for rows.Next() {
488 var prefs models.NotificationPreferences
489 if err := rows.Scan(
490 &prefs.ID,
491 &prefs.UserDid,
492 &prefs.RepoStarred,
493 &prefs.IssueCreated,
494 &prefs.IssueCommented,
495 &prefs.PullCreated,
496 &prefs.PullCommented,
497 &prefs.Followed,
498 &prefs.UserMentioned,
499 &prefs.PullMerged,
500 &prefs.IssueClosed,
501 &prefs.EmailNotifications,
502 ); err != nil {
503 return nil, err
504 }
505
506 prefsMap[prefs.UserDid] = &prefs
507 }
508
509 if err := rows.Err(); err != nil {
510 return nil, err
511 }
512
513 return prefsMap, nil
514}
515
516func (d *DB) UpdateNotificationPreferences(ctx context.Context, prefs *models.NotificationPreferences) error {
517 query := `
518 INSERT OR REPLACE INTO notification_preferences
519 (user_did, repo_starred, issue_created, issue_commented, pull_created,
520 pull_commented, followed, user_mentioned, pull_merged, issue_closed,
521 email_notifications)
522 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
523 `
524
525 result, err := d.DB.ExecContext(ctx, query,
526 prefs.UserDid,
527 prefs.RepoStarred,
528 prefs.IssueCreated,
529 prefs.IssueCommented,
530 prefs.PullCreated,
531 prefs.PullCommented,
532 prefs.Followed,
533 prefs.UserMentioned,
534 prefs.PullMerged,
535 prefs.IssueClosed,
536 prefs.EmailNotifications,
537 )
538 if err != nil {
539 return fmt.Errorf("failed to update notification preferences: %w", err)
540 }
541
542 if prefs.ID == 0 {
543 id, err := result.LastInsertId()
544 if err != nil {
545 return fmt.Errorf("failed to get preferences ID: %w", err)
546 }
547 prefs.ID = id
548 }
549
550 return nil
551}
552
553func (d *DB) ClearOldNotifications(ctx context.Context, olderThan time.Duration) error {
554 cutoff := time.Now().Add(-olderThan)
555 createdFilter := orm.FilterLte("created", cutoff)
556
557 query := fmt.Sprintf(`
558 DELETE FROM notifications
559 WHERE %s
560 `, createdFilter.Condition())
561
562 _, err := d.DB.ExecContext(ctx, query, createdFilter.Arg()...)
563 if err != nil {
564 return fmt.Errorf("failed to cleanup old notifications: %w", err)
565 }
566
567 return nil
568}