forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "database/sql"
8 "embed"
9 "errors"
10 "fmt"
11 "io"
12 "log/slog"
13 "net/http"
14 "net/smtp"
15 "os"
16 "path/filepath"
17 "sync"
18 "text/template"
19 "time"
20
21 "github.com/aws/aws-sdk-go/aws"
22 "github.com/aws/aws-sdk-go/aws/credentials"
23 "github.com/aws/aws-sdk-go/aws/session"
24 "github.com/aws/aws-sdk-go/service/s3"
25 "github.com/bluesky-social/indigo/api/atproto"
26 "github.com/bluesky-social/indigo/atproto/syntax"
27 "github.com/bluesky-social/indigo/events"
28 "github.com/bluesky-social/indigo/util"
29 "github.com/bluesky-social/indigo/xrpc"
30 "github.com/domodwyer/mailyak/v3"
31 "github.com/go-playground/validator"
32 "github.com/gorilla/sessions"
33 "github.com/haileyok/cocoon/identity"
34 "github.com/haileyok/cocoon/internal/db"
35 "github.com/haileyok/cocoon/internal/helpers"
36 "github.com/haileyok/cocoon/models"
37 "github.com/haileyok/cocoon/oauth/client"
38 "github.com/haileyok/cocoon/oauth/constants"
39 "github.com/haileyok/cocoon/oauth/dpop"
40 "github.com/haileyok/cocoon/oauth/provider"
41 "github.com/haileyok/cocoon/plc"
42 "github.com/ipfs/go-cid"
43 "github.com/labstack/echo-contrib/echoprometheus"
44 echo_session "github.com/labstack/echo-contrib/session"
45 "github.com/labstack/echo/v4"
46 "github.com/labstack/echo/v4/middleware"
47 slogecho "github.com/samber/slog-echo"
48 _ "github.com/tursodatabase/libsql-client-go/libsql"
49 "gorm.io/driver/postgres"
50 "gorm.io/driver/sqlite"
51 "gorm.io/gorm"
52)
53
54const (
55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
56)
57
58type S3Config struct {
59 BackupsEnabled bool
60 BlobstoreEnabled bool
61 Endpoint string
62 Region string
63 Bucket string
64 AccessKey string
65 SecretKey string
66 CDNUrl string
67}
68
69type Server struct {
70 http *http.Client
71 httpd *http.Server
72 mail *mailyak.MailYak
73 mailLk *sync.Mutex
74 echo *echo.Echo
75 db *db.DB
76 plcClient *plc.Client
77 logger *slog.Logger
78 config *config
79 privateKey *ecdsa.PrivateKey
80 repoman *RepoMan
81 oauthProvider *provider.Provider
82 evtman *events.EventManager
83 passport *identity.Passport
84 fallbackProxy string
85
86 lastRequestCrawl time.Time
87 requestCrawlMu sync.Mutex
88
89 dbType string
90 s3Config *S3Config
91}
92
93type Args struct {
94 Logger *slog.Logger
95
96 LogLevel slog.Level
97 Addr string
98 DbName string
99 DbType string
100 DatabaseURL string
101 TursoToken string
102 Version string
103 Did string
104 Hostname string
105 RotationKeyPath string
106 JwkPath string
107 ContactEmail string
108 Relays []string
109 AdminPassword string
110 RequireInvite bool
111
112 SmtpUser string
113 SmtpPass string
114 SmtpHost string
115 SmtpPort string
116 SmtpEmail string
117 SmtpName string
118
119 S3Config *S3Config
120
121 SessionSecret string
122 SessionCookieKey string
123
124 BlockstoreVariant BlockstoreVariant
125 FallbackProxy string
126
127 PushBasedEvents bool
128 SubscribeReposServiceURL string
129 NonceSecret string
130}
131
132type config struct {
133 LogLevel slog.Level
134 Version string
135 Did string
136 Hostname string
137 ContactEmail string
138 EnforcePeering bool
139 Relays []string
140 AdminPassword string
141 RequireInvite bool
142 SmtpEmail string
143 SmtpName string
144 SessionCookieKey string
145 BlockstoreVariant BlockstoreVariant
146 FallbackProxy string
147
148 PushBasedEvents bool
149 SubscribeReposServiceURL string
150}
151
152type CustomValidator struct {
153 validator *validator.Validate
154}
155
156type ValidationError struct {
157 error
158 Field string
159 Tag string
160}
161
162func (cv *CustomValidator) Validate(i any) error {
163 if err := cv.validator.Struct(i); err != nil {
164 var validateErrors validator.ValidationErrors
165 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
166 first := validateErrors[0]
167 return ValidationError{
168 error: err,
169 Field: first.Field(),
170 Tag: first.Tag(),
171 }
172 }
173
174 return err
175 }
176
177 return nil
178}
179
180//go:embed templates/*
181var templateFS embed.FS
182
183//go:embed static/*
184var staticFS embed.FS
185
186type TemplateRenderer struct {
187 templates *template.Template
188 isDev bool
189 templatePath string
190}
191
192func (s *Server) loadTemplates() {
193 absPath, _ := filepath.Abs("server/templates/*.html")
194 if s.config.Version == "dev" {
195 tmpl := template.Must(template.ParseGlob(absPath))
196 s.echo.Renderer = &TemplateRenderer{
197 templates: tmpl,
198 isDev: true,
199 templatePath: absPath,
200 }
201 } else {
202 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
203 s.echo.Renderer = &TemplateRenderer{
204 templates: tmpl,
205 isDev: false,
206 }
207 }
208}
209
210func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
211 if t.isDev {
212 tmpl, err := template.ParseGlob(t.templatePath)
213 if err != nil {
214 return err
215 }
216 t.templates = tmpl
217 }
218
219 if viewContext, isMap := data.(map[string]any); isMap {
220 viewContext["reverse"] = c.Echo().Reverse
221 }
222
223 return t.templates.ExecuteTemplate(w, name, data)
224}
225
226type filteredHandler struct {
227 level slog.Level
228 handler slog.Handler
229}
230
231func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool {
232 return level >= h.level && h.handler.Enabled(ctx, level)
233}
234
235func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error {
236 return h.handler.Handle(ctx, r)
237}
238
239func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
240 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)}
241}
242
243func (h *filteredHandler) WithGroup(name string) slog.Handler {
244 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)}
245}
246
247func New(args *Args) (*Server, error) {
248 if args.Logger == nil {
249 args.Logger = slog.Default()
250 }
251
252 if args.LogLevel != 0 {
253 args.Logger = slog.New(&filteredHandler{
254 level: args.LogLevel,
255 handler: args.Logger.Handler(),
256 })
257 }
258
259 logger := args.Logger.With("name", "New")
260
261 if args.Addr == "" {
262 return nil, fmt.Errorf("addr must be set")
263 }
264
265 if args.DbType == "sqlite" && args.DbName == "" {
266 return nil, fmt.Errorf("db name must be set for a sqlite app")
267 }
268
269 if args.Did == "" {
270 return nil, fmt.Errorf("cocoon did must be set")
271 }
272
273 if args.ContactEmail == "" {
274 return nil, fmt.Errorf("cocoon contact email is required")
275 }
276
277 if _, err := syntax.ParseDID(args.Did); err != nil {
278 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
279 }
280
281 if args.Hostname == "" {
282 return nil, fmt.Errorf("cocoon hostname must be set")
283 }
284
285 if args.AdminPassword == "" {
286 return nil, fmt.Errorf("admin password must be set")
287 }
288
289 if args.SessionSecret == "" {
290 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
291 }
292
293 e := echo.New()
294
295 e.Pre(middleware.RemoveTrailingSlash())
296 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
297 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
298 e.Use(echoprometheus.NewMiddleware("cocoon"))
299 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
300 AllowOrigins: []string{"*"},
301 AllowHeaders: []string{"*"},
302 AllowMethods: []string{"*"},
303 AllowCredentials: true,
304 MaxAge: 100_000_000,
305 }))
306
307 vdtor := validator.New()
308 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
309 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
310 return false
311 }
312 return true
313 })
314 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
315 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
316 return false
317 }
318 return true
319 })
320 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
321 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
322 return false
323 }
324 return true
325 })
326 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
327 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
328 return false
329 }
330 return true
331 })
332
333 e.Validator = &CustomValidator{validator: vdtor}
334
335 httpd := &http.Server{
336 Addr: args.Addr,
337 Handler: e,
338 // shitty defaults but okay for now, needed for import repo
339 ReadTimeout: 5 * time.Minute,
340 WriteTimeout: 5 * time.Minute,
341 IdleTimeout: 5 * time.Minute,
342 }
343
344 dbType := args.DbType
345 if dbType == "" {
346 dbType = "sqlite"
347 }
348
349 var gdb *gorm.DB
350 gormCfg := gorm.Config{
351 TranslateError: true,
352 }
353 var err error
354 switch dbType {
355 case "postgres":
356 if args.DatabaseURL == "" {
357 return nil, fmt.Errorf("database-url must be set when using postgres")
358 }
359 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gormCfg)
360 if err != nil {
361 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
362 }
363 logger.Info("connected to PostgreSQL database")
364 case "turso":
365 primaryUrl := args.DatabaseURL
366 authToken := args.TursoToken
367
368 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken))
369 gdb, err = gorm.Open(sqlite.New(sqlite.Config{
370 Conn: db,
371 }), &gormCfg)
372 if err != nil {
373 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
374 }
375 logger.Info("connected to PostgreSQL database")
376
377 default:
378 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gormCfg)
379 if err != nil {
380 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
381 }
382 gdb.Exec("PRAGMA journal_mode=WAL")
383 gdb.Exec("PRAGMA synchronous=NORMAL")
384
385 logger.Info("connected to SQLite database", "path", args.DbName)
386 }
387 dbw := db.NewDB(gdb)
388
389 rkbytes, err := os.ReadFile(args.RotationKeyPath)
390 if err != nil {
391 return nil, err
392 }
393
394 h := util.RobustHTTPClient()
395
396 plcClient, err := plc.NewClient(&plc.ClientArgs{
397 H: h,
398 Service: "https://plc.directory",
399 PdsHostname: args.Hostname,
400 RotationKey: rkbytes,
401 })
402 if err != nil {
403 return nil, err
404 }
405
406 jwkbytes, err := os.ReadFile(args.JwkPath)
407 if err != nil {
408 return nil, err
409 }
410
411 key, err := helpers.ParseJWKFromBytes(jwkbytes)
412 if err != nil {
413 return nil, err
414 }
415
416 var pkey ecdsa.PrivateKey
417 if err := key.Raw(&pkey); err != nil {
418 return nil, err
419 }
420
421 oauthCli := &http.Client{
422 Timeout: 10 * time.Second,
423 }
424
425 var nonceSecret []byte
426 if args.NonceSecret != "" {
427 nonceSecret = []byte(args.NonceSecret)
428 } else {
429 maybeSecret, err := os.ReadFile("nonce.secret")
430 if err != nil && !os.IsNotExist(err) {
431 logger.Error("error attempting to read nonce secret", "error", err)
432 } else {
433 nonceSecret = maybeSecret
434 }
435 }
436
437 evtPersister, err := NewDbPersister(gdb, 72*time.Hour)
438 if err != nil {
439 return nil, fmt.Errorf("failed to create event persister: %w", err)
440 }
441
442 s := &Server{
443 http: h,
444 httpd: httpd,
445 echo: e,
446 logger: args.Logger,
447 db: dbw,
448 plcClient: plcClient,
449 privateKey: &pkey,
450 config: &config{
451 LogLevel: args.LogLevel,
452 Version: args.Version,
453 Did: args.Did,
454 Hostname: args.Hostname,
455 ContactEmail: args.ContactEmail,
456 EnforcePeering: false,
457 Relays: args.Relays,
458 AdminPassword: args.AdminPassword,
459 RequireInvite: args.RequireInvite,
460 SmtpName: args.SmtpName,
461 SmtpEmail: args.SmtpEmail,
462 SessionCookieKey: args.SessionCookieKey,
463 BlockstoreVariant: args.BlockstoreVariant,
464 FallbackProxy: args.FallbackProxy,
465 PushBasedEvents: args.PushBasedEvents,
466 SubscribeReposServiceURL: args.SubscribeReposServiceURL,
467 },
468 evtman: events.NewEventManager(evtPersister),
469 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
470
471 dbType: dbType,
472 s3Config: args.S3Config,
473
474 oauthProvider: provider.NewProvider(provider.Args{
475 Hostname: args.Hostname,
476 ClientManagerArgs: client.ManagerArgs{
477 Cli: oauthCli,
478 Logger: args.Logger.With("component", "oauth-client-manager"),
479 },
480 DpopManagerArgs: dpop.ManagerArgs{
481 NonceSecret: nonceSecret,
482 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
483 OnNonceSecretCreated: func(newNonce []byte) {
484 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
485 logger.Error("error writing new nonce secret", "error", err)
486 }
487 },
488 Logger: args.Logger.With("component", "dpop-manager"),
489 Hostname: args.Hostname,
490 },
491 }),
492 }
493
494 s.loadTemplates()
495
496 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
497
498 // TODO: should validate these args
499 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
500 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
501 } else {
502 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
503 mail.From(s.config.SmtpEmail)
504 mail.FromName(s.config.SmtpName)
505
506 s.mail = mail
507 s.mailLk = &sync.Mutex{}
508 }
509
510 return s, nil
511}
512
513func (s *Server) addRoutes() {
514 // static
515 if s.config.Version == "dev" {
516 s.echo.Static("/static", "server/static")
517 } else {
518 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
519 }
520
521 // random stuff
522 s.echo.GET("/", s.handleRoot)
523 s.echo.GET("/xrpc/_health", s.handleHealth)
524 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
525 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
526 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
527 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
528 s.echo.GET("/robots.txt", s.handleRobots)
529
530 // public
531 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
532 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
533 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
534 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
535 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
536
537 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
538 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
539 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
540 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
541 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
542 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
543 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
544 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
545 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
546 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
547 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
548 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
549
550 // labels
551 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
552
553 // account
554 s.echo.GET("/account", s.handleAccount)
555 s.echo.POST("/account/revoke", s.handleAccountRevoke)
556 s.echo.POST("/account/switch", s.handleAccountSwitchPost)
557 s.echo.GET("/account/signin", s.handleAccountSigninGet)
558 s.echo.POST("/account/signin", s.handleAccountSigninPost)
559 s.echo.GET("/account/signout", s.handleAccountSignout)
560
561 // oauth account
562 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
563 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
564 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
565
566 // oauth authorization
567 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
568 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
569
570 // authed
571 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
572 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
573 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
574 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
575 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
576 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
577 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
578 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
579 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
580 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
581 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
582 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
583 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
584 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
585 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
586 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
587 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
588 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
589 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
590 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
591
592 // repo
593 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
594 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
595 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
596 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
597 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
598 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
599 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
600
601 // stupid silly endpoints
602 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
603 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
604 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
605 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
606 // admin routes
607 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
608 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
609
610 // are there any routes that we should be allowing without auth? i dont think so but idk
611 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
612 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
613}
614
615func (s *Server) Serve(ctx context.Context) error {
616 logger := s.logger.With("name", "Serve")
617
618 s.addRoutes()
619
620 logger.Info("migrating...")
621
622 s.db.AutoMigrate(
623 &models.Actor{},
624 &models.Repo{},
625 &models.InviteCode{},
626 &models.Token{},
627 &models.RefreshToken{},
628 &models.Block{},
629 &models.Record{},
630 &models.Blob{},
631 &models.BlobPart{},
632 &models.ReservedKey{},
633 &provider.OauthToken{},
634 &provider.OauthAuthorizationRequest{},
635 )
636
637 logger.Info("starting cocoon")
638
639 go func() {
640 if err := s.httpd.ListenAndServe(); err != nil {
641 panic(err)
642 }
643 }()
644
645 go s.backupRoutine()
646
647 go func() {
648 if err := s.requestCrawl(ctx); err != nil {
649 logger.Error("error requesting crawls", "err", err)
650 }
651 }()
652
653 if s.config.PushBasedEvents {
654 slog.Info("pushed based events enabled")
655 go func() {
656 if err := s.emmitEvents(ctx); err != nil {
657 logger.Error("error emitting events", "err", err)
658 }
659 }()
660 }
661
662 <-ctx.Done()
663
664 fmt.Println("shut down")
665
666 return nil
667}
668
669func (s *Server) requestCrawl(ctx context.Context) error {
670 logger := s.logger.With("component", "request-crawl")
671 s.requestCrawlMu.Lock()
672 defer s.requestCrawlMu.Unlock()
673
674 logger.Info("requesting crawl with configured relays")
675
676 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
677 return fmt.Errorf("a crawl request has already been made within the last minute")
678 }
679
680 for _, relay := range s.config.Relays {
681 logger := logger.With("relay", relay)
682 logger.Info("requesting crawl from relay")
683 cli := xrpc.Client{Host: relay}
684 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
685 Hostname: s.config.Hostname,
686 }); err != nil {
687 logger.Error("error requesting crawl", "err", err)
688 } else {
689 logger.Info("crawl requested successfully")
690 }
691 }
692
693 s.lastRequestCrawl = time.Now()
694
695 return nil
696}
697
698func (s *Server) doBackup() {
699 logger := s.logger.With("name", "doBackup")
700
701 if s.dbType == "postgres" || s.dbType == "turso" {
702 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)")
703 return
704 }
705
706 start := time.Now()
707
708 logger.Info("beginning backup to s3...")
709
710 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
711 defer os.Remove(tmpFile)
712
713 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
714 logger.Error("error creating tmp backup file", "err", err)
715 return
716 }
717
718 backupData, err := os.ReadFile(tmpFile)
719 if err != nil {
720 logger.Error("error reading tmp backup file", "err", err)
721 return
722 }
723
724 logger.Info("sending to s3...")
725
726 currTime := time.Now().Format("2006-01-02_15-04-05")
727 key := "cocoon-backup-" + currTime + ".db"
728
729 config := &aws.Config{
730 Region: aws.String(s.s3Config.Region),
731 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
732 }
733
734 if s.s3Config.Endpoint != "" {
735 config.Endpoint = aws.String(s.s3Config.Endpoint)
736 config.S3ForcePathStyle = aws.Bool(true)
737 }
738
739 sess, err := session.NewSession(config)
740 if err != nil {
741 logger.Error("error creating s3 session", "err", err)
742 return
743 }
744
745 svc := s3.New(sess)
746
747 if _, err := svc.PutObject(&s3.PutObjectInput{
748 Bucket: aws.String(s.s3Config.Bucket),
749 Key: aws.String(key),
750 Body: bytes.NewReader(backupData),
751 }); err != nil {
752 logger.Error("error uploading file to s3", "err", err)
753 return
754 }
755
756 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
757
758 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
759}
760
761func (s *Server) backupRoutine() {
762 logger := s.logger.With("name", "backupRoutine")
763
764 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
765 return
766 }
767
768 if s.s3Config.Region == "" {
769 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
770 return
771 }
772
773 if s.s3Config.Bucket == "" {
774 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
775 return
776 }
777
778 if s.s3Config.AccessKey == "" {
779 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
780 return
781 }
782
783 if s.s3Config.SecretKey == "" {
784 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
785 return
786 }
787
788 shouldBackupNow := false
789 lastBackupStr, err := os.ReadFile("last-backup.txt")
790 if err != nil {
791 shouldBackupNow = true
792 } else {
793 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
794 if err != nil {
795 shouldBackupNow = true
796 } else if time.Since(lastBackup).Seconds() > 3600 {
797 shouldBackupNow = true
798 }
799 }
800
801 if shouldBackupNow {
802 go s.doBackup()
803 }
804
805 ticker := time.NewTicker(time.Hour)
806 for range ticker.C {
807 go s.doBackup()
808 }
809}
810
811func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
812 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
813 return err
814 }
815
816 return nil
817}