forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "database/sql"
8 "embed"
9 "errors"
10 "fmt"
11 "io"
12 "log/slog"
13 "net/http"
14 "net/smtp"
15 "os"
16 "path/filepath"
17 "sync"
18 "text/template"
19 "time"
20
21 "github.com/aws/aws-sdk-go/aws"
22 "github.com/aws/aws-sdk-go/aws/credentials"
23 "github.com/aws/aws-sdk-go/aws/session"
24 "github.com/aws/aws-sdk-go/service/s3"
25 "github.com/bluesky-social/indigo/api/atproto"
26 "github.com/bluesky-social/indigo/atproto/syntax"
27 "github.com/bluesky-social/indigo/events"
28 "github.com/bluesky-social/indigo/util"
29 "github.com/bluesky-social/indigo/xrpc"
30 "github.com/domodwyer/mailyak/v3"
31 "github.com/go-playground/validator"
32 "github.com/gorilla/sessions"
33 "github.com/haileyok/cocoon/identity"
34 "github.com/haileyok/cocoon/internal/db"
35 "github.com/haileyok/cocoon/internal/helpers"
36 "github.com/haileyok/cocoon/models"
37 "github.com/haileyok/cocoon/oauth/client"
38 "github.com/haileyok/cocoon/oauth/constants"
39 "github.com/haileyok/cocoon/oauth/dpop"
40 "github.com/haileyok/cocoon/oauth/provider"
41 "github.com/haileyok/cocoon/plc"
42 "github.com/ipfs/go-cid"
43 "github.com/labstack/echo-contrib/echoprometheus"
44 echo_session "github.com/labstack/echo-contrib/session"
45 "github.com/labstack/echo/v4"
46 "github.com/labstack/echo/v4/middleware"
47 slogecho "github.com/samber/slog-echo"
48 _ "github.com/tursodatabase/libsql-client-go/libsql"
49 "gorm.io/driver/postgres"
50 "gorm.io/driver/sqlite"
51 "gorm.io/gorm"
52)
53
54const (
55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
56)
57
58type S3Config struct {
59 BackupsEnabled bool
60 BlobstoreEnabled bool
61 Endpoint string
62 Region string
63 Bucket string
64 AccessKey string
65 SecretKey string
66 CDNUrl string
67}
68
69type Server struct {
70 http *http.Client
71 httpd *http.Server
72 mail *mailyak.MailYak
73 mailLk *sync.Mutex
74 echo *echo.Echo
75 db *db.DB
76 plcClient *plc.Client
77 logger *slog.Logger
78 config *config
79 privateKey *ecdsa.PrivateKey
80 repoman *RepoMan
81 oauthProvider *provider.Provider
82 evtman *events.EventManager
83 passport *identity.Passport
84 fallbackProxy string
85
86 lastRequestCrawl time.Time
87 requestCrawlMu sync.Mutex
88
89 dbName string
90 dbType string
91 s3Config *S3Config
92}
93
94type Args struct {
95 Logger *slog.Logger
96
97 LogLevel slog.Level
98 Addr string
99 DbName string
100 DbType string
101 DatabaseURL string
102 TursoToken string
103 Version string
104 Did string
105 Hostname string
106 RotationKeyPath string
107 JwkPath string
108 ContactEmail string
109 Relays []string
110 AdminPassword string
111 RequireInvite bool
112
113 SmtpUser string
114 SmtpPass string
115 SmtpHost string
116 SmtpPort string
117 SmtpEmail string
118 SmtpName string
119
120 S3Config *S3Config
121
122 SessionSecret string
123 SessionCookieKey string
124
125 BlockstoreVariant BlockstoreVariant
126 FallbackProxy string
127
128 PushBasedEvents bool
129 SubscribeReposServiceURL string
130 NonceSecret string
131}
132
133type config struct {
134 LogLevel slog.Level
135 Version string
136 Did string
137 Hostname string
138 ContactEmail string
139 EnforcePeering bool
140 Relays []string
141 AdminPassword string
142 RequireInvite bool
143 SmtpEmail string
144 SmtpName string
145 SessionCookieKey string
146 BlockstoreVariant BlockstoreVariant
147 FallbackProxy string
148
149 PushBasedEvents bool
150 SubscribeReposServiceURL string
151}
152
153type CustomValidator struct {
154 validator *validator.Validate
155}
156
157type ValidationError struct {
158 error
159 Field string
160 Tag string
161}
162
163func (cv *CustomValidator) Validate(i any) error {
164 if err := cv.validator.Struct(i); err != nil {
165 var validateErrors validator.ValidationErrors
166 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
167 first := validateErrors[0]
168 return ValidationError{
169 error: err,
170 Field: first.Field(),
171 Tag: first.Tag(),
172 }
173 }
174
175 return err
176 }
177
178 return nil
179}
180
181//go:embed templates/*
182var templateFS embed.FS
183
184//go:embed static/*
185var staticFS embed.FS
186
187type TemplateRenderer struct {
188 templates *template.Template
189 isDev bool
190 templatePath string
191}
192
193func (s *Server) loadTemplates() {
194 absPath, _ := filepath.Abs("server/templates/*.html")
195 if s.config.Version == "dev" {
196 tmpl := template.Must(template.ParseGlob(absPath))
197 s.echo.Renderer = &TemplateRenderer{
198 templates: tmpl,
199 isDev: true,
200 templatePath: absPath,
201 }
202 } else {
203 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
204 s.echo.Renderer = &TemplateRenderer{
205 templates: tmpl,
206 isDev: false,
207 }
208 }
209}
210
211func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
212 if t.isDev {
213 tmpl, err := template.ParseGlob(t.templatePath)
214 if err != nil {
215 return err
216 }
217 t.templates = tmpl
218 }
219
220 if viewContext, isMap := data.(map[string]any); isMap {
221 viewContext["reverse"] = c.Echo().Reverse
222 }
223
224 return t.templates.ExecuteTemplate(w, name, data)
225}
226
227type filteredHandler struct {
228 level slog.Level
229 handler slog.Handler
230}
231
232func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool {
233 return level >= h.level && h.handler.Enabled(ctx, level)
234}
235
236func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error {
237 return h.handler.Handle(ctx, r)
238}
239
240func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
241 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)}
242}
243
244func (h *filteredHandler) WithGroup(name string) slog.Handler {
245 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)}
246}
247
248func New(args *Args) (*Server, error) {
249 if args.Logger == nil {
250 args.Logger = slog.Default()
251 }
252
253 if args.LogLevel != 0 {
254 args.Logger = slog.New(&filteredHandler{
255 level: args.LogLevel,
256 handler: args.Logger.Handler(),
257 })
258 }
259
260 logger := args.Logger.With("name", "New")
261
262 if args.Addr == "" {
263 return nil, fmt.Errorf("addr must be set")
264 }
265
266 if args.DbName == "" {
267 return nil, fmt.Errorf("db name must be set")
268 }
269
270 if args.Did == "" {
271 return nil, fmt.Errorf("cocoon did must be set")
272 }
273
274 if args.ContactEmail == "" {
275 return nil, fmt.Errorf("cocoon contact email is required")
276 }
277
278 if _, err := syntax.ParseDID(args.Did); err != nil {
279 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
280 }
281
282 if args.Hostname == "" {
283 return nil, fmt.Errorf("cocoon hostname must be set")
284 }
285
286 if args.AdminPassword == "" {
287 return nil, fmt.Errorf("admin password must be set")
288 }
289
290 if args.SessionSecret == "" {
291 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
292 }
293
294 e := echo.New()
295
296 e.Pre(middleware.RemoveTrailingSlash())
297 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
298 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
299 e.Use(echoprometheus.NewMiddleware("cocoon"))
300 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
301 AllowOrigins: []string{"*"},
302 AllowHeaders: []string{"*"},
303 AllowMethods: []string{"*"},
304 AllowCredentials: true,
305 MaxAge: 100_000_000,
306 }))
307
308 vdtor := validator.New()
309 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
310 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
311 return false
312 }
313 return true
314 })
315 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
316 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
317 return false
318 }
319 return true
320 })
321 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
322 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
323 return false
324 }
325 return true
326 })
327 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
328 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
329 return false
330 }
331 return true
332 })
333
334 e.Validator = &CustomValidator{validator: vdtor}
335
336 httpd := &http.Server{
337 Addr: args.Addr,
338 Handler: e,
339 // shitty defaults but okay for now, needed for import repo
340 ReadTimeout: 5 * time.Minute,
341 WriteTimeout: 5 * time.Minute,
342 IdleTimeout: 5 * time.Minute,
343 }
344
345 dbType := args.DbType
346 if dbType == "" {
347 dbType = "sqlite"
348 }
349
350 var gdb *gorm.DB
351 var err error
352 switch dbType {
353 case "postgres":
354 if args.DatabaseURL == "" {
355 return nil, fmt.Errorf("database-url must be set when using postgres")
356 }
357 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{})
358 if err != nil {
359 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
360 }
361 logger.Info("connected to PostgreSQL database")
362 case "turso":
363 primaryUrl := args.DatabaseURL
364 authToken := args.TursoToken
365
366 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken))
367 gdb, err = gorm.Open(sqlite.New(sqlite.Config{
368 Conn: db,
369 }), &gorm.Config{})
370 if err != nil {
371 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
372 }
373 logger.Info("connected to PostgreSQL database")
374
375 default:
376 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
377 if err != nil {
378 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
379 }
380 gdb.Exec("PRAGMA journal_mode=WAL")
381 gdb.Exec("PRAGMA synchronous=NORMAL")
382
383 logger.Info("connected to SQLite database", "path", args.DbName)
384 }
385 dbw := db.NewDB(gdb)
386
387 rkbytes, err := os.ReadFile(args.RotationKeyPath)
388 if err != nil {
389 return nil, err
390 }
391
392 h := util.RobustHTTPClient()
393
394 plcClient, err := plc.NewClient(&plc.ClientArgs{
395 H: h,
396 Service: "https://plc.directory",
397 PdsHostname: args.Hostname,
398 RotationKey: rkbytes,
399 })
400 if err != nil {
401 return nil, err
402 }
403
404 jwkbytes, err := os.ReadFile(args.JwkPath)
405 if err != nil {
406 return nil, err
407 }
408
409 key, err := helpers.ParseJWKFromBytes(jwkbytes)
410 if err != nil {
411 return nil, err
412 }
413
414 var pkey ecdsa.PrivateKey
415 if err := key.Raw(&pkey); err != nil {
416 return nil, err
417 }
418
419 oauthCli := &http.Client{
420 Timeout: 10 * time.Second,
421 }
422
423 var nonceSecret []byte
424 if args.NonceSecret != "" {
425 nonceSecret = []byte(args.NonceSecret)
426 } else {
427 maybeSecret, err := os.ReadFile("nonce.secret")
428 if err != nil && !os.IsNotExist(err) {
429 logger.Error("error attempting to read nonce secret", "error", err)
430 } else {
431 nonceSecret = maybeSecret
432 }
433 }
434
435 evtPersister, err := NewDbPersister(gdb, 72*time.Hour)
436 if err != nil {
437 return nil, fmt.Errorf("failed to create event persister: %w", err)
438 }
439
440 s := &Server{
441 http: h,
442 httpd: httpd,
443 echo: e,
444 logger: args.Logger,
445 db: dbw,
446 plcClient: plcClient,
447 privateKey: &pkey,
448 config: &config{
449 LogLevel: args.LogLevel,
450 Version: args.Version,
451 Did: args.Did,
452 Hostname: args.Hostname,
453 ContactEmail: args.ContactEmail,
454 EnforcePeering: false,
455 Relays: args.Relays,
456 AdminPassword: args.AdminPassword,
457 RequireInvite: args.RequireInvite,
458 SmtpName: args.SmtpName,
459 SmtpEmail: args.SmtpEmail,
460 SessionCookieKey: args.SessionCookieKey,
461 BlockstoreVariant: args.BlockstoreVariant,
462 FallbackProxy: args.FallbackProxy,
463 PushBasedEvents: args.PushBasedEvents,
464 SubscribeReposServiceURL: args.SubscribeReposServiceURL,
465 },
466 evtman: events.NewEventManager(evtPersister),
467 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
468
469 dbName: args.DbName,
470 dbType: dbType,
471 s3Config: args.S3Config,
472
473 oauthProvider: provider.NewProvider(provider.Args{
474 Hostname: args.Hostname,
475 ClientManagerArgs: client.ManagerArgs{
476 Cli: oauthCli,
477 Logger: args.Logger.With("component", "oauth-client-manager"),
478 },
479 DpopManagerArgs: dpop.ManagerArgs{
480 NonceSecret: nonceSecret,
481 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
482 OnNonceSecretCreated: func(newNonce []byte) {
483 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
484 logger.Error("error writing new nonce secret", "error", err)
485 }
486 },
487 Logger: args.Logger.With("component", "dpop-manager"),
488 Hostname: args.Hostname,
489 },
490 }),
491 }
492
493 s.loadTemplates()
494
495 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
496
497 // TODO: should validate these args
498 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
499 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
500 } else {
501 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
502 mail.From(s.config.SmtpEmail)
503 mail.FromName(s.config.SmtpName)
504
505 s.mail = mail
506 s.mailLk = &sync.Mutex{}
507 }
508
509 return s, nil
510}
511
512func (s *Server) addRoutes() {
513 // static
514 if s.config.Version == "dev" {
515 s.echo.Static("/static", "server/static")
516 } else {
517 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
518 }
519
520 // random stuff
521 s.echo.GET("/", s.handleRoot)
522 s.echo.GET("/xrpc/_health", s.handleHealth)
523 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
524 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
525 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
526 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
527 s.echo.GET("/robots.txt", s.handleRobots)
528
529 // public
530 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
531 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
532 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
533 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
534 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
535
536 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
537 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
538 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
539 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
540 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
541 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
542 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
543 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
544 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
545 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
546 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
547 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
548
549 // labels
550 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
551
552 // account
553 s.echo.GET("/account", s.handleAccount)
554 s.echo.POST("/account/revoke", s.handleAccountRevoke)
555 s.echo.GET("/account/signin", s.handleAccountSigninGet)
556 s.echo.POST("/account/signin", s.handleAccountSigninPost)
557 s.echo.GET("/account/signout", s.handleAccountSignout)
558
559 // oauth account
560 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
561 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
562 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
563
564 // oauth authorization
565 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
566 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
567
568 // authed
569 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
570 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
571 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
572 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
573 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
574 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
575 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
576 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
577 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
578 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
579 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
580 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
581 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
582 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
583 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
584 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
585 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
586 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
587 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
588 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
589
590 // repo
591 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
592 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
593 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
594 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
595 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
596 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
597 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
598
599 // stupid silly endpoints
600 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
601 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
602 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
603 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
604 // admin routes
605 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
606 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
607
608 // are there any routes that we should be allowing without auth? i dont think so but idk
609 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
610 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
611}
612
613func (s *Server) Serve(ctx context.Context) error {
614 logger := s.logger.With("name", "Serve")
615
616 s.addRoutes()
617
618 logger.Info("migrating...")
619
620 s.db.AutoMigrate(
621 &models.Actor{},
622 &models.Repo{},
623 &models.InviteCode{},
624 &models.Token{},
625 &models.RefreshToken{},
626 &models.Block{},
627 &models.Record{},
628 &models.Blob{},
629 &models.BlobPart{},
630 &models.ReservedKey{},
631 &provider.OauthToken{},
632 &provider.OauthAuthorizationRequest{},
633 )
634
635 logger.Info("starting cocoon")
636
637 go func() {
638 if err := s.httpd.ListenAndServe(); err != nil {
639 panic(err)
640 }
641 }()
642
643 go s.backupRoutine()
644
645 go func() {
646 if err := s.requestCrawl(ctx); err != nil {
647 logger.Error("error requesting crawls", "err", err)
648 }
649 }()
650
651 if s.config.PushBasedEvents {
652 slog.Info("pushed based events enabled")
653 go func() {
654 if err := s.emmitEvents(ctx); err != nil {
655 logger.Error("error emitting events", "err", err)
656 }
657 }()
658 }
659
660 <-ctx.Done()
661
662 fmt.Println("shut down")
663
664 return nil
665}
666
667func (s *Server) requestCrawl(ctx context.Context) error {
668 logger := s.logger.With("component", "request-crawl")
669 s.requestCrawlMu.Lock()
670 defer s.requestCrawlMu.Unlock()
671
672 logger.Info("requesting crawl with configured relays")
673
674 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
675 return fmt.Errorf("a crawl request has already been made within the last minute")
676 }
677
678 for _, relay := range s.config.Relays {
679 logger := logger.With("relay", relay)
680 logger.Info("requesting crawl from relay")
681 cli := xrpc.Client{Host: relay}
682 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
683 Hostname: s.config.Hostname,
684 }); err != nil {
685 logger.Error("error requesting crawl", "err", err)
686 } else {
687 logger.Info("crawl requested successfully")
688 }
689 }
690
691 s.lastRequestCrawl = time.Now()
692
693 return nil
694}
695
696func (s *Server) doBackup() {
697 logger := s.logger.With("name", "doBackup")
698
699 if s.dbType == "postgres" || s.dbType == "turso" {
700 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)")
701 return
702 }
703
704 start := time.Now()
705
706 logger.Info("beginning backup to s3...")
707
708 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
709 defer os.Remove(tmpFile)
710
711 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
712 logger.Error("error creating tmp backup file", "err", err)
713 return
714 }
715
716 backupData, err := os.ReadFile(tmpFile)
717 if err != nil {
718 logger.Error("error reading tmp backup file", "err", err)
719 return
720 }
721
722 logger.Info("sending to s3...")
723
724 currTime := time.Now().Format("2006-01-02_15-04-05")
725 key := "cocoon-backup-" + currTime + ".db"
726
727 config := &aws.Config{
728 Region: aws.String(s.s3Config.Region),
729 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
730 }
731
732 if s.s3Config.Endpoint != "" {
733 config.Endpoint = aws.String(s.s3Config.Endpoint)
734 config.S3ForcePathStyle = aws.Bool(true)
735 }
736
737 sess, err := session.NewSession(config)
738 if err != nil {
739 logger.Error("error creating s3 session", "err", err)
740 return
741 }
742
743 svc := s3.New(sess)
744
745 if _, err := svc.PutObject(&s3.PutObjectInput{
746 Bucket: aws.String(s.s3Config.Bucket),
747 Key: aws.String(key),
748 Body: bytes.NewReader(backupData),
749 }); err != nil {
750 logger.Error("error uploading file to s3", "err", err)
751 return
752 }
753
754 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
755
756 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
757}
758
759func (s *Server) backupRoutine() {
760 logger := s.logger.With("name", "backupRoutine")
761
762 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
763 return
764 }
765
766 if s.s3Config.Region == "" {
767 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
768 return
769 }
770
771 if s.s3Config.Bucket == "" {
772 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
773 return
774 }
775
776 if s.s3Config.AccessKey == "" {
777 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
778 return
779 }
780
781 if s.s3Config.SecretKey == "" {
782 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
783 return
784 }
785
786 shouldBackupNow := false
787 lastBackupStr, err := os.ReadFile("last-backup.txt")
788 if err != nil {
789 shouldBackupNow = true
790 } else {
791 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
792 if err != nil {
793 shouldBackupNow = true
794 } else if time.Since(lastBackup).Seconds() > 3600 {
795 shouldBackupNow = true
796 }
797 }
798
799 if shouldBackupNow {
800 go s.doBackup()
801 }
802
803 ticker := time.NewTicker(time.Hour)
804 for range ticker.C {
805 go s.doBackup()
806 }
807}
808
809func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
810 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
811 return err
812 }
813
814 return nil
815}