forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "embed"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "net/smtp"
14 "os"
15 "path/filepath"
16 "sync"
17 "text/template"
18 "time"
19
20 "github.com/aws/aws-sdk-go/aws"
21 "github.com/aws/aws-sdk-go/aws/credentials"
22 "github.com/aws/aws-sdk-go/aws/session"
23 "github.com/aws/aws-sdk-go/service/s3"
24 "github.com/bluesky-social/indigo/api/atproto"
25 "github.com/bluesky-social/indigo/atproto/syntax"
26 "github.com/bluesky-social/indigo/events"
27 "github.com/bluesky-social/indigo/util"
28 "github.com/bluesky-social/indigo/xrpc"
29 "github.com/domodwyer/mailyak/v3"
30 "github.com/go-playground/validator"
31 "github.com/gorilla/sessions"
32 "github.com/haileyok/cocoon/identity"
33 "github.com/haileyok/cocoon/internal/db"
34 "github.com/haileyok/cocoon/internal/helpers"
35 "github.com/haileyok/cocoon/models"
36 "github.com/haileyok/cocoon/oauth/client"
37 "github.com/haileyok/cocoon/oauth/constants"
38 "github.com/haileyok/cocoon/oauth/dpop"
39 "github.com/haileyok/cocoon/oauth/provider"
40 "github.com/haileyok/cocoon/plc"
41 "github.com/ipfs/go-cid"
42 "github.com/labstack/echo-contrib/echoprometheus"
43 echo_session "github.com/labstack/echo-contrib/session"
44 "github.com/labstack/echo/v4"
45 "github.com/labstack/echo/v4/middleware"
46 slogecho "github.com/samber/slog-echo"
47 "gorm.io/driver/postgres"
48 "gorm.io/driver/sqlite"
49 "gorm.io/gorm"
50)
51
52const (
53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
54)
55
56type S3Config struct {
57 BackupsEnabled bool
58 BlobstoreEnabled bool
59 Endpoint string
60 Region string
61 Bucket string
62 AccessKey string
63 SecretKey string
64 CDNUrl string
65}
66
67type Server struct {
68 http *http.Client
69 httpd *http.Server
70 mail *mailyak.MailYak
71 mailLk *sync.Mutex
72 echo *echo.Echo
73 db *db.DB
74 plcClient *plc.Client
75 logger *slog.Logger
76 config *config
77 privateKey *ecdsa.PrivateKey
78 repoman *RepoMan
79 oauthProvider *provider.Provider
80 evtman *events.EventManager
81 passport *identity.Passport
82 fallbackProxy string
83
84 lastRequestCrawl time.Time
85 requestCrawlMu sync.Mutex
86
87 dbName string
88 dbType string
89 s3Config *S3Config
90}
91
92type Args struct {
93 Logger *slog.Logger
94
95 LogLevel slog.Level
96 Addr string
97 DbName string
98 DbType string
99 DatabaseURL string
100 Version string
101 Did string
102 Hostname string
103 RotationKeyPath string
104 JwkPath string
105 ContactEmail string
106 Relays []string
107 AdminPassword string
108 RequireInvite bool
109
110 SmtpUser string
111 SmtpPass string
112 SmtpHost string
113 SmtpPort string
114 SmtpEmail string
115 SmtpName string
116
117 S3Config *S3Config
118
119 SessionSecret string
120 SessionCookieKey string
121
122 BlockstoreVariant BlockstoreVariant
123 FallbackProxy string
124}
125
126type config struct {
127 LogLevel slog.Level
128 Version string
129 Did string
130 Hostname string
131 ContactEmail string
132 EnforcePeering bool
133 Relays []string
134 AdminPassword string
135 RequireInvite bool
136 SmtpEmail string
137 SmtpName string
138 SessionCookieKey string
139 BlockstoreVariant BlockstoreVariant
140 FallbackProxy string
141}
142
143type CustomValidator struct {
144 validator *validator.Validate
145}
146
147type ValidationError struct {
148 error
149 Field string
150 Tag string
151}
152
153func (cv *CustomValidator) Validate(i any) error {
154 if err := cv.validator.Struct(i); err != nil {
155 var validateErrors validator.ValidationErrors
156 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
157 first := validateErrors[0]
158 return ValidationError{
159 error: err,
160 Field: first.Field(),
161 Tag: first.Tag(),
162 }
163 }
164
165 return err
166 }
167
168 return nil
169}
170
171//go:embed templates/*
172var templateFS embed.FS
173
174//go:embed static/*
175var staticFS embed.FS
176
177type TemplateRenderer struct {
178 templates *template.Template
179 isDev bool
180 templatePath string
181}
182
183func (s *Server) loadTemplates() {
184 absPath, _ := filepath.Abs("server/templates/*.html")
185 if s.config.Version == "dev" {
186 tmpl := template.Must(template.ParseGlob(absPath))
187 s.echo.Renderer = &TemplateRenderer{
188 templates: tmpl,
189 isDev: true,
190 templatePath: absPath,
191 }
192 } else {
193 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
194 s.echo.Renderer = &TemplateRenderer{
195 templates: tmpl,
196 isDev: false,
197 }
198 }
199}
200
201func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
202 if t.isDev {
203 tmpl, err := template.ParseGlob(t.templatePath)
204 if err != nil {
205 return err
206 }
207 t.templates = tmpl
208 }
209
210 if viewContext, isMap := data.(map[string]any); isMap {
211 viewContext["reverse"] = c.Echo().Reverse
212 }
213
214 return t.templates.ExecuteTemplate(w, name, data)
215}
216
217type filteredHandler struct {
218 level slog.Level
219 handler slog.Handler
220}
221
222func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool {
223 return level >= h.level && h.handler.Enabled(ctx, level)
224}
225
226func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error {
227 return h.handler.Handle(ctx, r)
228}
229
230func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
231 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)}
232}
233
234func (h *filteredHandler) WithGroup(name string) slog.Handler {
235 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)}
236}
237
238func New(args *Args) (*Server, error) {
239 if args.Logger == nil {
240 args.Logger = slog.Default()
241 }
242
243 if args.LogLevel != 0 {
244 args.Logger = slog.New(&filteredHandler{
245 level: args.LogLevel,
246 handler: args.Logger.Handler(),
247 })
248 }
249
250 logger := args.Logger.With("name", "New")
251
252 if args.Addr == "" {
253 return nil, fmt.Errorf("addr must be set")
254 }
255
256 if args.DbName == "" {
257 return nil, fmt.Errorf("db name must be set")
258 }
259
260 if args.Did == "" {
261 return nil, fmt.Errorf("cocoon did must be set")
262 }
263
264 if args.ContactEmail == "" {
265 return nil, fmt.Errorf("cocoon contact email is required")
266 }
267
268 if _, err := syntax.ParseDID(args.Did); err != nil {
269 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
270 }
271
272 if args.Hostname == "" {
273 return nil, fmt.Errorf("cocoon hostname must be set")
274 }
275
276 if args.AdminPassword == "" {
277 return nil, fmt.Errorf("admin password must be set")
278 }
279
280 if args.SessionSecret == "" {
281 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
282 }
283
284 e := echo.New()
285
286 e.Pre(middleware.RemoveTrailingSlash())
287 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
288 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
289 e.Use(echoprometheus.NewMiddleware("cocoon"))
290 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
291 AllowOrigins: []string{"*"},
292 AllowHeaders: []string{"*"},
293 AllowMethods: []string{"*"},
294 AllowCredentials: true,
295 MaxAge: 100_000_000,
296 }))
297
298 vdtor := validator.New()
299 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
300 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
301 return false
302 }
303 return true
304 })
305 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
306 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
307 return false
308 }
309 return true
310 })
311 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
312 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
313 return false
314 }
315 return true
316 })
317 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
318 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
319 return false
320 }
321 return true
322 })
323
324 e.Validator = &CustomValidator{validator: vdtor}
325
326 httpd := &http.Server{
327 Addr: args.Addr,
328 Handler: e,
329 // shitty defaults but okay for now, needed for import repo
330 ReadTimeout: 5 * time.Minute,
331 WriteTimeout: 5 * time.Minute,
332 IdleTimeout: 5 * time.Minute,
333 }
334
335 dbType := args.DbType
336 if dbType == "" {
337 dbType = "sqlite"
338 }
339
340 var gdb *gorm.DB
341 var err error
342 switch dbType {
343 case "postgres":
344 if args.DatabaseURL == "" {
345 return nil, fmt.Errorf("database-url must be set when using postgres")
346 }
347 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{})
348 if err != nil {
349 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
350 }
351 logger.Info("connected to PostgreSQL database")
352 default:
353 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
354 if err != nil {
355 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
356 }
357 gdb.Exec("PRAGMA journal_mode=WAL")
358 gdb.Exec("PRAGMA synchronous=NORMAL")
359
360 logger.Info("connected to SQLite database", "path", args.DbName)
361 }
362 dbw := db.NewDB(gdb)
363
364 rkbytes, err := os.ReadFile(args.RotationKeyPath)
365 if err != nil {
366 return nil, err
367 }
368
369 h := util.RobustHTTPClient()
370
371 plcClient, err := plc.NewClient(&plc.ClientArgs{
372 H: h,
373 Service: "https://plc.directory",
374 PdsHostname: args.Hostname,
375 RotationKey: rkbytes,
376 })
377 if err != nil {
378 return nil, err
379 }
380
381 jwkbytes, err := os.ReadFile(args.JwkPath)
382 if err != nil {
383 return nil, err
384 }
385
386 key, err := helpers.ParseJWKFromBytes(jwkbytes)
387 if err != nil {
388 return nil, err
389 }
390
391 var pkey ecdsa.PrivateKey
392 if err := key.Raw(&pkey); err != nil {
393 return nil, err
394 }
395
396 oauthCli := &http.Client{
397 Timeout: 10 * time.Second,
398 }
399
400 var nonceSecret []byte
401 maybeSecret, err := os.ReadFile("nonce.secret")
402 if err != nil && !os.IsNotExist(err) {
403 logger.Error("error attempting to read nonce secret", "error", err)
404 } else {
405 nonceSecret = maybeSecret
406 }
407
408 evtPersister, err := NewDbPersister(gdb, 72*time.Hour)
409 if err != nil {
410 return nil, fmt.Errorf("failed to create event persister: %w", err)
411 }
412
413 s := &Server{
414 http: h,
415 httpd: httpd,
416 echo: e,
417 logger: args.Logger,
418 db: dbw,
419 plcClient: plcClient,
420 privateKey: &pkey,
421 config: &config{
422 LogLevel: args.LogLevel,
423 Version: args.Version,
424 Did: args.Did,
425 Hostname: args.Hostname,
426 ContactEmail: args.ContactEmail,
427 EnforcePeering: false,
428 Relays: args.Relays,
429 AdminPassword: args.AdminPassword,
430 RequireInvite: args.RequireInvite,
431 SmtpName: args.SmtpName,
432 SmtpEmail: args.SmtpEmail,
433 SessionCookieKey: args.SessionCookieKey,
434 BlockstoreVariant: args.BlockstoreVariant,
435 FallbackProxy: args.FallbackProxy,
436 },
437 evtman: events.NewEventManager(evtPersister),
438 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
439
440 dbName: args.DbName,
441 dbType: dbType,
442 s3Config: args.S3Config,
443
444 oauthProvider: provider.NewProvider(provider.Args{
445 Hostname: args.Hostname,
446 ClientManagerArgs: client.ManagerArgs{
447 Cli: oauthCli,
448 Logger: args.Logger.With("component", "oauth-client-manager"),
449 },
450 DpopManagerArgs: dpop.ManagerArgs{
451 NonceSecret: nonceSecret,
452 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
453 OnNonceSecretCreated: func(newNonce []byte) {
454 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
455 logger.Error("error writing new nonce secret", "error", err)
456 }
457 },
458 Logger: args.Logger.With("component", "dpop-manager"),
459 Hostname: args.Hostname,
460 },
461 }),
462 }
463
464 s.loadTemplates()
465
466 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
467
468 // TODO: should validate these args
469 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
470 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
471 } else {
472 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
473 mail.From(s.config.SmtpEmail)
474 mail.FromName(s.config.SmtpName)
475
476 s.mail = mail
477 s.mailLk = &sync.Mutex{}
478 }
479
480 return s, nil
481}
482
483func (s *Server) addRoutes() {
484 // static
485 if s.config.Version == "dev" {
486 s.echo.Static("/static", "server/static")
487 } else {
488 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
489 }
490
491 // random stuff
492 s.echo.GET("/", s.handleRoot)
493 s.echo.GET("/xrpc/_health", s.handleHealth)
494 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
495 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
496 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
497 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
498 s.echo.GET("/robots.txt", s.handleRobots)
499
500 // public
501 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
502 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
503 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
504 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
505 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
506
507 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
508 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
509 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
510 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
511 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
512 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
513 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
514 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
515 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
516 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
517 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
518 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
519
520 // labels
521 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
522
523 // account
524 s.echo.GET("/account", s.handleAccount)
525 s.echo.POST("/account/revoke", s.handleAccountRevoke)
526 s.echo.POST("/account/switch", s.handleAccountSwitchPost)
527 s.echo.GET("/account/signin", s.handleAccountSigninGet)
528 s.echo.POST("/account/signin", s.handleAccountSigninPost)
529 s.echo.GET("/account/signout", s.handleAccountSignout)
530
531 // oauth account
532 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
533 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
534 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
535
536 // oauth authorization
537 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
538 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
539
540 // authed
541 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
542 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
543 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
544 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
545 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
546 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
547 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
548 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
549 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
550 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
551 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
552 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
553 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
554 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
555 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
556 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
557 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
558 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
559 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
560 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
561
562 // repo
563 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
564 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
565 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
566 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
567 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
568 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
569 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
570
571 // stupid silly endpoints
572 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
573 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
574 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
575 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
576 // admin routes
577 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
578 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
579
580 // are there any routes that we should be allowing without auth? i dont think so but idk
581 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
582 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
583}
584
585func (s *Server) Serve(ctx context.Context) error {
586 logger := s.logger.With("name", "Serve")
587
588 s.addRoutes()
589
590 logger.Info("migrating...")
591
592 s.db.AutoMigrate(
593 &models.Actor{},
594 &models.Repo{},
595 &models.InviteCode{},
596 &models.Token{},
597 &models.RefreshToken{},
598 &models.Block{},
599 &models.Record{},
600 &models.Blob{},
601 &models.BlobPart{},
602 &models.ReservedKey{},
603 &provider.OauthToken{},
604 &provider.OauthAuthorizationRequest{},
605 )
606
607 logger.Info("starting cocoon")
608
609 go func() {
610 if err := s.httpd.ListenAndServe(); err != nil {
611 panic(err)
612 }
613 }()
614
615 go s.backupRoutine()
616
617 go func() {
618 if err := s.requestCrawl(ctx); err != nil {
619 logger.Error("error requesting crawls", "err", err)
620 }
621 }()
622
623 <-ctx.Done()
624
625 fmt.Println("shut down")
626
627 return nil
628}
629
630func (s *Server) requestCrawl(ctx context.Context) error {
631 logger := s.logger.With("component", "request-crawl")
632 s.requestCrawlMu.Lock()
633 defer s.requestCrawlMu.Unlock()
634
635 logger.Info("requesting crawl with configured relays")
636
637 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
638 return fmt.Errorf("a crawl request has already been made within the last minute")
639 }
640
641 for _, relay := range s.config.Relays {
642 logger := logger.With("relay", relay)
643 logger.Info("requesting crawl from relay")
644 cli := xrpc.Client{Host: relay}
645 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
646 Hostname: s.config.Hostname,
647 }); err != nil {
648 logger.Error("error requesting crawl", "err", err)
649 } else {
650 logger.Info("crawl requested successfully")
651 }
652 }
653
654 s.lastRequestCrawl = time.Now()
655
656 return nil
657}
658
659func (s *Server) doBackup() {
660 logger := s.logger.With("name", "doBackup")
661
662 if s.dbType == "postgres" {
663 logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)")
664 return
665 }
666
667 start := time.Now()
668
669 logger.Info("beginning backup to s3...")
670
671 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
672 defer os.Remove(tmpFile)
673
674 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
675 logger.Error("error creating tmp backup file", "err", err)
676 return
677 }
678
679 backupData, err := os.ReadFile(tmpFile)
680 if err != nil {
681 logger.Error("error reading tmp backup file", "err", err)
682 return
683 }
684
685 logger.Info("sending to s3...")
686
687 currTime := time.Now().Format("2006-01-02_15-04-05")
688 key := "cocoon-backup-" + currTime + ".db"
689
690 config := &aws.Config{
691 Region: aws.String(s.s3Config.Region),
692 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
693 }
694
695 if s.s3Config.Endpoint != "" {
696 config.Endpoint = aws.String(s.s3Config.Endpoint)
697 config.S3ForcePathStyle = aws.Bool(true)
698 }
699
700 sess, err := session.NewSession(config)
701 if err != nil {
702 logger.Error("error creating s3 session", "err", err)
703 return
704 }
705
706 svc := s3.New(sess)
707
708 if _, err := svc.PutObject(&s3.PutObjectInput{
709 Bucket: aws.String(s.s3Config.Bucket),
710 Key: aws.String(key),
711 Body: bytes.NewReader(backupData),
712 }); err != nil {
713 logger.Error("error uploading file to s3", "err", err)
714 return
715 }
716
717 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
718
719 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
720}
721
722func (s *Server) backupRoutine() {
723 logger := s.logger.With("name", "backupRoutine")
724
725 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
726 return
727 }
728
729 if s.s3Config.Region == "" {
730 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
731 return
732 }
733
734 if s.s3Config.Bucket == "" {
735 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
736 return
737 }
738
739 if s.s3Config.AccessKey == "" {
740 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
741 return
742 }
743
744 if s.s3Config.SecretKey == "" {
745 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
746 return
747 }
748
749 shouldBackupNow := false
750 lastBackupStr, err := os.ReadFile("last-backup.txt")
751 if err != nil {
752 shouldBackupNow = true
753 } else {
754 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
755 if err != nil {
756 shouldBackupNow = true
757 } else if time.Since(lastBackup).Seconds() > 3600 {
758 shouldBackupNow = true
759 }
760 }
761
762 if shouldBackupNow {
763 go s.doBackup()
764 }
765
766 ticker := time.NewTicker(time.Hour)
767 for range ticker.C {
768 go s.doBackup()
769 }
770}
771
772func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
773 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
774 return err
775 }
776
777 return nil
778}