A fork of the Cocoon PDS but being made more distributed.
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 7.6 kB View raw
1package server 2 3import ( 4 "errors" 5 "fmt" 6 "net/url" 7 "slices" 8 "strings" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/haileyok/cocoon/internal/helpers" 13 "github.com/haileyok/cocoon/oauth" 14 "github.com/haileyok/cocoon/oauth/constants" 15 "github.com/haileyok/cocoon/oauth/provider" 16 "github.com/labstack/echo-contrib/session" 17 "github.com/labstack/echo/v4" 18) 19 20type HandleOauthAuthorizeGetInput struct { 21 RequestUri string `query:"request_uri"` 22} 23 24func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 25 ctx := e.Request().Context() 26 27 logger := s.logger.With("name", "handleOauthAuthorizeGet") 28 29 var input HandleOauthAuthorizeGetInput 30 if err := e.Bind(&input); err != nil { 31 logger.Error("error binding request", "err", err) 32 return fmt.Errorf("error binding request") 33 } 34 35 var reqId string 36 if input.RequestUri != "" { 37 id, err := oauth.DecodeRequestUri(input.RequestUri) 38 if err != nil { 39 logger.Error("no request uri found in input", "url", e.Request().URL.String()) 40 return helpers.InputError(e, to.StringPtr("no request uri")) 41 } 42 reqId = id 43 } else { 44 var parRequest provider.ParRequest 45 if err := e.Bind(&parRequest); err != nil { 46 s.logger.Error("error binding for standard auth request", "error", err) 47 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 48 } 49 50 if err := e.Validate(parRequest); err != nil { 51 // render page for logged out dev 52 if s.config.Version == "dev" && parRequest.ClientID == "" { 53 return e.Render(200, "authorize.html", map[string]any{ 54 "Scopes": []string{"atproto", "transition:generic"}, 55 "AppName": "DEV MODE AUTHORIZATION PAGE", 56 "Handle": "paula.cocoon.social", 57 "RequestUri": "", 58 "Accounts": []string{}, 59 "ActiveDid": "", 60 }) 61 } 62 return helpers.InputError(e, to.StringPtr("no request uri and invalid parameters")) 63 } 64 65 client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, parRequest.AuthenticateClientRequestBase, nil, &provider.AuthenticateClientOptions{ 66 AllowMissingDpopProof: true, 67 }) 68 if err != nil { 69 s.logger.Error("error authenticating client in standard request", "client_id", parRequest.ClientID, "error", err) 70 return helpers.ServerError(e, to.StringPtr(err.Error())) 71 } 72 73 if parRequest.DpopJkt == nil { 74 if client.Metadata.DpopBoundAccessTokens { 75 } 76 } else { 77 if !client.Metadata.DpopBoundAccessTokens { 78 msg := "dpop bound access tokens are not enabled for this client" 79 return helpers.InputError(e, &msg) 80 } 81 } 82 83 eat := time.Now().Add(constants.ParExpiresIn) 84 id := oauth.GenerateRequestId() 85 86 authRequest := &provider.OauthAuthorizationRequest{ 87 RequestId: id, 88 ClientId: client.Metadata.ClientID, 89 ClientAuth: *clientAuth, 90 Parameters: parRequest, 91 ExpiresAt: eat, 92 } 93 94 if err := s.db.Create(ctx, authRequest, nil).Error; err != nil { 95 s.logger.Error("error creating auth request in db", "error", err) 96 return helpers.ServerError(e, nil) 97 } 98 99 input.RequestUri = oauth.EncodeRequestUri(id) 100 reqId = id 101 102 } 103 104 var req provider.OauthAuthorizationRequest 105 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 106 return helpers.ServerError(e, to.StringPtr(err.Error())) 107 } 108 109 clientId := e.QueryParam("client_id") 110 if clientId != req.ClientId { 111 return helpers.InputError(e, to.StringPtr("client id does not match the client id for the supplied request")) 112 } 113 114 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), req.ClientId) 115 if err != nil { 116 return helpers.ServerError(e, to.StringPtr(err.Error())) 117 } 118 119 sess, err := session.Get(s.config.SessionCookieKey, e) 120 if err != nil { 121 return helpers.ServerError(e, to.StringPtr(err.Error())) 122 } 123 124 if req.Parameters.LoginHint != nil && *req.Parameters.LoginHint != "" { 125 did, err := s.resolveLoginHintToDid(ctx, *req.Parameters.LoginHint) 126 if err != nil || !slices.Contains(getSessionDids(sess), did) { 127 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode()) 128 } 129 130 setActiveSessionDid(sess, did) 131 applyAccountSessionOptions(sess, int(AccountSessionMaxAge.Seconds())) 132 if err := sess.Save(e.Request(), e.Response()); err != nil { 133 return helpers.ServerError(e, to.StringPtr(err.Error())) 134 } 135 } 136 137 repo, _, accounts, err := s.getSessionRepoAndAccountsFromSessionOrErr(e, ctx, sess) 138 if err != nil { 139 if !errors.Is(err, ErrSessionUnauthenticated) { 140 return helpers.ServerError(e, to.StringPtr(err.Error())) 141 } 142 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode()) 143 } 144 145 scopes := strings.Split(req.Parameters.Scope, " ") 146 appName := client.Metadata.ClientName 147 148 data := map[string]any{ 149 "Scopes": scopes, 150 "AppName": appName, 151 "RequestUri": input.RequestUri, 152 "QueryParams": e.QueryParams().Encode(), 153 "Handle": repo.Actor.Handle, 154 "Accounts": accounts, 155 "ActiveDid": repo.Repo.Did, 156 } 157 158 return e.Render(200, "authorize.html", data) 159} 160 161type OauthAuthorizePostRequest struct { 162 RequestUri string `form:"request_uri"` 163 AcceptOrRejct string `form:"accept_or_reject"` 164} 165 166func (s *Server) handleOauthAuthorizePost(e echo.Context) error { 167 ctx := e.Request().Context() 168 logger := s.logger.With("name", "handleOauthAuthorizePost") 169 170 repo, _, err := s.getSessionRepoOrErr(e) 171 if err != nil { 172 if !errors.Is(err, ErrSessionUnauthenticated) { 173 return helpers.ServerError(e, to.StringPtr(err.Error())) 174 } 175 return e.Redirect(303, "/account/signin") 176 } 177 178 var req OauthAuthorizePostRequest 179 if err := e.Bind(&req); err != nil { 180 logger.Error("error binding authorize post request", "error", err) 181 return helpers.InputError(e, nil) 182 } 183 184 reqId, err := oauth.DecodeRequestUri(req.RequestUri) 185 if err != nil { 186 return helpers.InputError(e, to.StringPtr(err.Error())) 187 } 188 189 var authReq provider.OauthAuthorizationRequest 190 if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 191 return helpers.ServerError(e, to.StringPtr(err.Error())) 192 } 193 194 client, err := s.oauthProvider.ClientManager.GetClient(e.Request().Context(), authReq.ClientId) 195 if err != nil { 196 return helpers.ServerError(e, to.StringPtr(err.Error())) 197 } 198 199 // TODO: figure out how im supposed to actually redirect 200 if req.AcceptOrRejct == "reject" { 201 return e.Redirect(303, client.Metadata.ClientURI) 202 } 203 204 if time.Now().After(authReq.ExpiresAt) { 205 return helpers.InputError(e, to.StringPtr("the request has expired")) 206 } 207 208 if authReq.Sub != nil || authReq.Code != nil { 209 return helpers.InputError(e, to.StringPtr("this request was already authorized")) 210 } 211 212 code := oauth.GenerateCode() 213 214 if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 215 logger.Error("error updating authorization request", "error", err) 216 return helpers.ServerError(e, nil) 217 } 218 219 q := url.Values{} 220 q.Set("state", authReq.Parameters.State) 221 q.Set("iss", "https://"+s.config.Hostname) 222 q.Set("code", code) 223 224 hashOrQuestion := "?" 225 if authReq.Parameters.ResponseMode != nil { 226 switch *authReq.Parameters.ResponseMode { 227 case "fragment": 228 hashOrQuestion = "#" 229 case "query": 230 // do nothing 231 break 232 default: 233 if authReq.Parameters.ResponseType != "code" { 234 hashOrQuestion = "#" 235 } 236 } 237 } else { 238 if authReq.Parameters.ResponseType != "code" { 239 hashOrQuestion = "#" 240 } 241 } 242 243 return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode()) 244}