forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "context"
5 "errors"
6 "slices"
7 "strings"
8
9 "github.com/bluesky-social/indigo/atproto/syntax"
10 "github.com/gorilla/sessions"
11 "github.com/haileyok/cocoon/models"
12 "gorm.io/gorm"
13)
14
15const (
16 sessionDidKey = "did"
17 sessionDidsKey = "dids"
18)
19
20func normalizeSessionDids(dids []string) []string {
21 normalized := make([]string, 0, len(dids))
22 for _, did := range dids {
23 if did == "" || slices.Contains(normalized, did) {
24 continue
25 }
26 normalized = append(normalized, did)
27 }
28 return normalized
29}
30
31func getSessionDids(sess *sessions.Session) []string {
32 if sess == nil {
33 return nil
34 }
35
36 if val, ok := sess.Values[sessionDidsKey]; ok {
37 switch dids := val.(type) {
38 case []string:
39 return normalizeSessionDids(dids)
40 case []any:
41 out := make([]string, 0, len(dids))
42 for _, did := range dids {
43 if s, ok := did.(string); ok {
44 out = append(out, s)
45 }
46 }
47 return normalizeSessionDids(out)
48 }
49 }
50
51 if did, ok := sess.Values[sessionDidKey].(string); ok && did != "" {
52 return []string{did}
53 }
54
55 return nil
56}
57
58func setSessionDids(sess *sessions.Session, dids []string) {
59 if sess == nil {
60 return
61 }
62
63 normalized := normalizeSessionDids(dids)
64 if len(normalized) == 0 {
65 delete(sess.Values, sessionDidKey)
66 delete(sess.Values, sessionDidsKey)
67 return
68 }
69
70 sess.Values[sessionDidsKey] = normalized
71 if activeDid, ok := sess.Values[sessionDidKey].(string); !ok || !slices.Contains(normalized, activeDid) {
72 sess.Values[sessionDidKey] = normalized[0]
73 }
74}
75
76func getActiveSessionDid(sess *sessions.Session) string {
77 if sess == nil {
78 return ""
79 }
80
81 dids := getSessionDids(sess)
82 if len(dids) == 0 {
83 return ""
84 }
85
86 if activeDid, ok := sess.Values[sessionDidKey].(string); ok && slices.Contains(dids, activeDid) {
87 return activeDid
88 }
89 return dids[0]
90}
91
92func setActiveSessionDid(sess *sessions.Session, did string) bool {
93 if sess == nil || did == "" {
94 return false
95 }
96
97 dids := getSessionDids(sess)
98 if !slices.Contains(dids, did) {
99 dids = append(dids, did)
100 }
101 setSessionDids(sess, dids)
102
103 current, _ := sess.Values[sessionDidKey].(string)
104 if current == did {
105 return false
106 }
107 sess.Values[sessionDidKey] = did
108 return true
109}
110
111func removeSessionDid(sess *sessions.Session, did string) {
112 if sess == nil || did == "" {
113 return
114 }
115
116 next := make([]string, 0)
117 for _, existingDid := range getSessionDids(sess) {
118 if existingDid != did {
119 next = append(next, existingDid)
120 }
121 }
122 setSessionDids(sess, next)
123}
124
125func (s *Server) getSessionAccountActors(ctx context.Context, sess *sessions.Session) ([]models.RepoActor, bool, error) {
126 changed := false
127 validDids := make([]string, 0)
128 var accounts []models.RepoActor
129 for _, did := range getSessionDids(sess) {
130 repo, err := s.getRepoActorByDid(ctx, did)
131 if err != nil {
132 if errors.Is(err, gorm.ErrRecordNotFound) {
133 changed = true
134 continue
135 }
136 return nil, changed, err
137 }
138 validDids = append(validDids, did)
139 accounts = append(accounts, *repo)
140 }
141
142 if changed {
143 setSessionDids(sess, validDids)
144 }
145 return accounts, changed, nil
146}
147
148func (s *Server) resolveLoginHintToDid(ctx context.Context, loginHint string) (string, error) {
149 loginHint = strings.TrimSpace(loginHint)
150 if loginHint == "" {
151 return "", gorm.ErrRecordNotFound
152 }
153
154 if _, err := syntax.ParseDID(loginHint); err == nil {
155 return loginHint, nil
156 }
157
158 normalizedHandle := strings.ToLower(loginHint)
159 if _, err := syntax.ParseHandle(normalizedHandle); err == nil {
160 actor, err := s.getActorByHandle(ctx, normalizedHandle)
161 if err != nil {
162 return "", err
163 }
164 return actor.Did, nil
165 }
166
167 return "", gorm.ErrRecordNotFound
168}