forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "errors"
5 "fmt"
6 "net/url"
7 "slices"
8 "strings"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/haileyok/cocoon/internal/helpers"
13 "github.com/haileyok/cocoon/oauth"
14 "github.com/haileyok/cocoon/oauth/constants"
15 "github.com/haileyok/cocoon/oauth/provider"
16 "github.com/labstack/echo-contrib/session"
17 "github.com/labstack/echo/v4"
18)
19
20type HandleOauthAuthorizeGetInput struct {
21 RequestUri string `query:"request_uri"`
22}
23
24func (s *Server) handleOauthAuthorizeGet(e echo.Context) error {
25 ctx := e.Request().Context()
26
27 logger := s.logger.With("name", "handleOauthAuthorizeGet")
28
29 var input HandleOauthAuthorizeGetInput
30 if err := e.Bind(&input); err != nil {
31 logger.Error("error binding request", "err", err)
32 return fmt.Errorf("error binding request")
33 }
34
35 var reqId string
36 if input.RequestUri != "" {
37 id, err := oauth.DecodeRequestUri(input.RequestUri)
38 if err != nil {
39 logger.Error("no request uri found in input", "url", e.Request().URL.String())
40 return helpers.InputError(e, to.StringPtr("no request uri"))
41 }
42 reqId = id
43 } else {
44 var parRequest provider.ParRequest
45 if err := e.Bind(&parRequest); err != nil {
46 s.logger.Error("error binding for standard auth request", "error", err)
47 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
48 }
49
50 if err := e.Validate(parRequest); err != nil {
51 // render page for logged out dev
52 if s.config.Version == "dev" && parRequest.ClientID == "" {
53 return e.Render(200, "authorize.html", map[string]any{
54 "Scopes": []string{"atproto", "transition:generic"},
55 "AppName": "DEV MODE AUTHORIZATION PAGE",
56 "Handle": "paula.cocoon.social",
57 "RequestUri": "",
58 "Accounts": []string{},
59 "ActiveDid": "",
60 })
61 }
62 return helpers.InputError(e, to.StringPtr("no request uri and invalid parameters"))
63 }
64
65 client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, parRequest.AuthenticateClientRequestBase, nil, &provider.AuthenticateClientOptions{
66 AllowMissingDpopProof: true,
67 })
68 if err != nil {
69 s.logger.Error("error authenticating client in standard request", "client_id", parRequest.ClientID, "error", err)
70 return helpers.ServerError(e, to.StringPtr(err.Error()))
71 }
72
73 if parRequest.DpopJkt == nil {
74 if client.Metadata.DpopBoundAccessTokens {
75 }
76 } else {
77 if !client.Metadata.DpopBoundAccessTokens {
78 msg := "dpop bound access tokens are not enabled for this client"
79 return helpers.InputError(e, &msg)
80 }
81 }
82
83 eat := time.Now().Add(constants.ParExpiresIn)
84 id := oauth.GenerateRequestId()
85
86 authRequest := &provider.OauthAuthorizationRequest{
87 RequestId: id,
88 ClientId: client.Metadata.ClientID,
89 ClientAuth: *clientAuth,
90 Parameters: parRequest,
91 ExpiresAt: eat,
92 }
93
94 if err := s.db.Create(ctx, authRequest, nil).Error; err != nil {
95 s.logger.Error("error creating auth request in db", "error", err)
96 return helpers.ServerError(e, nil)
97 }
98
99 input.RequestUri = oauth.EncodeRequestUri(id)
100 reqId = id
101
102 }
103
104 var req provider.OauthAuthorizationRequest
105 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil {
106 return helpers.ServerError(e, to.StringPtr(err.Error()))
107 }
108
109 clientId := e.QueryParam("client_id")
110 if clientId != req.ClientId {
111 return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request"))
112 }
113
114 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId)
115 if err != nil {
116 return helpers.ServerError(e, to.StringPtr(err.Error()))
117 }
118
119 sess, err := session.Get(s.config.SessionCookieKey, e)
120 if err != nil {
121 return helpers.ServerError(e, to.StringPtr(err.Error()))
122 }
123
124 if req.Parameters.LoginHint != nil && *req.Parameters.LoginHint != "" {
125 did, err := s.resolveLoginHintToDid(ctx, *req.Parameters.LoginHint)
126 if err != nil || !slices.Contains(getSessionDids(sess), did) {
127 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode())
128 }
129
130 setActiveSessionDid(sess, did)
131 applyAccountSessionOptions(sess, int(AccountSessionMaxAge.Seconds()))
132 if err := sess.Save(e.Request(), e.Response()); err != nil {
133 return helpers.ServerError(e, to.StringPtr(err.Error()))
134 }
135 }
136
137 repo, _, accounts, err := s.getSessionRepoAndAccountsFromSessionOrErr(e, ctx, sess)
138 if err != nil {
139 if !errors.Is(err, ErrSessionUnauthenticated) {
140 return helpers.ServerError(e, to.StringPtr(err.Error()))
141 }
142 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode())
143 }
144
145 scopes := strings.Split(req.Parameters.Scope, " ")
146 appName := client.Metadata.ClientName
147
148 data := map[string]any{
149 "Scopes": scopes,
150 "AppName": appName,
151 "RequestUri": input.RequestUri,
152 "QueryParams": e.QueryParams().Encode(),
153 "Handle": repo.Actor.Handle,
154 "Accounts": accounts,
155 "ActiveDid": repo.Repo.Did,
156 }
157
158 return e.Render(200, "authorize.html", data)
159}
160
161type OauthAuthorizePostRequest struct {
162 RequestUri string `form:"request_uri"`
163 AcceptOrRejct string `form:"accept_or_reject"`
164}
165
166func (s *Server) handleOauthAuthorizePost(e echo.Context) error {
167 ctx := e.Request().Context()
168 logger := s.logger.With("name", "handleOauthAuthorizePost")
169
170 repo, _, err := s.getSessionRepoOrErr(e)
171 if err != nil {
172 if !errors.Is(err, ErrSessionUnauthenticated) {
173 return helpers.ServerError(e, to.StringPtr(err.Error()))
174 }
175 return e.Redirect(303, "/account/signin")
176 }
177
178 var req OauthAuthorizePostRequest
179 if err := e.Bind(&req); err != nil {
180 logger.Error("error binding authorize post request", "error", err)
181 return helpers.InputError(e, nil)
182 }
183
184 reqId, err := oauth.DecodeRequestUri(req.RequestUri)
185 if err != nil {
186 return helpers.InputError(e, to.StringPtr(err.Error()))
187 }
188
189 var authReq provider.OauthAuthorizationRequest
190 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil {
191 return helpers.ServerError(e, to.StringPtr(err.Error()))
192 }
193
194 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId)
195 if err != nil {
196 return helpers.ServerError(e, to.StringPtr(err.Error()))
197 }
198
199 // TODO: figure out how im supposed to actually redirect
200 if req.AcceptOrRejct == "reject" {
201 return e.Redirect(303, client.Metadata.ClientURI)
202 }
203
204 if time.Now().After(authReq.ExpiresAt) {
205 return helpers.InputError(e, to.StringPtr("the request has expired"))
206 }
207
208 if authReq.Sub != nil || authReq.Code != nil {
209 return helpers.InputError(e, to.StringPtr("this request was already authorized"))
210 }
211
212 code := oauth.GenerateCode()
213
214 if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil {
215 logger.Error("error updating authorization request", "error", err)
216 return helpers.ServerError(e, nil)
217 }
218
219 q := url.Values{}
220 q.Set("state", authReq.Parameters.State)
221 q.Set("iss", "https://"+s.config.Hostname)
222 q.Set("code", code)
223
224 hashOrQuestion := "?"
225 if authReq.Parameters.ResponseMode != nil {
226 switch *authReq.Parameters.ResponseMode {
227 case "fragment":
228 hashOrQuestion = "#"
229 case "query":
230 // do nothing
231 break
232 default:
233 if authReq.Parameters.ResponseType != "code" {
234 hashOrQuestion = "#"
235 }
236 }
237 } else {
238 if authReq.Parameters.ResponseType != "code" {
239 hashOrQuestion = "#"
240 }
241 }
242
243 return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode())
244}