forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "embed"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "net/smtp"
14 "os"
15 "path/filepath"
16 "sync"
17 "text/template"
18 "time"
19
20 "github.com/aws/aws-sdk-go/aws"
21 "github.com/aws/aws-sdk-go/aws/credentials"
22 "github.com/aws/aws-sdk-go/aws/session"
23 "github.com/aws/aws-sdk-go/service/s3"
24 "github.com/bluesky-social/indigo/api/atproto"
25 "github.com/bluesky-social/indigo/atproto/syntax"
26 "github.com/bluesky-social/indigo/events"
27 "github.com/bluesky-social/indigo/util"
28 "github.com/bluesky-social/indigo/xrpc"
29 "github.com/domodwyer/mailyak/v3"
30 "github.com/go-playground/validator"
31 "github.com/gorilla/sessions"
32 "github.com/haileyok/cocoon/identity"
33 "github.com/haileyok/cocoon/internal/db"
34 "github.com/haileyok/cocoon/internal/helpers"
35 "github.com/haileyok/cocoon/models"
36 "github.com/haileyok/cocoon/oauth/client"
37 "github.com/haileyok/cocoon/oauth/constants"
38 "github.com/haileyok/cocoon/oauth/dpop"
39 "github.com/haileyok/cocoon/oauth/provider"
40 "github.com/haileyok/cocoon/plc"
41 "github.com/ipfs/go-cid"
42 "github.com/labstack/echo-contrib/echoprometheus"
43 echo_session "github.com/labstack/echo-contrib/session"
44 "github.com/labstack/echo/v4"
45 "github.com/labstack/echo/v4/middleware"
46 slogecho "github.com/samber/slog-echo"
47 "gorm.io/driver/postgres"
48 "gorm.io/driver/sqlite"
49 "gorm.io/gorm"
50)
51
52const (
53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
54)
55
56type S3Config struct {
57 BackupsEnabled bool
58 BlobstoreEnabled bool
59 Endpoint string
60 Region string
61 Bucket string
62 AccessKey string
63 SecretKey string
64 CDNUrl string
65}
66
67type Server struct {
68 http *http.Client
69 httpd *http.Server
70 mail *mailyak.MailYak
71 mailLk *sync.Mutex
72 echo *echo.Echo
73 db *db.DB
74 plcClient *plc.Client
75 logger *slog.Logger
76 config *config
77 privateKey *ecdsa.PrivateKey
78 repoman *RepoMan
79 oauthProvider *provider.Provider
80 evtman *events.EventManager
81 passport *identity.Passport
82 fallbackProxy string
83
84 lastRequestCrawl time.Time
85 requestCrawlMu sync.Mutex
86
87 dbName string
88 dbType string
89 s3Config *S3Config
90}
91
92type Args struct {
93 Logger *slog.Logger
94
95 Addr string
96 DbName string
97 DbType string
98 DatabaseURL string
99 Version string
100 Did string
101 Hostname string
102 RotationKeyPath string
103 JwkPath string
104 ContactEmail string
105 Relays []string
106 AdminPassword string
107 RequireInvite bool
108
109 SmtpUser string
110 SmtpPass string
111 SmtpHost string
112 SmtpPort string
113 SmtpEmail string
114 SmtpName string
115
116 S3Config *S3Config
117
118 SessionSecret string
119 SessionCookieKey string
120
121 BlockstoreVariant BlockstoreVariant
122 FallbackProxy string
123}
124
125type config struct {
126 Version string
127 Did string
128 Hostname string
129 ContactEmail string
130 EnforcePeering bool
131 Relays []string
132 AdminPassword string
133 RequireInvite bool
134 SmtpEmail string
135 SmtpName string
136 SessionCookieKey string
137 BlockstoreVariant BlockstoreVariant
138 FallbackProxy string
139}
140
141type CustomValidator struct {
142 validator *validator.Validate
143}
144
145type ValidationError struct {
146 error
147 Field string
148 Tag string
149}
150
151func (cv *CustomValidator) Validate(i any) error {
152 if err := cv.validator.Struct(i); err != nil {
153 var validateErrors validator.ValidationErrors
154 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
155 first := validateErrors[0]
156 return ValidationError{
157 error: err,
158 Field: first.Field(),
159 Tag: first.Tag(),
160 }
161 }
162
163 return err
164 }
165
166 return nil
167}
168
169//go:embed templates/*
170var templateFS embed.FS
171
172//go:embed static/*
173var staticFS embed.FS
174
175type TemplateRenderer struct {
176 templates *template.Template
177 isDev bool
178 templatePath string
179}
180
181func (s *Server) loadTemplates() {
182 absPath, _ := filepath.Abs("server/templates/*.html")
183 if s.config.Version == "dev" {
184 tmpl := template.Must(template.ParseGlob(absPath))
185 s.echo.Renderer = &TemplateRenderer{
186 templates: tmpl,
187 isDev: true,
188 templatePath: absPath,
189 }
190 } else {
191 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
192 s.echo.Renderer = &TemplateRenderer{
193 templates: tmpl,
194 isDev: false,
195 }
196 }
197}
198
199func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
200 if t.isDev {
201 tmpl, err := template.ParseGlob(t.templatePath)
202 if err != nil {
203 return err
204 }
205 t.templates = tmpl
206 }
207
208 if viewContext, isMap := data.(map[string]any); isMap {
209 viewContext["reverse"] = c.Echo().Reverse
210 }
211
212 return t.templates.ExecuteTemplate(w, name, data)
213}
214
215func New(args *Args) (*Server, error) {
216 if args.Logger == nil {
217 args.Logger = slog.Default()
218 }
219
220 logger := args.Logger.With("name", "New")
221
222 if args.Addr == "" {
223 return nil, fmt.Errorf("addr must be set")
224 }
225
226 if args.DbName == "" {
227 return nil, fmt.Errorf("db name must be set")
228 }
229
230 if args.Did == "" {
231 return nil, fmt.Errorf("cocoon did must be set")
232 }
233
234 if args.ContactEmail == "" {
235 return nil, fmt.Errorf("cocoon contact email is required")
236 }
237
238 if _, err := syntax.ParseDID(args.Did); err != nil {
239 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
240 }
241
242 if args.Hostname == "" {
243 return nil, fmt.Errorf("cocoon hostname must be set")
244 }
245
246 if args.AdminPassword == "" {
247 return nil, fmt.Errorf("admin password must be set")
248 }
249
250 if args.SessionSecret == "" {
251 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
252 }
253
254 e := echo.New()
255
256 e.Pre(middleware.RemoveTrailingSlash())
257 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
258 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
259 e.Use(echoprometheus.NewMiddleware("cocoon"))
260 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
261 AllowOrigins: []string{"*"},
262 AllowHeaders: []string{"*"},
263 AllowMethods: []string{"*"},
264 AllowCredentials: true,
265 MaxAge: 100_000_000,
266 }))
267
268 vdtor := validator.New()
269 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
270 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
271 return false
272 }
273 return true
274 })
275 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
276 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
277 return false
278 }
279 return true
280 })
281 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
282 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
283 return false
284 }
285 return true
286 })
287 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
288 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
289 return false
290 }
291 return true
292 })
293
294 e.Validator = &CustomValidator{validator: vdtor}
295
296 httpd := &http.Server{
297 Addr: args.Addr,
298 Handler: e,
299 // shitty defaults but okay for now, needed for import repo
300 ReadTimeout: 5 * time.Minute,
301 WriteTimeout: 5 * time.Minute,
302 IdleTimeout: 5 * time.Minute,
303 }
304
305 dbType := args.DbType
306 if dbType == "" {
307 dbType = "sqlite"
308 }
309
310 var gdb *gorm.DB
311 var err error
312 switch dbType {
313 case "postgres":
314 if args.DatabaseURL == "" {
315 return nil, fmt.Errorf("database-url must be set when using postgres")
316 }
317 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{})
318 if err != nil {
319 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
320 }
321 logger.Info("connected to PostgreSQL database")
322 default:
323 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
324 if err != nil {
325 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
326 }
327 gdb.Exec("PRAGMA journal_mode=WAL")
328 gdb.Exec("PRAGMA synchronous=NORMAL")
329
330 logger.Info("connected to SQLite database", "path", args.DbName)
331 }
332 dbw := db.NewDB(gdb)
333
334 rkbytes, err := os.ReadFile(args.RotationKeyPath)
335 if err != nil {
336 return nil, err
337 }
338
339 h := util.RobustHTTPClient()
340
341 plcClient, err := plc.NewClient(&plc.ClientArgs{
342 H: h,
343 Service: "https://plc.directory",
344 PdsHostname: args.Hostname,
345 RotationKey: rkbytes,
346 })
347 if err != nil {
348 return nil, err
349 }
350
351 jwkbytes, err := os.ReadFile(args.JwkPath)
352 if err != nil {
353 return nil, err
354 }
355
356 key, err := helpers.ParseJWKFromBytes(jwkbytes)
357 if err != nil {
358 return nil, err
359 }
360
361 var pkey ecdsa.PrivateKey
362 if err := key.Raw(&pkey); err != nil {
363 return nil, err
364 }
365
366 oauthCli := &http.Client{
367 Timeout: 10 * time.Second,
368 }
369
370 var nonceSecret []byte
371 maybeSecret, err := os.ReadFile("nonce.secret")
372 if err != nil && !os.IsNotExist(err) {
373 logger.Error("error attempting to read nonce secret", "error", err)
374 } else {
375 nonceSecret = maybeSecret
376 }
377
378 s := &Server{
379 http: h,
380 httpd: httpd,
381 echo: e,
382 logger: args.Logger,
383 db: dbw,
384 plcClient: plcClient,
385 privateKey: &pkey,
386 config: &config{
387 Version: args.Version,
388 Did: args.Did,
389 Hostname: args.Hostname,
390 ContactEmail: args.ContactEmail,
391 EnforcePeering: false,
392 Relays: args.Relays,
393 AdminPassword: args.AdminPassword,
394 RequireInvite: args.RequireInvite,
395 SmtpName: args.SmtpName,
396 SmtpEmail: args.SmtpEmail,
397 SessionCookieKey: args.SessionCookieKey,
398 BlockstoreVariant: args.BlockstoreVariant,
399 FallbackProxy: args.FallbackProxy,
400 },
401 evtman: events.NewEventManager(events.NewMemPersister()),
402 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
403
404 dbName: args.DbName,
405 dbType: dbType,
406 s3Config: args.S3Config,
407
408 oauthProvider: provider.NewProvider(provider.Args{
409 Hostname: args.Hostname,
410 ClientManagerArgs: client.ManagerArgs{
411 Cli: oauthCli,
412 Logger: args.Logger.With("component", "oauth-client-manager"),
413 },
414 DpopManagerArgs: dpop.ManagerArgs{
415 NonceSecret: nonceSecret,
416 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
417 OnNonceSecretCreated: func(newNonce []byte) {
418 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
419 logger.Error("error writing new nonce secret", "error", err)
420 }
421 },
422 Logger: args.Logger.With("component", "dpop-manager"),
423 Hostname: args.Hostname,
424 },
425 }),
426 }
427
428 s.loadTemplates()
429
430 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
431
432 // TODO: should validate these args
433 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
434 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
435 } else {
436 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
437 mail.From(s.config.SmtpEmail)
438 mail.FromName(s.config.SmtpName)
439
440 s.mail = mail
441 s.mailLk = &sync.Mutex{}
442 }
443
444 return s, nil
445}
446
447func (s *Server) addRoutes() {
448 // static
449 if s.config.Version == "dev" {
450 s.echo.Static("/static", "server/static")
451 } else {
452 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
453 }
454
455 // random stuff
456 s.echo.GET("/", s.handleRoot)
457 s.echo.GET("/xrpc/_health", s.handleHealth)
458 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
459 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
460 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
461 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
462 s.echo.GET("/robots.txt", s.handleRobots)
463
464 // public
465 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
466 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
467 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
468 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
469 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
470
471 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
472 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
473 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
474 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
475 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
476 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
477 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
478 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
479 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
480 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
481 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
482 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
483
484 // labels
485 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
486
487 // account
488 s.echo.GET("/account", s.handleAccount)
489 s.echo.POST("/account/revoke", s.handleAccountRevoke)
490 s.echo.GET("/account/signin", s.handleAccountSigninGet)
491 s.echo.POST("/account/signin", s.handleAccountSigninPost)
492 s.echo.GET("/account/signout", s.handleAccountSignout)
493
494 // oauth account
495 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
496 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
497 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
498
499 // oauth authorization
500 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
501 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
502
503 // authed
504 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
505 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
506 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
507 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
508 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
509 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
510 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
511 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
512 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
513 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
514 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
515 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
516 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
517 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
518 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
519 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
520 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
521 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
522 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
523 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
524
525 // repo
526 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
527 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
528 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
529 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
530 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
531 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
532 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
533
534 // stupid silly endpoints
535 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
536 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
537 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
538 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
539 // admin routes
540 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
541 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
542
543 // are there any routes that we should be allowing without auth? i dont think so but idk
544 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
545 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
546}
547
548func (s *Server) Serve(ctx context.Context) error {
549 logger := s.logger.With("name", "Serve")
550
551 s.addRoutes()
552
553 logger.Info("migrating...")
554
555 s.db.AutoMigrate(
556 &models.Actor{},
557 &models.Repo{},
558 &models.InviteCode{},
559 &models.Token{},
560 &models.RefreshToken{},
561 &models.Block{},
562 &models.Record{},
563 &models.Blob{},
564 &models.BlobPart{},
565 &models.ReservedKey{},
566 &provider.OauthToken{},
567 &provider.OauthAuthorizationRequest{},
568 )
569
570 logger.Info("starting cocoon")
571
572 go func() {
573 if err := s.httpd.ListenAndServe(); err != nil {
574 panic(err)
575 }
576 }()
577
578 go s.backupRoutine()
579
580 go func() {
581 if err := s.requestCrawl(ctx); err != nil {
582 logger.Error("error requesting crawls", "err", err)
583 }
584 }()
585
586 <-ctx.Done()
587
588 fmt.Println("shut down")
589
590 return nil
591}
592
593func (s *Server) requestCrawl(ctx context.Context) error {
594 logger := s.logger.With("component", "request-crawl")
595 s.requestCrawlMu.Lock()
596 defer s.requestCrawlMu.Unlock()
597
598 logger.Info("requesting crawl with configured relays")
599
600 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
601 return fmt.Errorf("a crawl request has already been made within the last minute")
602 }
603
604 for _, relay := range s.config.Relays {
605 logger := logger.With("relay", relay)
606 logger.Info("requesting crawl from relay")
607 cli := xrpc.Client{Host: relay}
608 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
609 Hostname: s.config.Hostname,
610 }); err != nil {
611 logger.Error("error requesting crawl", "err", err)
612 } else {
613 logger.Info("crawl requested successfully")
614 }
615 }
616
617 s.lastRequestCrawl = time.Now()
618
619 return nil
620}
621
622func (s *Server) doBackup() {
623 logger := s.logger.With("name", "doBackup")
624
625 if s.dbType == "postgres" {
626 logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)")
627 return
628 }
629
630 start := time.Now()
631
632 logger.Info("beginning backup to s3...")
633
634 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
635 defer os.Remove(tmpFile)
636
637 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
638 logger.Error("error creating tmp backup file", "err", err)
639 return
640 }
641
642 backupData, err := os.ReadFile(tmpFile)
643 if err != nil {
644 logger.Error("error reading tmp backup file", "err", err)
645 return
646 }
647
648 logger.Info("sending to s3...")
649
650 currTime := time.Now().Format("2006-01-02_15-04-05")
651 key := "cocoon-backup-" + currTime + ".db"
652
653 config := &aws.Config{
654 Region: aws.String(s.s3Config.Region),
655 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
656 }
657
658 if s.s3Config.Endpoint != "" {
659 config.Endpoint = aws.String(s.s3Config.Endpoint)
660 config.S3ForcePathStyle = aws.Bool(true)
661 }
662
663 sess, err := session.NewSession(config)
664 if err != nil {
665 logger.Error("error creating s3 session", "err", err)
666 return
667 }
668
669 svc := s3.New(sess)
670
671 if _, err := svc.PutObject(&s3.PutObjectInput{
672 Bucket: aws.String(s.s3Config.Bucket),
673 Key: aws.String(key),
674 Body: bytes.NewReader(backupData),
675 }); err != nil {
676 logger.Error("error uploading file to s3", "err", err)
677 return
678 }
679
680 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
681
682 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
683}
684
685func (s *Server) backupRoutine() {
686 logger := s.logger.With("name", "backupRoutine")
687
688 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
689 return
690 }
691
692 if s.s3Config.Region == "" {
693 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
694 return
695 }
696
697 if s.s3Config.Bucket == "" {
698 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
699 return
700 }
701
702 if s.s3Config.AccessKey == "" {
703 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
704 return
705 }
706
707 if s.s3Config.SecretKey == "" {
708 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
709 return
710 }
711
712 shouldBackupNow := false
713 lastBackupStr, err := os.ReadFile("last-backup.txt")
714 if err != nil {
715 shouldBackupNow = true
716 } else {
717 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
718 if err != nil {
719 shouldBackupNow = true
720 } else if time.Since(lastBackup).Seconds() > 3600 {
721 shouldBackupNow = true
722 }
723 }
724
725 if shouldBackupNow {
726 go s.doBackup()
727 }
728
729 ticker := time.NewTicker(time.Hour)
730 for range ticker.C {
731 go s.doBackup()
732 }
733}
734
735func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
736 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
737 return err
738 }
739
740 return nil
741}