forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "embed"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "net/smtp"
14 "os"
15 "path/filepath"
16 "sync"
17 "text/template"
18 "time"
19
20 "github.com/aws/aws-sdk-go/aws"
21 "github.com/aws/aws-sdk-go/aws/credentials"
22 "github.com/aws/aws-sdk-go/aws/session"
23 "github.com/aws/aws-sdk-go/service/s3"
24 "github.com/bluesky-social/indigo/api/atproto"
25 "github.com/bluesky-social/indigo/atproto/syntax"
26 "github.com/bluesky-social/indigo/events"
27 "github.com/bluesky-social/indigo/util"
28 "github.com/bluesky-social/indigo/xrpc"
29 "github.com/domodwyer/mailyak/v3"
30 "github.com/go-playground/validator"
31 "github.com/gorilla/sessions"
32 "github.com/haileyok/cocoon/identity"
33 "github.com/haileyok/cocoon/internal/db"
34 "github.com/haileyok/cocoon/internal/helpers"
35 "github.com/haileyok/cocoon/models"
36 "github.com/haileyok/cocoon/oauth/client"
37 "github.com/haileyok/cocoon/oauth/constants"
38 "github.com/haileyok/cocoon/oauth/dpop"
39 "github.com/haileyok/cocoon/oauth/provider"
40 "github.com/haileyok/cocoon/plc"
41 "github.com/ipfs/go-cid"
42 "github.com/labstack/echo-contrib/echoprometheus"
43 echo_session "github.com/labstack/echo-contrib/session"
44 "github.com/labstack/echo/v4"
45 "github.com/labstack/echo/v4/middleware"
46 slogecho "github.com/samber/slog-echo"
47 "gorm.io/driver/postgres"
48 "gorm.io/driver/sqlite"
49 "gorm.io/gorm"
50)
51
52const (
53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
54)
55
56type S3Config struct {
57 BackupsEnabled bool
58 BlobstoreEnabled bool
59 Endpoint string
60 Region string
61 Bucket string
62 AccessKey string
63 SecretKey string
64 CDNUrl string
65}
66
67type Server struct {
68 http *http.Client
69 httpd *http.Server
70 mail *mailyak.MailYak
71 mailLk *sync.Mutex
72 echo *echo.Echo
73 db *db.DB
74 plcClient *plc.Client
75 logger *slog.Logger
76 config *config
77 privateKey *ecdsa.PrivateKey
78 repoman *RepoMan
79 oauthProvider *provider.Provider
80 evtman *events.EventManager
81 passport *identity.Passport
82 fallbackProxy string
83
84 lastRequestCrawl time.Time
85 requestCrawlMu sync.Mutex
86
87 dbName string
88 dbType string
89 s3Config *S3Config
90}
91
92type Args struct {
93 Logger *slog.Logger
94
95 LogLevel slog.Level
96 Addr string
97 DbName string
98 DbType string
99 DatabaseURL string
100 Version string
101 Did string
102 Hostname string
103 RotationKeyPath string
104 JwkPath string
105 ContactEmail string
106 Relays []string
107 AdminPassword string
108 RequireInvite bool
109
110 SmtpUser string
111 SmtpPass string
112 SmtpHost string
113 SmtpPort string
114 SmtpEmail string
115 SmtpName string
116
117 S3Config *S3Config
118
119 SessionSecret string
120 SessionCookieKey string
121
122 BlockstoreVariant BlockstoreVariant
123 FallbackProxy string
124}
125
126type config struct {
127 LogLevel slog.Level
128 Version string
129 Did string
130 Hostname string
131 ContactEmail string
132 EnforcePeering bool
133 Relays []string
134 AdminPassword string
135 RequireInvite bool
136 SmtpEmail string
137 SmtpName string
138 SessionCookieKey string
139 BlockstoreVariant BlockstoreVariant
140 FallbackProxy string
141}
142
143type CustomValidator struct {
144 validator *validator.Validate
145}
146
147type ValidationError struct {
148 error
149 Field string
150 Tag string
151}
152
153func (cv *CustomValidator) Validate(i any) error {
154 if err := cv.validator.Struct(i); err != nil {
155 var validateErrors validator.ValidationErrors
156 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
157 first := validateErrors[0]
158 return ValidationError{
159 error: err,
160 Field: first.Field(),
161 Tag: first.Tag(),
162 }
163 }
164
165 return err
166 }
167
168 return nil
169}
170
171//go:embed templates/*
172var templateFS embed.FS
173
174//go:embed static/*
175var staticFS embed.FS
176
177type TemplateRenderer struct {
178 templates *template.Template
179 isDev bool
180 templatePath string
181}
182
183func (s *Server) loadTemplates() {
184 absPath, _ := filepath.Abs("server/templates/*.html")
185 if s.config.Version == "dev" {
186 tmpl := template.Must(template.ParseGlob(absPath))
187 s.echo.Renderer = &TemplateRenderer{
188 templates: tmpl,
189 isDev: true,
190 templatePath: absPath,
191 }
192 } else {
193 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
194 s.echo.Renderer = &TemplateRenderer{
195 templates: tmpl,
196 isDev: false,
197 }
198 }
199}
200
201func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
202 if t.isDev {
203 tmpl, err := template.ParseGlob(t.templatePath)
204 if err != nil {
205 return err
206 }
207 t.templates = tmpl
208 }
209
210 if viewContext, isMap := data.(map[string]any); isMap {
211 viewContext["reverse"] = c.Echo().Reverse
212 }
213
214 return t.templates.ExecuteTemplate(w, name, data)
215}
216
217type filteredHandler struct {
218 level slog.Level
219 handler slog.Handler
220}
221
222func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool {
223 return level >= h.level && h.handler.Enabled(ctx, level)
224}
225
226func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error {
227 return h.handler.Handle(ctx, r)
228}
229
230func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
231 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)}
232}
233
234func (h *filteredHandler) WithGroup(name string) slog.Handler {
235 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)}
236}
237
238func New(args *Args) (*Server, error) {
239 if args.Logger == nil {
240 args.Logger = slog.Default()
241 }
242
243 if args.LogLevel != 0 {
244 args.Logger = slog.New(&filteredHandler{
245 level: args.LogLevel,
246 handler: args.Logger.Handler(),
247 })
248 }
249
250 logger := args.Logger.With("name", "New")
251
252 if args.Addr == "" {
253 return nil, fmt.Errorf("addr must be set")
254 }
255
256 if args.DbName == "" {
257 return nil, fmt.Errorf("db name must be set")
258 }
259
260 if args.Did == "" {
261 return nil, fmt.Errorf("cocoon did must be set")
262 }
263
264 if args.ContactEmail == "" {
265 return nil, fmt.Errorf("cocoon contact email is required")
266 }
267
268 if _, err := syntax.ParseDID(args.Did); err != nil {
269 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
270 }
271
272 if args.Hostname == "" {
273 return nil, fmt.Errorf("cocoon hostname must be set")
274 }
275
276 if args.AdminPassword == "" {
277 return nil, fmt.Errorf("admin password must be set")
278 }
279
280 if args.SessionSecret == "" {
281 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
282 }
283
284 e := echo.New()
285
286 e.Pre(middleware.RemoveTrailingSlash())
287 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
288 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
289 e.Use(echoprometheus.NewMiddleware("cocoon"))
290 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
291 AllowOrigins: []string{"*"},
292 AllowHeaders: []string{"*"},
293 AllowMethods: []string{"*"},
294 AllowCredentials: true,
295 MaxAge: 100_000_000,
296 }))
297
298 vdtor := validator.New()
299 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
300 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
301 return false
302 }
303 return true
304 })
305 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
306 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
307 return false
308 }
309 return true
310 })
311 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
312 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
313 return false
314 }
315 return true
316 })
317 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
318 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
319 return false
320 }
321 return true
322 })
323
324 e.Validator = &CustomValidator{validator: vdtor}
325
326 httpd := &http.Server{
327 Addr: args.Addr,
328 Handler: e,
329 // shitty defaults but okay for now, needed for import repo
330 ReadTimeout: 5 * time.Minute,
331 WriteTimeout: 5 * time.Minute,
332 IdleTimeout: 5 * time.Minute,
333 }
334
335 dbType := args.DbType
336 if dbType == "" {
337 dbType = "sqlite"
338 }
339
340 var gdb *gorm.DB
341 var err error
342 switch dbType {
343 case "postgres":
344 if args.DatabaseURL == "" {
345 return nil, fmt.Errorf("database-url must be set when using postgres")
346 }
347 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{})
348 if err != nil {
349 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
350 }
351 logger.Info("connected to PostgreSQL database")
352 default:
353 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
354 if err != nil {
355 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
356 }
357 gdb.Exec("PRAGMA journal_mode=WAL")
358 gdb.Exec("PRAGMA synchronous=NORMAL")
359
360 logger.Info("connected to SQLite database", "path", args.DbName)
361 }
362 dbw := db.NewDB(gdb)
363
364 rkbytes, err := os.ReadFile(args.RotationKeyPath)
365 if err != nil {
366 return nil, err
367 }
368
369 h := util.RobustHTTPClient()
370
371 plcClient, err := plc.NewClient(&plc.ClientArgs{
372 H: h,
373 Service: "https://plc.directory",
374 PdsHostname: args.Hostname,
375 RotationKey: rkbytes,
376 })
377 if err != nil {
378 return nil, err
379 }
380
381 jwkbytes, err := os.ReadFile(args.JwkPath)
382 if err != nil {
383 return nil, err
384 }
385
386 key, err := helpers.ParseJWKFromBytes(jwkbytes)
387 if err != nil {
388 return nil, err
389 }
390
391 var pkey ecdsa.PrivateKey
392 if err := key.Raw(&pkey); err != nil {
393 return nil, err
394 }
395
396 oauthCli := &http.Client{
397 Timeout: 10 * time.Second,
398 }
399
400 var nonceSecret []byte
401 maybeSecret, err := os.ReadFile("nonce.secret")
402 if err != nil && !os.IsNotExist(err) {
403 logger.Error("error attempting to read nonce secret", "error", err)
404 } else {
405 nonceSecret = maybeSecret
406 }
407
408 evtPersister, err := NewDbPersister(gdb, 72*time.Hour)
409 if err != nil {
410 return nil, fmt.Errorf("failed to create event persister: %w", err)
411 }
412
413 s := &Server{
414 http: h,
415 httpd: httpd,
416 echo: e,
417 logger: args.Logger,
418 db: dbw,
419 plcClient: plcClient,
420 privateKey: &pkey,
421 config: &config{
422 LogLevel: args.LogLevel,
423 Version: args.Version,
424 Did: args.Did,
425 Hostname: args.Hostname,
426 ContactEmail: args.ContactEmail,
427 EnforcePeering: false,
428 Relays: args.Relays,
429 AdminPassword: args.AdminPassword,
430 RequireInvite: args.RequireInvite,
431 SmtpName: args.SmtpName,
432 SmtpEmail: args.SmtpEmail,
433 SessionCookieKey: args.SessionCookieKey,
434 BlockstoreVariant: args.BlockstoreVariant,
435 FallbackProxy: args.FallbackProxy,
436 },
437 evtman: events.NewEventManager(evtPersister),
438 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
439
440 dbName: args.DbName,
441 dbType: dbType,
442 s3Config: args.S3Config,
443
444 oauthProvider: provider.NewProvider(provider.Args{
445 Hostname: args.Hostname,
446 ClientManagerArgs: client.ManagerArgs{
447 Cli: oauthCli,
448 Logger: args.Logger.With("component", "oauth-client-manager"),
449 },
450 DpopManagerArgs: dpop.ManagerArgs{
451 NonceSecret: nonceSecret,
452 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
453 OnNonceSecretCreated: func(newNonce []byte) {
454 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
455 logger.Error("error writing new nonce secret", "error", err)
456 }
457 },
458 Logger: args.Logger.With("component", "dpop-manager"),
459 Hostname: args.Hostname,
460 },
461 }),
462 }
463
464 s.loadTemplates()
465
466 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
467
468 // TODO: should validate these args
469 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
470 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
471 } else {
472 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
473 mail.From(s.config.SmtpEmail)
474 mail.FromName(s.config.SmtpName)
475
476 s.mail = mail
477 s.mailLk = &sync.Mutex{}
478 }
479
480 return s, nil
481}
482
483func (s *Server) addRoutes() {
484 // static
485 if s.config.Version == "dev" {
486 s.echo.Static("/static", "server/static")
487 } else {
488 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
489 }
490
491 // random stuff
492 s.echo.GET("/", s.handleRoot)
493 s.echo.GET("/xrpc/_health", s.handleHealth)
494 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
495 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
496 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
497 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
498 s.echo.GET("/robots.txt", s.handleRobots)
499
500 // public
501 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
502 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
503 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
504 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
505 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
506
507 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
508 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
509 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
510 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
511 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
512 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
513 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
514 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
515 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
516 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
517 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
518 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
519
520 // labels
521 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
522
523 // account
524 s.echo.GET("/account", s.handleAccount)
525 s.echo.POST("/account/revoke", s.handleAccountRevoke)
526 s.echo.GET("/account/signin", s.handleAccountSigninGet)
527 s.echo.POST("/account/signin", s.handleAccountSigninPost)
528 s.echo.GET("/account/signout", s.handleAccountSignout)
529
530 // oauth account
531 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
532 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
533 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
534
535 // oauth authorization
536 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
537 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
538
539 // authed
540 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
541 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
542 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
543 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
544 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
545 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
546 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
547 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
548 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
549 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
550 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
551 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
552 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
553 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
554 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
555 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
556 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
557 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
558 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
559 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
560
561 // repo
562 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
563 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
564 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
565 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
566 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
567 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
568 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
569
570 // stupid silly endpoints
571 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
572 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
573 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
574 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
575 // admin routes
576 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
577 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
578
579 // are there any routes that we should be allowing without auth? i dont think so but idk
580 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
581 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
582}
583
584func (s *Server) Serve(ctx context.Context) error {
585 logger := s.logger.With("name", "Serve")
586
587 s.addRoutes()
588
589 logger.Info("migrating...")
590
591 s.db.AutoMigrate(
592 &models.Actor{},
593 &models.Repo{},
594 &models.InviteCode{},
595 &models.Token{},
596 &models.RefreshToken{},
597 &models.Block{},
598 &models.Record{},
599 &models.Blob{},
600 &models.BlobPart{},
601 &models.ReservedKey{},
602 &provider.OauthToken{},
603 &provider.OauthAuthorizationRequest{},
604 )
605
606 logger.Info("starting cocoon")
607
608 go func() {
609 if err := s.httpd.ListenAndServe(); err != nil {
610 panic(err)
611 }
612 }()
613
614 go s.backupRoutine()
615
616 go func() {
617 if err := s.requestCrawl(ctx); err != nil {
618 logger.Error("error requesting crawls", "err", err)
619 }
620 }()
621
622 <-ctx.Done()
623
624 fmt.Println("shut down")
625
626 return nil
627}
628
629func (s *Server) requestCrawl(ctx context.Context) error {
630 logger := s.logger.With("component", "request-crawl")
631 s.requestCrawlMu.Lock()
632 defer s.requestCrawlMu.Unlock()
633
634 logger.Info("requesting crawl with configured relays")
635
636 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
637 return fmt.Errorf("a crawl request has already been made within the last minute")
638 }
639
640 for _, relay := range s.config.Relays {
641 logger := logger.With("relay", relay)
642 logger.Info("requesting crawl from relay")
643 cli := xrpc.Client{Host: relay}
644 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
645 Hostname: s.config.Hostname,
646 }); err != nil {
647 logger.Error("error requesting crawl", "err", err)
648 } else {
649 logger.Info("crawl requested successfully")
650 }
651 }
652
653 s.lastRequestCrawl = time.Now()
654
655 return nil
656}
657
658func (s *Server) doBackup() {
659 logger := s.logger.With("name", "doBackup")
660
661 if s.dbType == "postgres" {
662 logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)")
663 return
664 }
665
666 start := time.Now()
667
668 logger.Info("beginning backup to s3...")
669
670 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
671 defer os.Remove(tmpFile)
672
673 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
674 logger.Error("error creating tmp backup file", "err", err)
675 return
676 }
677
678 backupData, err := os.ReadFile(tmpFile)
679 if err != nil {
680 logger.Error("error reading tmp backup file", "err", err)
681 return
682 }
683
684 logger.Info("sending to s3...")
685
686 currTime := time.Now().Format("2006-01-02_15-04-05")
687 key := "cocoon-backup-" + currTime + ".db"
688
689 config := &aws.Config{
690 Region: aws.String(s.s3Config.Region),
691 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
692 }
693
694 if s.s3Config.Endpoint != "" {
695 config.Endpoint = aws.String(s.s3Config.Endpoint)
696 config.S3ForcePathStyle = aws.Bool(true)
697 }
698
699 sess, err := session.NewSession(config)
700 if err != nil {
701 logger.Error("error creating s3 session", "err", err)
702 return
703 }
704
705 svc := s3.New(sess)
706
707 if _, err := svc.PutObject(&s3.PutObjectInput{
708 Bucket: aws.String(s.s3Config.Bucket),
709 Key: aws.String(key),
710 Body: bytes.NewReader(backupData),
711 }); err != nil {
712 logger.Error("error uploading file to s3", "err", err)
713 return
714 }
715
716 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
717
718 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
719}
720
721func (s *Server) backupRoutine() {
722 logger := s.logger.With("name", "backupRoutine")
723
724 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
725 return
726 }
727
728 if s.s3Config.Region == "" {
729 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
730 return
731 }
732
733 if s.s3Config.Bucket == "" {
734 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
735 return
736 }
737
738 if s.s3Config.AccessKey == "" {
739 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
740 return
741 }
742
743 if s.s3Config.SecretKey == "" {
744 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
745 return
746 }
747
748 shouldBackupNow := false
749 lastBackupStr, err := os.ReadFile("last-backup.txt")
750 if err != nil {
751 shouldBackupNow = true
752 } else {
753 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
754 if err != nil {
755 shouldBackupNow = true
756 } else if time.Since(lastBackup).Seconds() > 3600 {
757 shouldBackupNow = true
758 }
759 }
760
761 if shouldBackupNow {
762 go s.doBackup()
763 }
764
765 ticker := time.NewTicker(time.Hour)
766 for range ticker.C {
767 go s.doBackup()
768 }
769}
770
771func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
772 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
773 return err
774 }
775
776 return nil
777}