forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/bluesky-social/indigo/atproto/syntax"
11 "github.com/gorilla/sessions"
12 "github.com/haileyok/cocoon/internal/helpers"
13 "github.com/haileyok/cocoon/models"
14 "github.com/labstack/echo-contrib/session"
15 "github.com/labstack/echo/v4"
16 "golang.org/x/crypto/bcrypt"
17 "gorm.io/gorm"
18)
19
20type OauthSigninInput struct {
21 Username string `form:"username"`
22 Password string `form:"password"`
23 AuthFactorToken string `form:"token"`
24 QueryParams string `form:"query_params"`
25}
26
27var ErrSessionUnauthenticated = errors.New("session is unauthenticated")
28
29func (s *Server) getSessionRepoAndAccountsOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, []models.RepoActor, error) {
30 ctx := e.Request().Context()
31 sess, err := session.Get(s.config.SessionCookieKey, e)
32 if err != nil {
33 return nil, nil, nil, err
34 }
35
36 return s.getSessionRepoAndAccountsFromSessionOrErr(e, ctx, sess)
37}
38
39func (s *Server) getSessionRepoAndAccountsFromSessionOrErr(e echo.Context, ctx context.Context, sess *sessions.Session) (*models.RepoActor, *sessions.Session, []models.RepoActor, error) {
40 if sess == nil {
41 return nil, nil, nil, errors.New("session is nil")
42 }
43
44 accounts, changed, err := s.getSessionAccountActors(ctx, sess)
45 if err != nil {
46 return nil, sess, nil, err
47 }
48 if changed {
49 applyAccountSessionOptions(sess, int(AccountSessionMaxAge.Seconds()))
50 if err := sess.Save(e.Request(), e.Response()); err != nil {
51 return nil, sess, nil, err
52 }
53 }
54
55 did := getActiveSessionDid(sess)
56 if did == "" {
57 return nil, sess, accounts, fmt.Errorf("%w: did was not set in session", ErrSessionUnauthenticated)
58 }
59
60 for i := range accounts {
61 if accounts[i].Repo.Did == did {
62 return &accounts[i], sess, accounts, nil
63 }
64 }
65
66 return nil, sess, accounts, fmt.Errorf("%w: did was not found in session accounts", ErrSessionUnauthenticated)
67}
68
69func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) {
70 repo, sess, _, err := s.getSessionRepoAndAccountsOrErr(e)
71 return repo, sess, err
72}
73
74func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any {
75 defer sess.Save(e.Request(), e.Response())
76 return map[string]any{
77 "errors": sess.Flashes("error"),
78 "successes": sess.Flashes("success"),
79 "tokenrequired": sess.Flashes("tokenrequired"),
80 }
81}
82
83func (s *Server) handleAccountSigninGet(e echo.Context) error {
84 repo, sess, accounts, err := s.getSessionRepoAndAccountsOrErr(e)
85 if err != nil && !errors.Is(err, ErrSessionUnauthenticated) {
86 return helpers.ServerError(e, nil)
87 }
88 if err == nil && e.QueryString() == "" {
89 return e.Redirect(303, "/account")
90 }
91
92 if sess == nil {
93 return helpers.ServerError(e, nil)
94 }
95
96 activeDid := ""
97 if repo != nil {
98 activeDid = repo.Repo.Did
99 }
100
101 return e.Render(200, "signin.html", map[string]any{
102 "flashes": getFlashesFromSession(e, sess),
103 "QueryParams": e.QueryParams().Encode(),
104 "Accounts": accounts,
105 "ActiveDid": activeDid,
106 })
107}
108
109func (s *Server) handleAccountSigninPost(e echo.Context) error {
110 ctx := e.Request().Context()
111 logger := s.logger.With("name", "handleAccountSigninPost")
112
113 var req OauthSigninInput
114 if err := e.Bind(&req); err != nil {
115 logger.Error("error binding sign in req", "error", err)
116 return helpers.ServerError(e, nil)
117 }
118
119 sess, _ := session.Get(s.config.SessionCookieKey, e)
120
121 req.Username = strings.ToLower(req.Username)
122 var idtype string
123 if _, err := syntax.ParseDID(req.Username); err == nil {
124 idtype = "did"
125 } else if _, err := syntax.ParseHandle(req.Username); err == nil {
126 idtype = "handle"
127 } else {
128 idtype = "email"
129 }
130
131 queryParams := ""
132 if req.QueryParams != "" {
133 queryParams = fmt.Sprintf("?%s", req.QueryParams)
134 }
135
136 // TODO: we should make this a helper since we do it for the base create_session as well
137 var repo models.RepoActor
138 var err error
139 switch idtype {
140 case "did":
141 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
142 case "handle":
143 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
144 case "email":
145 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
146 }
147 if err != nil {
148 if err == gorm.ErrRecordNotFound {
149 sess.AddFlash("Handle or password is incorrect", "error")
150 } else {
151 sess.AddFlash("Something went wrong!", "error")
152 }
153 sess.Save(e.Request(), e.Response())
154 return e.Redirect(303, "/account/signin"+queryParams)
155 }
156
157 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
158 if err != bcrypt.ErrMismatchedHashAndPassword {
159 sess.AddFlash("Handle or password is incorrect", "error")
160 } else {
161 sess.AddFlash("Something went wrong!", "error")
162 }
163 sess.Save(e.Request(), e.Response())
164 return e.Redirect(303, "/account/signin"+queryParams)
165 }
166
167 // if repo requires 2FA token and one hasn't been provided, return error prompting for one
168 if repo.TwoFactorType != models.TwoFactorTypeNone && req.AuthFactorToken == "" {
169 err = s.createAndSendTwoFactorCode(ctx, repo)
170 if err != nil {
171 sess.AddFlash("Something went wrong!", "error")
172 sess.Save(e.Request(), e.Response())
173 return e.Redirect(303, "/account/signin"+queryParams)
174 }
175
176 sess.AddFlash("requires 2FA token", "tokenrequired")
177 sess.Save(e.Request(), e.Response())
178 return e.Redirect(303, "/account/signin"+queryParams)
179 }
180
181 // if 2FAis required, now check that the one provided is valid
182 if repo.TwoFactorType != models.TwoFactorTypeNone {
183 if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil {
184 err = s.createAndSendTwoFactorCode(ctx, repo)
185 if err != nil {
186 sess.AddFlash("Something went wrong!", "error")
187 sess.Save(e.Request(), e.Response())
188 return e.Redirect(303, "/account/signin"+queryParams)
189 }
190
191 sess.AddFlash("requires 2FA token", "tokenrequired")
192 sess.Save(e.Request(), e.Response())
193 return e.Redirect(303, "/account/signin"+queryParams)
194 }
195
196 if *repo.TwoFactorCode != req.AuthFactorToken {
197 return helpers.InvalidTokenError(e)
198 }
199
200 if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) {
201 return helpers.ExpiredTokenError(e)
202 }
203 }
204
205 applyAccountSessionOptions(sess, int(AccountSessionMaxAge.Seconds()))
206
207 setActiveSessionDid(sess, repo.Repo.Did)
208
209 if err := sess.Save(e.Request(), e.Response()); err != nil {
210 return err
211 }
212
213 if queryParams != "" {
214 return e.Redirect(303, "/oauth/authorize"+queryParams)
215 } else {
216 return e.Redirect(303, "/account")
217 }
218}