forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "net/http"
5 "net/url"
6 "slices"
7 "strings"
8
9 "github.com/Azure/go-autorest/autorest/to"
10 "github.com/haileyok/cocoon/internal/helpers"
11 "github.com/labstack/echo-contrib/session"
12 "github.com/labstack/echo/v4"
13)
14
15type AccountSwitchRequest struct {
16 Did string `form:"did"`
17 QueryParams string `form:"query_params"`
18 Next string `form:"next"`
19}
20
21func sanitizeLocalRedirectPath(next string) string {
22 redirect := strings.TrimSpace(next)
23 if redirect == "" {
24 return "/account"
25 }
26 if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
27 return "/account"
28 }
29
30 parsed, err := url.Parse(redirect)
31 if err != nil || parsed.IsAbs() || parsed.Host != "" {
32 return "/account"
33 }
34
35 return redirect
36}
37
38func mergeRedirectQuery(redirect string, queryParams string) (string, error) {
39 parsedRedirect, err := url.Parse(redirect)
40 if err != nil {
41 return "", err
42 }
43
44 merged := parsedRedirect.Query()
45
46 rawQueryParams := strings.TrimSpace(queryParams)
47 if rawQueryParams != "" {
48 rawQueryParams = strings.TrimPrefix(rawQueryParams, "?")
49 additional, err := url.ParseQuery(rawQueryParams)
50 if err != nil {
51 return "", err
52 }
53 for key, values := range additional {
54 for _, value := range values {
55 merged.Add(key, value)
56 }
57 }
58 }
59
60 parsedRedirect.RawQuery = merged.Encode()
61 return parsedRedirect.String(), nil
62}
63
64func isSameOriginRequest(e echo.Context) bool {
65 host := e.Request().Host
66
67 origin := strings.TrimSpace(e.Request().Header.Get("Origin"))
68 if origin != "" {
69 parsedOrigin, err := url.Parse(origin)
70 return err == nil && parsedOrigin.Host == host
71 }
72
73 referer := strings.TrimSpace(e.Request().Header.Get("Referer"))
74 if referer != "" {
75 parsedReferer, err := url.Parse(referer)
76 return err == nil && parsedReferer.Host == host
77 }
78
79 return false
80}
81
82func (s *Server) handleAccountSwitchPost(e echo.Context) error {
83 if !isSameOriginRequest(e) {
84 return e.JSON(http.StatusForbidden, map[string]string{"error": "Forbidden"})
85 }
86
87 var req AccountSwitchRequest
88 if err := e.Bind(&req); err != nil {
89 return helpers.InputError(e, to.StringPtr("invalid switch account request"))
90 }
91
92 sess, err := session.Get(s.config.SessionCookieKey, e)
93 if err != nil {
94 return err
95 }
96
97 dids := getSessionDids(sess)
98 if !slices.Contains(dids, req.Did) {
99 return helpers.InputError(e, to.StringPtr("requested account is not logged in"))
100 }
101
102 setActiveSessionDid(sess, req.Did)
103 applyAccountSessionOptions(sess, int(AccountSessionMaxAge.Seconds()))
104
105 if err := sess.Save(e.Request(), e.Response()); err != nil {
106 return err
107 }
108
109 redirect := sanitizeLocalRedirectPath(req.Next)
110 redirect, err = mergeRedirectQuery(redirect, req.QueryParams)
111 if err != nil {
112 return helpers.InputError(e, to.StringPtr("invalid query params"))
113 }
114
115 return e.Redirect(303, redirect)
116}