Monorepo for Tangled
tangled.org
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 lexutil "github.com/bluesky-social/indigo/lex/util"
16 "github.com/ipfs/go-cid"
17 "tangled.org/core/appview/models"
18 "tangled.org/core/appview/pagination"
19 "tangled.org/core/orm"
20 "tangled.org/core/sets"
21)
22
23func comparePullSource(existing, new *models.PullSource) bool {
24 if existing == nil && new == nil {
25 return true
26 }
27 if existing == nil || new == nil {
28 return false
29 }
30 if existing.Branch != new.Branch {
31 return false
32 }
33 if existing.RepoDid == nil && new.RepoDid == nil {
34 return true
35 }
36 if existing.RepoDid == nil || new.RepoDid == nil {
37 return false
38 }
39 return *existing.RepoDid == *new.RepoDid
40}
41
42func compareSubmissions(existing, new []*models.PullSubmission) bool {
43 if len(existing) != len(new) {
44 return false
45 }
46 for i := range existing {
47 if existing[i].Blob.Ref.String() != new[i].Blob.Ref.String() {
48 return false
49 }
50 if existing[i].Blob.MimeType != new[i].Blob.MimeType {
51 return false
52 }
53 if existing[i].Blob.Size != new[i].Blob.Size {
54 return false
55 }
56 }
57 return true
58}
59
60func PutPull(tx *sql.Tx, pull *models.Pull) error {
61 // ensure sequence exists
62 _, err := tx.Exec(`
63 insert or ignore into repo_pull_seqs (repo_did, next_pull_id)
64 values (?, 1)
65 `, pull.RepoDid)
66 if err != nil {
67 return err
68 }
69
70 pulls, err := GetPulls(
71 tx,
72 orm.FilterEq("owner_did", pull.OwnerDid),
73 orm.FilterEq("rkey", pull.Rkey),
74 )
75 switch {
76 case err != nil:
77 return err
78 case len(pulls) == 0:
79 return createNewPull(tx, pull)
80 case len(pulls) != 1: // should be unreachable
81 return fmt.Errorf("invalid number of pulls returned: %d", len(pulls))
82 default:
83 existingPull := pulls[0]
84 if existingPull.State == models.PullMerged {
85 return nil
86 }
87
88 dependentOnEqual := (existingPull.DependentOn == nil && pull.DependentOn == nil) ||
89 (existingPull.DependentOn != nil && pull.DependentOn != nil && *existingPull.DependentOn == *pull.DependentOn)
90
91 pullSourceEqual := comparePullSource(existingPull.PullSource, pull.PullSource)
92 submissionsEqual := compareSubmissions(existingPull.Submissions, pull.Submissions)
93
94 if existingPull.Title == pull.Title &&
95 existingPull.Body == pull.Body &&
96 existingPull.TargetBranch == pull.TargetBranch &&
97 existingPull.RepoDid == pull.RepoDid &&
98 dependentOnEqual &&
99 pullSourceEqual &&
100 submissionsEqual {
101 return nil
102 }
103
104 isLonger := len(existingPull.Submissions) < len(pull.Submissions)
105 if isLonger {
106 isAppendOnly := compareSubmissions(existingPull.Submissions, pull.Submissions[:len(existingPull.Submissions)])
107 if !isAppendOnly {
108 return fmt.Errorf("the new pull does not treat submissions as append-only")
109 }
110 } else if !submissionsEqual {
111 return fmt.Errorf("the new pull does not treat submissions as append-only")
112 }
113
114 pull.ID = existingPull.ID
115 pull.PullId = existingPull.PullId
116 return updatePull(tx, pull, existingPull)
117 }
118}
119
120func createNewPull(tx *sql.Tx, pull *models.Pull) error {
121 _, err := tx.Exec(`
122 insert or ignore into repo_pull_seqs (repo_did, next_pull_id)
123 values (?, 1)
124 `, pull.RepoDid)
125 if err != nil {
126 return err
127 }
128
129 var nextId int
130 err = tx.QueryRow(`
131 update repo_pull_seqs
132 set next_pull_id = next_pull_id + 1
133 where repo_did = ?
134 returning next_pull_id - 1
135 `, pull.RepoDid).Scan(&nextId)
136 if err != nil {
137 return err
138 }
139
140 pull.PullId = nextId
141 pull.State = models.PullOpen
142
143 var sourceBranch, sourceRepoDid *string
144 if pull.PullSource != nil {
145 sourceBranch = &pull.PullSource.Branch
146 if pull.PullSource.RepoDid != nil {
147 x := string(*pull.PullSource.RepoDid)
148 sourceRepoDid = &x
149 }
150 }
151
152 result, err := tx.Exec(
153 `
154 insert into pulls (
155 repo_did,
156 owner_did,
157 pull_id,
158 title,
159 target_branch,
160 body,
161 rkey,
162 state,
163 dependent_on,
164 source_branch,
165 source_repo_did
166 )
167 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
168 pull.RepoDid,
169 pull.OwnerDid,
170 pull.PullId,
171 pull.Title,
172 pull.TargetBranch,
173 pull.Body,
174 pull.Rkey,
175 pull.State,
176 pull.DependentOn,
177 sourceBranch,
178 sourceRepoDid,
179 )
180 if err != nil {
181 return err
182 }
183
184 // Set the database primary key ID
185 id, err := result.LastInsertId()
186 if err != nil {
187 return err
188 }
189 pull.ID = int(id)
190
191 for i, s := range pull.Submissions {
192 _, err = tx.Exec(`
193 insert into pull_submissions (
194 pull_at,
195 round_number,
196 patch,
197 combined,
198 source_rev,
199 patch_blob_ref,
200 patch_blob_mime,
201 patch_blob_size
202 )
203 values (?, ?, ?, ?, ?, ?, ?, ?)
204 `,
205 pull.AtUri(),
206 i,
207 s.Patch,
208 s.Combined,
209 s.SourceRev,
210 s.Blob.Ref.String(),
211 s.Blob.MimeType,
212 s.Blob.Size,
213 )
214 if err != nil {
215 return err
216 }
217 }
218
219 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
220 return fmt.Errorf("put reference_links: %w", err)
221 }
222
223 return nil
224}
225
226func updatePull(tx *sql.Tx, pull *models.Pull, existingPull *models.Pull) error {
227 var sourceBranch, sourceRepoDid *string
228 if pull.PullSource != nil {
229 sourceBranch = &pull.PullSource.Branch
230 if pull.PullSource.RepoDid != nil {
231 x := string(*pull.PullSource.RepoDid)
232 sourceRepoDid = &x
233 }
234 }
235
236 _, err := tx.Exec(`
237 update pulls set
238 title = ?,
239 body = ?,
240 target_branch = ?,
241 dependent_on = ?,
242 source_branch = ?,
243 source_repo_did = ?
244 where owner_did = ? and rkey = ?
245 `, pull.Title, pull.Body, pull.TargetBranch, pull.DependentOn, sourceBranch, sourceRepoDid, pull.OwnerDid, pull.Rkey)
246 if err != nil {
247 return err
248 }
249
250 // insert new submissions (append-only)
251 for i := len(existingPull.Submissions); i < len(pull.Submissions); i++ {
252 s := pull.Submissions[i]
253 _, err = tx.Exec(`
254 insert into pull_submissions (
255 pull_at,
256 round_number,
257 patch,
258 combined,
259 source_rev,
260 patch_blob_ref,
261 patch_blob_mime,
262 patch_blob_size
263 )
264 values (?, ?, ?, ?, ?, ?, ?, ?)
265 `,
266 pull.AtUri(),
267 i,
268 s.Patch,
269 s.Combined,
270 s.SourceRev,
271 s.Blob.Ref.String(),
272 s.Blob.MimeType,
273 s.Blob.Size,
274 )
275 if err != nil {
276 return err
277 }
278 }
279
280 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
281 return fmt.Errorf("put reference_links: %w", err)
282 }
283 return nil
284}
285
286func NextPullId(e Execer, repoDid string) (int, error) {
287 var pullId int
288 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_did = ?`, repoDid).Scan(&pullId)
289 return pullId - 1, err
290}
291
292func GetPullsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Pull, error) {
293 pulls := make(map[syntax.ATURI]*models.Pull)
294
295 var conditions []string
296 var args []any
297 for _, filter := range filters {
298 conditions = append(conditions, filter.Condition())
299 args = append(args, filter.Arg()...)
300 }
301
302 whereClause := ""
303 if conditions != nil {
304 whereClause = " where " + strings.Join(conditions, " and ")
305 }
306 pageClause := ""
307 if page.Limit != 0 {
308 pageClause = fmt.Sprintf(
309 " limit %d offset %d ",
310 page.Limit,
311 page.Offset,
312 )
313 }
314
315 query := fmt.Sprintf(`
316 select
317 id,
318 owner_did,
319 repo_did,
320 pull_id,
321 created,
322 title,
323 state,
324 target_branch,
325 body,
326 rkey,
327 source_branch,
328 source_repo_did,
329 dependent_on
330 from
331 pulls
332 %s
333 order by
334 created desc
335 %s
336 `, whereClause, pageClause)
337
338 rows, err := e.Query(query, args...)
339 if err != nil {
340 return nil, err
341 }
342 defer rows.Close()
343
344 for rows.Next() {
345 var pull models.Pull
346 var createdAt string
347 var sourceBranch, sourceRepoDid, dependentOn sql.NullString
348 err := rows.Scan(
349 &pull.ID,
350 &pull.OwnerDid,
351 &pull.RepoDid,
352 &pull.PullId,
353 &createdAt,
354 &pull.Title,
355 &pull.State,
356 &pull.TargetBranch,
357 &pull.Body,
358 &pull.Rkey,
359 &sourceBranch,
360 &sourceRepoDid,
361 &dependentOn,
362 )
363 if err != nil {
364 return nil, err
365 }
366
367 createdTime, err := time.Parse(time.RFC3339, createdAt)
368 if err != nil {
369 return nil, err
370 }
371 pull.Created = createdTime
372
373 if sourceBranch.Valid {
374 pull.PullSource = &models.PullSource{
375 Branch: sourceBranch.String,
376 }
377 if sourceRepoDid.Valid {
378 sourceRepoDidParsed, err := syntax.ParseDID(sourceRepoDid.String)
379 if err != nil {
380 return nil, err
381 }
382 pull.PullSource.RepoDid = &sourceRepoDidParsed
383 }
384 }
385
386 if dependentOn.Valid {
387 x := syntax.ATURI(dependentOn.String)
388 pull.DependentOn = &x
389 }
390
391 pulls[pull.AtUri()] = &pull
392 }
393
394 var pullAts []syntax.ATURI
395 for _, p := range pulls {
396 pullAts = append(pullAts, p.AtUri())
397 }
398 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts))
399 if err != nil {
400 return nil, fmt.Errorf("failed to get submissions: %w", err)
401 }
402
403 for pullAt, submissions := range submissionsMap {
404 if p, ok := pulls[pullAt]; ok {
405 p.Submissions = submissions
406 }
407 }
408
409 // collect allLabels for each issue
410 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts))
411 if err != nil {
412 return nil, fmt.Errorf("failed to query labels: %w", err)
413 }
414 for pullAt, labels := range allLabels {
415 if p, ok := pulls[pullAt]; ok {
416 p.Labels = labels
417 }
418 }
419
420 // build up reverse mappings: p.Repo and p.PullSource.Repo
421 var repoDids []syntax.DID
422 for _, p := range pulls {
423 repoDids = append(repoDids, p.RepoDid)
424 if p.PullSource != nil && p.PullSource.RepoDid != nil {
425 repoDids = append(repoDids, *p.PullSource.RepoDid)
426 }
427 }
428
429 repos, err := GetRepos(e, orm.FilterIn("repo_did", repoDids))
430 if err != nil && !errors.Is(err, sql.ErrNoRows) {
431 return nil, fmt.Errorf("failed to get repos: %w", err)
432 }
433
434 repoMap := make(map[syntax.DID]*models.Repo)
435 for _, r := range repos {
436 repoMap[syntax.DID(r.RepoDid)] = &r
437 }
438
439 for _, p := range pulls {
440 if repo, ok := repoMap[p.RepoDid]; ok {
441 p.Repo = repo
442 }
443 if p.PullSource != nil && p.PullSource.RepoDid != nil {
444 if sourceRepo, ok := repoMap[*p.PullSource.RepoDid]; ok {
445 p.PullSource.Repo = sourceRepo
446 }
447 }
448 }
449
450 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts))
451 if err != nil {
452 return nil, fmt.Errorf("failed to query reference_links: %w", err)
453 }
454 for pullAt, references := range allReferences {
455 if pull, ok := pulls[pullAt]; ok {
456 pull.References = references
457 }
458 }
459
460 orderedByPullId := []*models.Pull{}
461 for _, p := range pulls {
462 orderedByPullId = append(orderedByPullId, p)
463 }
464 sort.Slice(orderedByPullId, func(i, j int) bool {
465 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
466 })
467
468 return orderedByPullId, nil
469}
470
471func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) {
472 return GetPullsPaginated(e, pagination.Page{}, filters...)
473}
474
475func GetPull(e Execer, filters ...orm.Filter) (*models.Pull, error) {
476 pulls, err := GetPullsPaginated(e, pagination.Page{Limit: 1}, filters...)
477 if err != nil {
478 return nil, err
479 }
480 if len(pulls) == 0 {
481 return nil, sql.ErrNoRows
482 }
483
484 return pulls[0], nil
485}
486
487// mapping from pull -> pull submissions
488func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
489 var conditions []string
490 var args []any
491 for _, filter := range filters {
492 conditions = append(conditions, filter.Condition())
493 args = append(args, filter.Arg()...)
494 }
495
496 whereClause := ""
497 if conditions != nil {
498 whereClause = " where " + strings.Join(conditions, " and ")
499 }
500
501 query := fmt.Sprintf(`
502 select
503 id,
504 pull_at,
505 round_number,
506 patch,
507 combined,
508 created,
509 source_rev,
510 patch_blob_ref,
511 patch_blob_mime,
512 patch_blob_size
513 from
514 pull_submissions
515 %s
516 order by
517 round_number asc
518 `, whereClause)
519
520 rows, err := e.Query(query, args...)
521 if err != nil {
522 return nil, err
523 }
524 defer rows.Close()
525
526 submissionMap := make(map[int]*models.PullSubmission)
527
528 for rows.Next() {
529 var submission models.PullSubmission
530 var submissionCreatedStr string
531 var submissionSourceRev, submissionCombined sql.Null[string]
532 var patchBlobRef, patchBlobMime sql.Null[string]
533 var patchBlobSize sql.Null[int64]
534 err := rows.Scan(
535 &submission.ID,
536 &submission.PullAt,
537 &submission.RoundNumber,
538 &submission.Patch,
539 &submissionCombined,
540 &submissionCreatedStr,
541 &submissionSourceRev,
542 &patchBlobRef,
543 &patchBlobMime,
544 &patchBlobSize,
545 )
546 if err != nil {
547 return nil, err
548 }
549
550 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
551 submission.Created = t
552 }
553
554 if submissionSourceRev.Valid {
555 submission.SourceRev = submissionSourceRev.V
556 }
557
558 if submissionCombined.Valid {
559 submission.Combined = submissionCombined.V
560 }
561
562 if patchBlobRef.Valid {
563 submission.Blob.Ref = lexutil.LexLink(cid.MustParse(patchBlobRef.V))
564 }
565
566 if patchBlobMime.Valid {
567 submission.Blob.MimeType = patchBlobMime.V
568 }
569
570 if patchBlobSize.Valid {
571 submission.Blob.Size = patchBlobSize.V
572 }
573
574 submissionMap[submission.ID] = &submission
575 }
576
577 if err := rows.Err(); err != nil {
578 return nil, err
579 }
580
581 // Get comments for all submissions using GetPullComments
582 submissionIds := slices.Collect(maps.Keys(submissionMap))
583 comments, err := GetPullComments(e, orm.FilterIn("submission_id", submissionIds))
584 if err != nil {
585 return nil, fmt.Errorf("failed to get pull comments: %w", err)
586 }
587 for _, comment := range comments {
588 if submission, ok := submissionMap[comment.SubmissionId]; ok {
589 submission.Comments = append(submission.Comments, comment)
590 }
591 }
592
593 // group the submissions by pull_at
594 m := make(map[syntax.ATURI][]*models.PullSubmission)
595 for _, s := range submissionMap {
596 m[s.PullAt] = append(m[s.PullAt], s)
597 }
598
599 // sort each one by round number
600 for _, s := range m {
601 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
602 return cmp.Compare(a.RoundNumber, b.RoundNumber)
603 })
604 }
605
606 return m, nil
607}
608
609func GetPullComments(e Execer, filters ...orm.Filter) ([]models.PullComment, error) {
610 var conditions []string
611 var args []any
612 for _, filter := range filters {
613 conditions = append(conditions, filter.Condition())
614 args = append(args, filter.Arg()...)
615 }
616
617 whereClause := ""
618 if conditions != nil {
619 whereClause = " where " + strings.Join(conditions, " and ")
620 }
621
622 query := fmt.Sprintf(`
623 select
624 id,
625 pull_id,
626 submission_id,
627 repo_did,
628 owner_did,
629 comment_at,
630 body,
631 created
632 from
633 pull_comments
634 %s
635 order by
636 created asc
637 `, whereClause)
638
639 rows, err := e.Query(query, args...)
640 if err != nil {
641 return nil, err
642 }
643 defer rows.Close()
644
645 commentMap := make(map[string]*models.PullComment)
646 for rows.Next() {
647 var comment models.PullComment
648 var createdAt string
649 err := rows.Scan(
650 &comment.ID,
651 &comment.PullId,
652 &comment.SubmissionId,
653 &comment.RepoDid,
654 &comment.OwnerDid,
655 &comment.CommentAt,
656 &comment.Body,
657 &createdAt,
658 )
659 if err != nil {
660 return nil, err
661 }
662
663 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
664 comment.Created = t
665 }
666
667 atUri := comment.AtUri().String()
668 commentMap[atUri] = &comment
669 }
670
671 if err := rows.Err(); err != nil {
672 return nil, err
673 }
674
675 // collect references for each comments
676 commentAts := slices.Collect(maps.Keys(commentMap))
677 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", commentAts))
678 if err != nil {
679 return nil, fmt.Errorf("failed to query reference_links: %w", err)
680 }
681 for commentAt, references := range allReferences {
682 if comment, ok := commentMap[commentAt.String()]; ok {
683 comment.References = references
684 }
685 }
686
687 var comments []models.PullComment
688 for _, c := range commentMap {
689 comments = append(comments, *c)
690 }
691
692 sort.Slice(comments, func(i, j int) bool {
693 return comments[i].Created.Before(comments[j].Created)
694 })
695
696 return comments, nil
697}
698
699// timeframe here is directly passed into the sql query filter, and any
700// timeframe in the past should be negative; e.g.: "-3 months"
701func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
702 var pulls []models.Pull
703
704 rows, err := e.Query(`
705 select
706 p.owner_did,
707 p.repo_did,
708 p.pull_id,
709 p.created,
710 p.title,
711 p.state,
712 r.did,
713 r.name,
714 r.knot,
715 r.rkey,
716 r.created
717 from
718 pulls p
719 join
720 repos r on p.repo_did = r.repo_did
721 where
722 p.owner_did = ? and p.created >= date ('now', ?)
723 order by
724 p.created desc`, did, timeframe)
725 if err != nil {
726 return nil, err
727 }
728 defer rows.Close()
729
730 for rows.Next() {
731 var pull models.Pull
732 var repo models.Repo
733 var pullCreatedAt, repoCreatedAt string
734 err := rows.Scan(
735 &pull.OwnerDid,
736 &pull.RepoDid,
737 &pull.PullId,
738 &pullCreatedAt,
739 &pull.Title,
740 &pull.State,
741 &repo.Did,
742 &repo.Name,
743 &repo.Knot,
744 &repo.Rkey,
745 &repoCreatedAt,
746 )
747 if err != nil {
748 return nil, err
749 }
750
751 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
752 if err != nil {
753 return nil, err
754 }
755 pull.Created = pullCreatedTime
756
757 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
758 if err != nil {
759 return nil, err
760 }
761 repo.Created = repoCreatedTime
762
763 pull.Repo = &repo
764
765 pulls = append(pulls, pull)
766 }
767
768 if err := rows.Err(); err != nil {
769 return nil, err
770 }
771
772 return pulls, nil
773}
774
775func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) {
776 query := `insert into pull_comments (owner_did, repo_did, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
777 res, err := tx.Exec(
778 query,
779 comment.OwnerDid,
780 comment.RepoDid,
781 comment.SubmissionId,
782 comment.CommentAt,
783 comment.PullId,
784 comment.Body,
785 )
786 if err != nil {
787 return 0, err
788 }
789
790 i, err := res.LastInsertId()
791 if err != nil {
792 return 0, err
793 }
794
795 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil {
796 return 0, fmt.Errorf("put reference_links: %w", err)
797 }
798
799 return i, nil
800}
801
802// use with transaction
803func SetPullsState(e Execer, pullState models.PullState, filters ...orm.Filter) error {
804 var conditions []string
805 var args []any
806
807 args = append(args, pullState)
808 for _, filter := range filters {
809 conditions = append(conditions, filter.Condition())
810 args = append(args, filter.Arg()...)
811 }
812 args = append(args, models.PullAbandoned) // only update state of non-deleted pulls
813 args = append(args, models.PullMerged) // only update state of non-merged pulls
814
815 whereClause := ""
816 if conditions != nil {
817 whereClause = " where " + strings.Join(conditions, " and ")
818 }
819
820 query := fmt.Sprintf("update pulls set state = ? %s and state <> ? and state <> ?", whereClause)
821
822 _, err := e.Exec(query, args...)
823 return err
824}
825
826func ClosePulls(e Execer, filters ...orm.Filter) error {
827 return SetPullsState(e, models.PullClosed, filters...)
828}
829
830func ReopenPulls(e Execer, filters ...orm.Filter) error {
831 return SetPullsState(e, models.PullOpen, filters...)
832}
833
834func MergePulls(e Execer, filters ...orm.Filter) error {
835 return SetPullsState(e, models.PullMerged, filters...)
836}
837
838func AbandonPulls(e Execer, filters ...orm.Filter) error {
839 return SetPullsState(e, models.PullAbandoned, filters...)
840}
841
842func ResubmitPull(
843 e Execer,
844 pullAt syntax.ATURI,
845 newRoundNumber int,
846 newPatch string,
847 combinedPatch string,
848 newSourceRev string,
849 blob *lexutil.LexBlob,
850) error {
851 _, err := e.Exec(`
852 insert into pull_submissions (
853 pull_at,
854 round_number,
855 patch,
856 combined,
857 source_rev,
858 patch_blob_ref,
859 patch_blob_mime,
860 patch_blob_size
861 )
862 values (?, ?, ?, ?, ?, ?, ?, ?)
863 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev, blob.Ref.String(), blob.MimeType, blob.Size)
864
865 return err
866}
867
868func SetDependentOn(e Execer, dependentOn syntax.ATURI, filters ...orm.Filter) error {
869 var conditions []string
870 var args []any
871
872 args = append(args, dependentOn)
873
874 for _, filter := range filters {
875 conditions = append(conditions, filter.Condition())
876 args = append(args, filter.Arg()...)
877 }
878
879 whereClause := ""
880 if conditions != nil {
881 whereClause = " where " + strings.Join(conditions, " and ")
882 }
883
884 query := fmt.Sprintf("update pulls set dependent_on = ? %s", whereClause)
885 _, err := e.Exec(query, args...)
886
887 return err
888}
889
890func GetPullCount(e Execer, repoDid string) (models.PullCount, error) {
891 row := e.QueryRow(`
892 select
893 count(case when state = ? then 1 end) as open_count,
894 count(case when state = ? then 1 end) as merged_count,
895 count(case when state = ? then 1 end) as closed_count,
896 count(case when state = ? then 1 end) as deleted_count
897 from pulls
898 where repo_did = ?`,
899 models.PullOpen,
900 models.PullMerged,
901 models.PullClosed,
902 models.PullAbandoned,
903 repoDid,
904 )
905
906 var count models.PullCount
907 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
908 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
909 }
910
911 return count, nil
912}
913
914// change-id dependent_on
915//
916// 4 w ,-------- at_uri(z) (TOP)
917// 3 z <----',------- at_uri(y)
918// 2 y <-----',------ at_uri(x)
919// 1 x <------' nil (BOT)
920//
921// `w` has no dependents, so it is the top of the stack
922//
923// this unfortunately does a db query for *each* pull of the stack,
924// ideally this would be a recursive query, but in the interest of implementation simplicity,
925// we took the less performant route
926//
927// TODO: make this less bad
928func GetStack(e Execer, atUri syntax.ATURI) (models.Stack, error) {
929 // first get the pull for the given at-uri
930 pull, err := GetPull(e, orm.FilterEq("at_uri", atUri))
931 if err != nil {
932 return nil, err
933 }
934
935 // Collect all pulls in the stack by traversing up and down
936 allPulls := []*models.Pull{pull}
937 visited := sets.New[syntax.ATURI]()
938
939 // Traverse up to find all dependents
940 current := pull
941 for {
942 dependent, err := GetPull(e,
943 orm.FilterEq("dependent_on", current.AtUri()),
944 orm.FilterNotEq("state", models.PullAbandoned),
945 )
946 if err != nil || dependent == nil {
947 break
948 }
949 if visited.Contains(dependent.AtUri()) {
950 return allPulls, fmt.Errorf("circular dependency detected in stack")
951 }
952 allPulls = append(allPulls, dependent)
953 visited.Insert(dependent.AtUri())
954 current = dependent
955 }
956
957 // Traverse down to find all dependencies
958 current = pull
959 for current.DependentOn != nil {
960 dependency, err := GetPull(
961 e,
962 orm.FilterEq("at_uri", current.DependentOn),
963 orm.FilterNotEq("state", models.PullAbandoned),
964 )
965
966 if err != nil {
967 return allPulls, fmt.Errorf("failed to find parent pull request, stack is malformed, missing PR: %s", current.DependentOn)
968 }
969 if visited.Contains(dependency.AtUri()) {
970 return allPulls, fmt.Errorf("circular dependency detected in stack")
971 }
972 allPulls = append(allPulls, dependency)
973 visited.Insert(dependency.AtUri())
974 current = dependency
975 }
976
977 // sort the list: find the top and build ordered list
978 atUriMap := make(map[syntax.ATURI]*models.Pull, len(allPulls))
979 dependentMap := make(map[syntax.ATURI]*models.Pull, len(allPulls))
980
981 for _, p := range allPulls {
982 atUriMap[p.AtUri()] = p
983 if p.DependentOn != nil {
984 dependentMap[*p.DependentOn] = p
985 }
986 }
987
988 // the top of the stack is the pull that no other pull depends on
989 var topPull *models.Pull
990 for _, maybeTop := range allPulls {
991 if _, ok := dependentMap[maybeTop.AtUri()]; !ok {
992 topPull = maybeTop
993 break
994 }
995 }
996
997 pulls := []*models.Pull{}
998 for {
999 pulls = append(pulls, topPull)
1000 if topPull.DependentOn != nil {
1001 if next, ok := atUriMap[*topPull.DependentOn]; ok {
1002 topPull = next
1003 } else {
1004 return pulls, fmt.Errorf("failed to find parent pull request, stack is malformed")
1005 }
1006 } else {
1007 break
1008 }
1009 }
1010
1011 return pulls, nil
1012}
1013
1014func GetAbandonedPulls(e Execer, atUri syntax.ATURI) ([]*models.Pull, error) {
1015 stack, err := GetStack(e, atUri)
1016 if err != nil {
1017 return nil, err
1018 }
1019
1020 var abandoned []*models.Pull
1021 for _, p := range stack {
1022 if p.State == models.PullAbandoned {
1023 abandoned = append(abandoned, p)
1024 }
1025 }
1026
1027 return abandoned, nil
1028}