forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "database/sql"
8 "embed"
9 "errors"
10 "fmt"
11 "io"
12 "log/slog"
13 "net/http"
14 "net/smtp"
15 "os"
16 "path/filepath"
17 "sync"
18 "text/template"
19 "time"
20
21 "github.com/aws/aws-sdk-go/aws"
22 "github.com/aws/aws-sdk-go/aws/credentials"
23 "github.com/aws/aws-sdk-go/aws/session"
24 "github.com/aws/aws-sdk-go/service/s3"
25 "github.com/bluesky-social/indigo/api/atproto"
26 "github.com/bluesky-social/indigo/atproto/syntax"
27 "github.com/bluesky-social/indigo/events"
28 "github.com/bluesky-social/indigo/util"
29 "github.com/bluesky-social/indigo/xrpc"
30 "github.com/domodwyer/mailyak/v3"
31 "github.com/go-playground/validator"
32 "github.com/gorilla/sessions"
33 "github.com/haileyok/cocoon/identity"
34 "github.com/haileyok/cocoon/internal/db"
35 "github.com/haileyok/cocoon/internal/helpers"
36 "github.com/haileyok/cocoon/models"
37 "github.com/haileyok/cocoon/oauth/client"
38 "github.com/haileyok/cocoon/oauth/constants"
39 "github.com/haileyok/cocoon/oauth/dpop"
40 "github.com/haileyok/cocoon/oauth/provider"
41 "github.com/haileyok/cocoon/plc"
42 "github.com/ipfs/go-cid"
43 "github.com/labstack/echo-contrib/echoprometheus"
44 echo_session "github.com/labstack/echo-contrib/session"
45 "github.com/labstack/echo/v4"
46 "github.com/labstack/echo/v4/middleware"
47 slogecho "github.com/samber/slog-echo"
48 _ "github.com/tursodatabase/libsql-client-go/libsql"
49 "gorm.io/driver/postgres"
50 "gorm.io/driver/sqlite"
51 "gorm.io/gorm"
52)
53
54const (
55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
56)
57
58type S3Config struct {
59 BackupsEnabled bool
60 BlobstoreEnabled bool
61 Endpoint string
62 Region string
63 Bucket string
64 AccessKey string
65 SecretKey string
66 CDNUrl string
67}
68
69type Server struct {
70 http *http.Client
71 httpd *http.Server
72 mail *mailyak.MailYak
73 mailLk *sync.Mutex
74 echo *echo.Echo
75 db *db.DB
76 plcClient *plc.Client
77 logger *slog.Logger
78 config *config
79 privateKey *ecdsa.PrivateKey
80 repoman *RepoMan
81 oauthProvider *provider.Provider
82 evtman *events.EventManager
83 passport *identity.Passport
84 fallbackProxy string
85
86 lastRequestCrawl time.Time
87 requestCrawlMu sync.Mutex
88
89 dbName string
90 dbType string
91 s3Config *S3Config
92}
93
94type Args struct {
95 Logger *slog.Logger
96
97 LogLevel slog.Level
98 Addr string
99 DbName string
100 DbType string
101 DatabaseURL string
102 TursoToken string
103 Version string
104 Did string
105 Hostname string
106 RotationKeyPath string
107 JwkPath string
108 ContactEmail string
109 Relays []string
110 AdminPassword string
111 RequireInvite bool
112
113 SmtpUser string
114 SmtpPass string
115 SmtpHost string
116 SmtpPort string
117 SmtpEmail string
118 SmtpName string
119
120 S3Config *S3Config
121
122 SessionSecret string
123 SessionCookieKey string
124
125 BlockstoreVariant BlockstoreVariant
126 FallbackProxy string
127
128 PushBasedEvents bool
129 SubscribeReposServiceURL string
130 NonceSecret string
131}
132
133type config struct {
134 LogLevel slog.Level
135 Version string
136 Did string
137 Hostname string
138 ContactEmail string
139 EnforcePeering bool
140 Relays []string
141 AdminPassword string
142 RequireInvite bool
143 SmtpEmail string
144 SmtpName string
145 SessionCookieKey string
146 BlockstoreVariant BlockstoreVariant
147 FallbackProxy string
148
149 PushBasedEvents bool
150 SubscribeReposServiceURL string
151}
152
153type CustomValidator struct {
154 validator *validator.Validate
155}
156
157type ValidationError struct {
158 error
159 Field string
160 Tag string
161}
162
163func (cv *CustomValidator) Validate(i any) error {
164 if err := cv.validator.Struct(i); err != nil {
165 var validateErrors validator.ValidationErrors
166 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
167 first := validateErrors[0]
168 return ValidationError{
169 error: err,
170 Field: first.Field(),
171 Tag: first.Tag(),
172 }
173 }
174
175 return err
176 }
177
178 return nil
179}
180
181//go:embed templates/*
182var templateFS embed.FS
183
184//go:embed static/*
185var staticFS embed.FS
186
187type TemplateRenderer struct {
188 templates *template.Template
189 isDev bool
190 templatePath string
191}
192
193func (s *Server) loadTemplates() {
194 absPath, _ := filepath.Abs("server/templates/*.html")
195 if s.config.Version == "dev" {
196 tmpl := template.Must(template.ParseGlob(absPath))
197 s.echo.Renderer = &TemplateRenderer{
198 templates: tmpl,
199 isDev: true,
200 templatePath: absPath,
201 }
202 } else {
203 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
204 s.echo.Renderer = &TemplateRenderer{
205 templates: tmpl,
206 isDev: false,
207 }
208 }
209}
210
211func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
212 if t.isDev {
213 tmpl, err := template.ParseGlob(t.templatePath)
214 if err != nil {
215 return err
216 }
217 t.templates = tmpl
218 }
219
220 if viewContext, isMap := data.(map[string]any); isMap {
221 viewContext["reverse"] = c.Echo().Reverse
222 }
223
224 return t.templates.ExecuteTemplate(w, name, data)
225}
226
227type filteredHandler struct {
228 level slog.Level
229 handler slog.Handler
230}
231
232func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool {
233 return level >= h.level && h.handler.Enabled(ctx, level)
234}
235
236func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error {
237 return h.handler.Handle(ctx, r)
238}
239
240func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
241 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)}
242}
243
244func (h *filteredHandler) WithGroup(name string) slog.Handler {
245 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)}
246}
247
248func New(args *Args) (*Server, error) {
249 if args.Logger == nil {
250 args.Logger = slog.Default()
251 }
252
253 if args.LogLevel != 0 {
254 args.Logger = slog.New(&filteredHandler{
255 level: args.LogLevel,
256 handler: args.Logger.Handler(),
257 })
258 }
259
260 logger := args.Logger.With("name", "New")
261
262 if args.Addr == "" {
263 return nil, fmt.Errorf("addr must be set")
264 }
265
266 if args.DbName == "" {
267 return nil, fmt.Errorf("db name must be set")
268 }
269
270 if args.Did == "" {
271 return nil, fmt.Errorf("cocoon did must be set")
272 }
273
274 if args.ContactEmail == "" {
275 return nil, fmt.Errorf("cocoon contact email is required")
276 }
277
278 if _, err := syntax.ParseDID(args.Did); err != nil {
279 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
280 }
281
282 if args.Hostname == "" {
283 return nil, fmt.Errorf("cocoon hostname must be set")
284 }
285
286 if args.AdminPassword == "" {
287 return nil, fmt.Errorf("admin password must be set")
288 }
289
290 if args.SessionSecret == "" {
291 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
292 }
293
294 e := echo.New()
295
296 e.Pre(middleware.RemoveTrailingSlash())
297 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
298 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
299 e.Use(echoprometheus.NewMiddleware("cocoon"))
300 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
301 AllowOrigins: []string{"*"},
302 AllowHeaders: []string{"*"},
303 AllowMethods: []string{"*"},
304 AllowCredentials: true,
305 MaxAge: 100_000_000,
306 }))
307
308 vdtor := validator.New()
309 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
310 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
311 return false
312 }
313 return true
314 })
315 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
316 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
317 return false
318 }
319 return true
320 })
321 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
322 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
323 return false
324 }
325 return true
326 })
327 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
328 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
329 return false
330 }
331 return true
332 })
333
334 e.Validator = &CustomValidator{validator: vdtor}
335
336 httpd := &http.Server{
337 Addr: args.Addr,
338 Handler: e,
339 // shitty defaults but okay for now, needed for import repo
340 ReadTimeout: 5 * time.Minute,
341 WriteTimeout: 5 * time.Minute,
342 IdleTimeout: 5 * time.Minute,
343 }
344
345 dbType := args.DbType
346 if dbType == "" {
347 dbType = "sqlite"
348 }
349
350 var gdb *gorm.DB
351 gormCfg := gorm.Config{
352 TranslateError: true,
353 }
354 var err error
355 switch dbType {
356 case "postgres":
357 if args.DatabaseURL == "" {
358 return nil, fmt.Errorf("database-url must be set when using postgres")
359 }
360 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gormCfg)
361 if err != nil {
362 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
363 }
364 logger.Info("connected to PostgreSQL database")
365 case "turso":
366 primaryUrl := args.DatabaseURL
367 authToken := args.TursoToken
368
369 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken))
370 gdb, err = gorm.Open(sqlite.New(sqlite.Config{
371 Conn: db,
372 }), &gormCfg)
373 if err != nil {
374 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
375 }
376 logger.Info("connected to PostgreSQL database")
377
378 default:
379 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gormCfg)
380 if err != nil {
381 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
382 }
383 gdb.Exec("PRAGMA journal_mode=WAL")
384 gdb.Exec("PRAGMA synchronous=NORMAL")
385
386 logger.Info("connected to SQLite database", "path", args.DbName)
387 }
388 dbw := db.NewDB(gdb)
389
390 rkbytes, err := os.ReadFile(args.RotationKeyPath)
391 if err != nil {
392 return nil, err
393 }
394
395 h := util.RobustHTTPClient()
396
397 plcClient, err := plc.NewClient(&plc.ClientArgs{
398 H: h,
399 Service: "https://plc.directory",
400 PdsHostname: args.Hostname,
401 RotationKey: rkbytes,
402 })
403 if err != nil {
404 return nil, err
405 }
406
407 jwkbytes, err := os.ReadFile(args.JwkPath)
408 if err != nil {
409 return nil, err
410 }
411
412 key, err := helpers.ParseJWKFromBytes(jwkbytes)
413 if err != nil {
414 return nil, err
415 }
416
417 var pkey ecdsa.PrivateKey
418 if err := key.Raw(&pkey); err != nil {
419 return nil, err
420 }
421
422 oauthCli := &http.Client{
423 Timeout: 10 * time.Second,
424 }
425
426 var nonceSecret []byte
427 if args.NonceSecret != "" {
428 nonceSecret = []byte(args.NonceSecret)
429 } else {
430 maybeSecret, err := os.ReadFile("nonce.secret")
431 if err != nil && !os.IsNotExist(err) {
432 logger.Error("error attempting to read nonce secret", "error", err)
433 } else {
434 nonceSecret = maybeSecret
435 }
436 }
437
438 evtPersister, err := NewDbPersister(gdb, 72*time.Hour)
439 if err != nil {
440 return nil, fmt.Errorf("failed to create event persister: %w", err)
441 }
442
443 s := &Server{
444 http: h,
445 httpd: httpd,
446 echo: e,
447 logger: args.Logger,
448 db: dbw,
449 plcClient: plcClient,
450 privateKey: &pkey,
451 config: &config{
452 LogLevel: args.LogLevel,
453 Version: args.Version,
454 Did: args.Did,
455 Hostname: args.Hostname,
456 ContactEmail: args.ContactEmail,
457 EnforcePeering: false,
458 Relays: args.Relays,
459 AdminPassword: args.AdminPassword,
460 RequireInvite: args.RequireInvite,
461 SmtpName: args.SmtpName,
462 SmtpEmail: args.SmtpEmail,
463 SessionCookieKey: args.SessionCookieKey,
464 BlockstoreVariant: args.BlockstoreVariant,
465 FallbackProxy: args.FallbackProxy,
466 PushBasedEvents: args.PushBasedEvents,
467 SubscribeReposServiceURL: args.SubscribeReposServiceURL,
468 },
469 evtman: events.NewEventManager(evtPersister),
470 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
471
472 dbName: args.DbName,
473 dbType: dbType,
474 s3Config: args.S3Config,
475
476 oauthProvider: provider.NewProvider(provider.Args{
477 Hostname: args.Hostname,
478 ClientManagerArgs: client.ManagerArgs{
479 Cli: oauthCli,
480 Logger: args.Logger.With("component", "oauth-client-manager"),
481 },
482 DpopManagerArgs: dpop.ManagerArgs{
483 NonceSecret: nonceSecret,
484 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
485 OnNonceSecretCreated: func(newNonce []byte) {
486 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
487 logger.Error("error writing new nonce secret", "error", err)
488 }
489 },
490 Logger: args.Logger.With("component", "dpop-manager"),
491 Hostname: args.Hostname,
492 },
493 }),
494 }
495
496 s.loadTemplates()
497
498 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
499
500 // TODO: should validate these args
501 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
502 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
503 } else {
504 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
505 mail.From(s.config.SmtpEmail)
506 mail.FromName(s.config.SmtpName)
507
508 s.mail = mail
509 s.mailLk = &sync.Mutex{}
510 }
511
512 return s, nil
513}
514
515func (s *Server) addRoutes() {
516 // static
517 if s.config.Version == "dev" {
518 s.echo.Static("/static", "server/static")
519 } else {
520 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
521 }
522
523 // random stuff
524 s.echo.GET("/", s.handleRoot)
525 s.echo.GET("/xrpc/_health", s.handleHealth)
526 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
527 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
528 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
529 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
530 s.echo.GET("/robots.txt", s.handleRobots)
531
532 // public
533 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
534 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
535 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
536 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
537 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
538
539 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
540 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
541 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
542 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
543 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
544 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
545 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
546 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
547 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
548 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
549 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
550 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
551
552 // labels
553 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
554
555 // account
556 s.echo.GET("/account", s.handleAccount)
557 s.echo.POST("/account/revoke", s.handleAccountRevoke)
558 s.echo.POST("/account/switch", s.handleAccountSwitchPost)
559 s.echo.GET("/account/signin", s.handleAccountSigninGet)
560 s.echo.POST("/account/signin", s.handleAccountSigninPost)
561 s.echo.GET("/account/signout", s.handleAccountSignout)
562
563 // oauth account
564 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
565 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
566 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
567
568 // oauth authorization
569 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
570 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
571
572 // authed
573 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
574 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
575 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
576 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
577 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
578 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
579 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
580 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
581 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
582 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
583 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
584 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
585 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
586 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
587 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
588 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
589 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
590 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
591 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
592 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
593
594 // repo
595 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
596 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
597 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
598 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
599 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
600 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
601 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
602
603 // stupid silly endpoints
604 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
605 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
606 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
607 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
608 // admin routes
609 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
610 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
611
612 // are there any routes that we should be allowing without auth? i dont think so but idk
613 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
614 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
615}
616
617func (s *Server) Serve(ctx context.Context) error {
618 logger := s.logger.With("name", "Serve")
619
620 s.addRoutes()
621
622 logger.Info("migrating...")
623
624 s.db.AutoMigrate(
625 &models.Actor{},
626 &models.Repo{},
627 &models.InviteCode{},
628 &models.Token{},
629 &models.RefreshToken{},
630 &models.Block{},
631 &models.Record{},
632 &models.Blob{},
633 &models.BlobPart{},
634 &models.ReservedKey{},
635 &provider.OauthToken{},
636 &provider.OauthAuthorizationRequest{},
637 )
638
639 logger.Info("starting cocoon")
640
641 go func() {
642 if err := s.httpd.ListenAndServe(); err != nil {
643 panic(err)
644 }
645 }()
646
647 go s.backupRoutine()
648
649 go func() {
650 if err := s.requestCrawl(ctx); err != nil {
651 logger.Error("error requesting crawls", "err", err)
652 }
653 }()
654
655 if s.config.PushBasedEvents {
656 slog.Info("pushed based events enabled")
657 go func() {
658 if err := s.emmitEvents(ctx); err != nil {
659 logger.Error("error emitting events", "err", err)
660 }
661 }()
662 }
663
664 <-ctx.Done()
665
666 fmt.Println("shut down")
667
668 return nil
669}
670
671func (s *Server) requestCrawl(ctx context.Context) error {
672 logger := s.logger.With("component", "request-crawl")
673 s.requestCrawlMu.Lock()
674 defer s.requestCrawlMu.Unlock()
675
676 logger.Info("requesting crawl with configured relays")
677
678 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
679 return fmt.Errorf("a crawl request has already been made within the last minute")
680 }
681
682 for _, relay := range s.config.Relays {
683 logger := logger.With("relay", relay)
684 logger.Info("requesting crawl from relay")
685 cli := xrpc.Client{Host: relay}
686 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
687 Hostname: s.config.Hostname,
688 }); err != nil {
689 logger.Error("error requesting crawl", "err", err)
690 } else {
691 logger.Info("crawl requested successfully")
692 }
693 }
694
695 s.lastRequestCrawl = time.Now()
696
697 return nil
698}
699
700func (s *Server) doBackup() {
701 logger := s.logger.With("name", "doBackup")
702
703 if s.dbType == "postgres" || s.dbType == "turso" {
704 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)")
705 return
706 }
707
708 start := time.Now()
709
710 logger.Info("beginning backup to s3...")
711
712 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
713 defer os.Remove(tmpFile)
714
715 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
716 logger.Error("error creating tmp backup file", "err", err)
717 return
718 }
719
720 backupData, err := os.ReadFile(tmpFile)
721 if err != nil {
722 logger.Error("error reading tmp backup file", "err", err)
723 return
724 }
725
726 logger.Info("sending to s3...")
727
728 currTime := time.Now().Format("2006-01-02_15-04-05")
729 key := "cocoon-backup-" + currTime + ".db"
730
731 config := &aws.Config{
732 Region: aws.String(s.s3Config.Region),
733 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
734 }
735
736 if s.s3Config.Endpoint != "" {
737 config.Endpoint = aws.String(s.s3Config.Endpoint)
738 config.S3ForcePathStyle = aws.Bool(true)
739 }
740
741 sess, err := session.NewSession(config)
742 if err != nil {
743 logger.Error("error creating s3 session", "err", err)
744 return
745 }
746
747 svc := s3.New(sess)
748
749 if _, err := svc.PutObject(&s3.PutObjectInput{
750 Bucket: aws.String(s.s3Config.Bucket),
751 Key: aws.String(key),
752 Body: bytes.NewReader(backupData),
753 }); err != nil {
754 logger.Error("error uploading file to s3", "err", err)
755 return
756 }
757
758 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
759
760 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
761}
762
763func (s *Server) backupRoutine() {
764 logger := s.logger.With("name", "backupRoutine")
765
766 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
767 return
768 }
769
770 if s.s3Config.Region == "" {
771 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
772 return
773 }
774
775 if s.s3Config.Bucket == "" {
776 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
777 return
778 }
779
780 if s.s3Config.AccessKey == "" {
781 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
782 return
783 }
784
785 if s.s3Config.SecretKey == "" {
786 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
787 return
788 }
789
790 shouldBackupNow := false
791 lastBackupStr, err := os.ReadFile("last-backup.txt")
792 if err != nil {
793 shouldBackupNow = true
794 } else {
795 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
796 if err != nil {
797 shouldBackupNow = true
798 } else if time.Since(lastBackup).Seconds() > 3600 {
799 shouldBackupNow = true
800 }
801 }
802
803 if shouldBackupNow {
804 go s.doBackup()
805 }
806
807 ticker := time.NewTicker(time.Hour)
808 for range ticker.C {
809 go s.doBackup()
810 }
811}
812
813func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
814 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
815 return err
816 }
817
818 return nil
819}