forked from
willdot.net/cocoon
A fork of the Cocoon PDS but being made more distributed.
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "database/sql"
8 "embed"
9 "errors"
10 "fmt"
11 "io"
12 "log/slog"
13 "net/http"
14 "net/smtp"
15 "os"
16 "path/filepath"
17 "sync"
18 "text/template"
19 "time"
20
21 "github.com/aws/aws-sdk-go/aws"
22 "github.com/aws/aws-sdk-go/aws/credentials"
23 "github.com/aws/aws-sdk-go/aws/session"
24 "github.com/aws/aws-sdk-go/service/s3"
25 "github.com/bluesky-social/indigo/api/atproto"
26 "github.com/bluesky-social/indigo/atproto/syntax"
27 "github.com/bluesky-social/indigo/events"
28 "github.com/bluesky-social/indigo/util"
29 "github.com/bluesky-social/indigo/xrpc"
30 "github.com/domodwyer/mailyak/v3"
31 "github.com/go-playground/validator"
32 "github.com/gorilla/sessions"
33 "github.com/haileyok/cocoon/identity"
34 "github.com/haileyok/cocoon/internal/db"
35 "github.com/haileyok/cocoon/internal/helpers"
36 "github.com/haileyok/cocoon/models"
37 "github.com/haileyok/cocoon/oauth/client"
38 "github.com/haileyok/cocoon/oauth/constants"
39 "github.com/haileyok/cocoon/oauth/dpop"
40 "github.com/haileyok/cocoon/oauth/provider"
41 "github.com/haileyok/cocoon/plc"
42 "github.com/ipfs/go-cid"
43 "github.com/labstack/echo-contrib/echoprometheus"
44 echo_session "github.com/labstack/echo-contrib/session"
45 "github.com/labstack/echo/v4"
46 "github.com/labstack/echo/v4/middleware"
47 slogecho "github.com/samber/slog-echo"
48 _ "github.com/tursodatabase/libsql-client-go/libsql"
49 "gorm.io/driver/postgres"
50 "gorm.io/driver/sqlite"
51 "gorm.io/gorm"
52)
53
54const (
55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
56)
57
58type S3Config struct {
59 BackupsEnabled bool
60 BlobstoreEnabled bool
61 Endpoint string
62 Region string
63 Bucket string
64 AccessKey string
65 SecretKey string
66 CDNUrl string
67}
68
69type Server struct {
70 http *http.Client
71 httpd *http.Server
72 mail *mailyak.MailYak
73 mailLk *sync.Mutex
74 echo *echo.Echo
75 db *db.DB
76 plcClient *plc.Client
77 logger *slog.Logger
78 config *config
79 privateKey *ecdsa.PrivateKey
80 repoman *RepoMan
81 oauthProvider *provider.Provider
82 evtman *events.EventManager
83 passport *identity.Passport
84 fallbackProxy string
85
86 lastRequestCrawl time.Time
87 requestCrawlMu sync.Mutex
88
89 dbName string
90 dbType string
91 s3Config *S3Config
92}
93
94type Args struct {
95 Logger *slog.Logger
96
97 LogLevel slog.Level
98 Addr string
99 DbName string
100 DbType string
101 DatabaseURL string
102 TursoToken string
103 Version string
104 Did string
105 Hostname string
106 RotationKeyPath string
107 JwkPath string
108 ContactEmail string
109 Relays []string
110 AdminPassword string
111 RequireInvite bool
112
113 SmtpUser string
114 SmtpPass string
115 SmtpHost string
116 SmtpPort string
117 SmtpEmail string
118 SmtpName string
119
120 S3Config *S3Config
121
122 SessionSecret string
123 SessionCookieKey string
124
125 BlockstoreVariant BlockstoreVariant
126 FallbackProxy string
127
128 PushBasedEvents bool
129 SubscribeReposServiceURL string
130 NonceSecret string
131}
132
133type config struct {
134 LogLevel slog.Level
135 Version string
136 Did string
137 Hostname string
138 ContactEmail string
139 EnforcePeering bool
140 Relays []string
141 AdminPassword string
142 RequireInvite bool
143 SmtpEmail string
144 SmtpName string
145 SessionCookieKey string
146 BlockstoreVariant BlockstoreVariant
147 FallbackProxy string
148
149 PushBasedEvents bool
150 SubscribeReposServiceURL string
151}
152
153type CustomValidator struct {
154 validator *validator.Validate
155}
156
157type ValidationError struct {
158 error
159 Field string
160 Tag string
161}
162
163func (cv *CustomValidator) Validate(i any) error {
164 if err := cv.validator.Struct(i); err != nil {
165 var validateErrors validator.ValidationErrors
166 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
167 first := validateErrors[0]
168 return ValidationError{
169 error: err,
170 Field: first.Field(),
171 Tag: first.Tag(),
172 }
173 }
174
175 return err
176 }
177
178 return nil
179}
180
181//go:embed templates/*
182var templateFS embed.FS
183
184//go:embed static/*
185var staticFS embed.FS
186
187type TemplateRenderer struct {
188 templates *template.Template
189 isDev bool
190 templatePath string
191}
192
193func (s *Server) loadTemplates() {
194 absPath, _ := filepath.Abs("server/templates/*.html")
195 if s.config.Version == "dev" {
196 tmpl := template.Must(template.ParseGlob(absPath))
197 s.echo.Renderer = &TemplateRenderer{
198 templates: tmpl,
199 isDev: true,
200 templatePath: absPath,
201 }
202 } else {
203 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
204 s.echo.Renderer = &TemplateRenderer{
205 templates: tmpl,
206 isDev: false,
207 }
208 }
209}
210
211func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
212 if t.isDev {
213 tmpl, err := template.ParseGlob(t.templatePath)
214 if err != nil {
215 return err
216 }
217 t.templates = tmpl
218 }
219
220 if viewContext, isMap := data.(map[string]any); isMap {
221 viewContext["reverse"] = c.Echo().Reverse
222 }
223
224 return t.templates.ExecuteTemplate(w, name, data)
225}
226
227type filteredHandler struct {
228 level slog.Level
229 handler slog.Handler
230}
231
232func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool {
233 return level >= h.level && h.handler.Enabled(ctx, level)
234}
235
236func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error {
237 return h.handler.Handle(ctx, r)
238}
239
240func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
241 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)}
242}
243
244func (h *filteredHandler) WithGroup(name string) slog.Handler {
245 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)}
246}
247
248func New(args *Args) (*Server, error) {
249 if args.Logger == nil {
250 args.Logger = slog.Default()
251 }
252
253 if args.LogLevel != 0 {
254 args.Logger = slog.New(&filteredHandler{
255 level: args.LogLevel,
256 handler: args.Logger.Handler(),
257 })
258 }
259
260 logger := args.Logger.With("name", "New")
261
262 if args.Addr == "" {
263 return nil, fmt.Errorf("addr must be set")
264 }
265
266 if args.DbName == "" {
267 return nil, fmt.Errorf("db name must be set")
268 }
269
270 if args.Did == "" {
271 return nil, fmt.Errorf("cocoon did must be set")
272 }
273
274 if args.ContactEmail == "" {
275 return nil, fmt.Errorf("cocoon contact email is required")
276 }
277
278 if _, err := syntax.ParseDID(args.Did); err != nil {
279 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
280 }
281
282 if args.Hostname == "" {
283 return nil, fmt.Errorf("cocoon hostname must be set")
284 }
285
286 if args.AdminPassword == "" {
287 return nil, fmt.Errorf("admin password must be set")
288 }
289
290 if args.SessionSecret == "" {
291 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
292 }
293
294 e := echo.New()
295
296 e.Pre(middleware.RemoveTrailingSlash())
297 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
298 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
299 e.Use(echoprometheus.NewMiddleware("cocoon"))
300 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
301 AllowOrigins: []string{"*"},
302 AllowHeaders: []string{"*"},
303 AllowMethods: []string{"*"},
304 AllowCredentials: true,
305 MaxAge: 100_000_000,
306 }))
307
308 vdtor := validator.New()
309 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
310 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
311 return false
312 }
313 return true
314 })
315 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
316 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
317 return false
318 }
319 return true
320 })
321 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
322 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
323 return false
324 }
325 return true
326 })
327 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
328 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
329 return false
330 }
331 return true
332 })
333
334 e.Validator = &CustomValidator{validator: vdtor}
335
336 httpd := &http.Server{
337 Addr: args.Addr,
338 Handler: e,
339 // shitty defaults but okay for now, needed for import repo
340 ReadTimeout: 5 * time.Minute,
341 WriteTimeout: 5 * time.Minute,
342 IdleTimeout: 5 * time.Minute,
343 }
344
345 dbType := args.DbType
346 if dbType == "" {
347 dbType = "sqlite"
348 }
349
350 var gdb *gorm.DB
351 gormCfg := gorm.Config{
352 TranslateError: true,
353 }
354 var err error
355 switch dbType {
356 case "postgres":
357 if args.DatabaseURL == "" {
358 return nil, fmt.Errorf("database-url must be set when using postgres")
359 }
360 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gormCfg)
361 if err != nil {
362 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
363 }
364 logger.Info("connected to PostgreSQL database")
365 case "turso":
366 primaryUrl := args.DatabaseURL
367 authToken := args.TursoToken
368
369 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken))
370 gdb, err = gorm.Open(sqlite.New(sqlite.Config{
371 Conn: db,
372 }), &gormCfg)
373 if err != nil {
374 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
375 }
376 logger.Info("connected to PostgreSQL database")
377
378 default:
379 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gormCfg)
380 if err != nil {
381 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
382 }
383 gdb.Exec("PRAGMA journal_mode=WAL")
384 gdb.Exec("PRAGMA synchronous=NORMAL")
385
386 logger.Info("connected to SQLite database", "path", args.DbName)
387 }
388 dbw := db.NewDB(gdb)
389
390 rkbytes, err := os.ReadFile(args.RotationKeyPath)
391 if err != nil {
392 return nil, err
393 }
394
395 h := util.RobustHTTPClient()
396
397 plcClient, err := plc.NewClient(&plc.ClientArgs{
398 H: h,
399 Service: "https://plc.directory",
400 PdsHostname: args.Hostname,
401 RotationKey: rkbytes,
402 })
403 if err != nil {
404 return nil, err
405 }
406
407 jwkbytes, err := os.ReadFile(args.JwkPath)
408 if err != nil {
409 return nil, err
410 }
411
412 key, err := helpers.ParseJWKFromBytes(jwkbytes)
413 if err != nil {
414 return nil, err
415 }
416
417 var pkey ecdsa.PrivateKey
418 if err := key.Raw(&pkey); err != nil {
419 return nil, err
420 }
421
422 oauthCli := &http.Client{
423 Timeout: 10 * time.Second,
424 }
425
426 var nonceSecret []byte
427 if args.NonceSecret != "" {
428 nonceSecret = []byte(args.NonceSecret)
429 } else {
430 maybeSecret, err := os.ReadFile("nonce.secret")
431 if err != nil && !os.IsNotExist(err) {
432 logger.Error("error attempting to read nonce secret", "error", err)
433 } else {
434 nonceSecret = maybeSecret
435 }
436 }
437
438 evtPersister, err := NewDbPersister(gdb, 72*time.Hour)
439 if err != nil {
440 return nil, fmt.Errorf("failed to create event persister: %w", err)
441 }
442
443 s := &Server{
444 http: h,
445 httpd: httpd,
446 echo: e,
447 logger: args.Logger,
448 db: dbw,
449 plcClient: plcClient,
450 privateKey: &pkey,
451 config: &config{
452 LogLevel: args.LogLevel,
453 Version: args.Version,
454 Did: args.Did,
455 Hostname: args.Hostname,
456 ContactEmail: args.ContactEmail,
457 EnforcePeering: false,
458 Relays: args.Relays,
459 AdminPassword: args.AdminPassword,
460 RequireInvite: args.RequireInvite,
461 SmtpName: args.SmtpName,
462 SmtpEmail: args.SmtpEmail,
463 SessionCookieKey: args.SessionCookieKey,
464 BlockstoreVariant: args.BlockstoreVariant,
465 FallbackProxy: args.FallbackProxy,
466 PushBasedEvents: args.PushBasedEvents,
467 SubscribeReposServiceURL: args.SubscribeReposServiceURL,
468 },
469 evtman: events.NewEventManager(evtPersister),
470 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
471
472 dbName: args.DbName,
473 dbType: dbType,
474 s3Config: args.S3Config,
475
476 oauthProvider: provider.NewProvider(provider.Args{
477 Hostname: args.Hostname,
478 ClientManagerArgs: client.ManagerArgs{
479 Cli: oauthCli,
480 Logger: args.Logger.With("component", "oauth-client-manager"),
481 },
482 DpopManagerArgs: dpop.ManagerArgs{
483 NonceSecret: nonceSecret,
484 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
485 OnNonceSecretCreated: func(newNonce []byte) {
486 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
487 logger.Error("error writing new nonce secret", "error", err)
488 }
489 },
490 Logger: args.Logger.With("component", "dpop-manager"),
491 Hostname: args.Hostname,
492 },
493 }),
494 }
495
496 s.loadTemplates()
497
498 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
499
500 // TODO: should validate these args
501 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
502 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
503 } else {
504 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
505 mail.From(s.config.SmtpEmail)
506 mail.FromName(s.config.SmtpName)
507
508 s.mail = mail
509 s.mailLk = &sync.Mutex{}
510 }
511
512 return s, nil
513}
514
515func (s *Server) addRoutes() {
516 // static
517 if s.config.Version == "dev" {
518 s.echo.Static("/static", "server/static")
519 } else {
520 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
521 }
522
523 // random stuff
524 s.echo.GET("/", s.handleRoot)
525 s.echo.GET("/xrpc/_health", s.handleHealth)
526 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
527 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
528 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
529 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
530 s.echo.GET("/robots.txt", s.handleRobots)
531
532 // public
533 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
534 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
535 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
536 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
537 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
538
539 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
540 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
541 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
542 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
543 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
544 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
545 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
546 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
547 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
548 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
549 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
550 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
551
552 // labels
553 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
554
555 // account
556 s.echo.GET("/account", s.handleAccount)
557 s.echo.POST("/account/revoke", s.handleAccountRevoke)
558 s.echo.GET("/account/signin", s.handleAccountSigninGet)
559 s.echo.POST("/account/signin", s.handleAccountSigninPost)
560 s.echo.GET("/account/signout", s.handleAccountSignout)
561
562 // oauth account
563 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
564 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
565 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
566
567 // oauth authorization
568 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
569 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
570
571 // authed
572 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
573 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
574 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
575 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
576 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
577 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
578 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
579 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
580 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
581 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
582 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
583 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
584 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
585 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
586 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
587 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
588 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
589 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
590 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
591 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
592
593 // repo
594 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
595 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
596 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
597 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
598 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
599 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
600 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
601
602 // stupid silly endpoints
603 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
604 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
605 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
606 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
607 // admin routes
608 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
609 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
610
611 // are there any routes that we should be allowing without auth? i dont think so but idk
612 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
613 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
614}
615
616func (s *Server) Serve(ctx context.Context) error {
617 logger := s.logger.With("name", "Serve")
618
619 s.addRoutes()
620
621 logger.Info("migrating...")
622
623 s.db.AutoMigrate(
624 &models.Actor{},
625 &models.Repo{},
626 &models.InviteCode{},
627 &models.Token{},
628 &models.RefreshToken{},
629 &models.Block{},
630 &models.Record{},
631 &models.Blob{},
632 &models.BlobPart{},
633 &models.ReservedKey{},
634 &provider.OauthToken{},
635 &provider.OauthAuthorizationRequest{},
636 )
637
638 logger.Info("starting cocoon")
639
640 go func() {
641 if err := s.httpd.ListenAndServe(); err != nil {
642 panic(err)
643 }
644 }()
645
646 go s.backupRoutine()
647
648 go func() {
649 if err := s.requestCrawl(ctx); err != nil {
650 logger.Error("error requesting crawls", "err", err)
651 }
652 }()
653
654 if s.config.PushBasedEvents {
655 slog.Info("pushed based events enabled")
656 go func() {
657 if err := s.emmitEvents(ctx); err != nil {
658 logger.Error("error emitting events", "err", err)
659 }
660 }()
661 }
662
663 <-ctx.Done()
664
665 fmt.Println("shut down")
666
667 return nil
668}
669
670func (s *Server) requestCrawl(ctx context.Context) error {
671 logger := s.logger.With("component", "request-crawl")
672 s.requestCrawlMu.Lock()
673 defer s.requestCrawlMu.Unlock()
674
675 logger.Info("requesting crawl with configured relays")
676
677 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
678 return fmt.Errorf("a crawl request has already been made within the last minute")
679 }
680
681 for _, relay := range s.config.Relays {
682 logger := logger.With("relay", relay)
683 logger.Info("requesting crawl from relay")
684 cli := xrpc.Client{Host: relay}
685 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
686 Hostname: s.config.Hostname,
687 }); err != nil {
688 logger.Error("error requesting crawl", "err", err)
689 } else {
690 logger.Info("crawl requested successfully")
691 }
692 }
693
694 s.lastRequestCrawl = time.Now()
695
696 return nil
697}
698
699func (s *Server) doBackup() {
700 logger := s.logger.With("name", "doBackup")
701
702 if s.dbType == "postgres" || s.dbType == "turso" {
703 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)")
704 return
705 }
706
707 start := time.Now()
708
709 logger.Info("beginning backup to s3...")
710
711 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
712 defer os.Remove(tmpFile)
713
714 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
715 logger.Error("error creating tmp backup file", "err", err)
716 return
717 }
718
719 backupData, err := os.ReadFile(tmpFile)
720 if err != nil {
721 logger.Error("error reading tmp backup file", "err", err)
722 return
723 }
724
725 logger.Info("sending to s3...")
726
727 currTime := time.Now().Format("2006-01-02_15-04-05")
728 key := "cocoon-backup-" + currTime + ".db"
729
730 config := &aws.Config{
731 Region: aws.String(s.s3Config.Region),
732 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
733 }
734
735 if s.s3Config.Endpoint != "" {
736 config.Endpoint = aws.String(s.s3Config.Endpoint)
737 config.S3ForcePathStyle = aws.Bool(true)
738 }
739
740 sess, err := session.NewSession(config)
741 if err != nil {
742 logger.Error("error creating s3 session", "err", err)
743 return
744 }
745
746 svc := s3.New(sess)
747
748 if _, err := svc.PutObject(&s3.PutObjectInput{
749 Bucket: aws.String(s.s3Config.Bucket),
750 Key: aws.String(key),
751 Body: bytes.NewReader(backupData),
752 }); err != nil {
753 logger.Error("error uploading file to s3", "err", err)
754 return
755 }
756
757 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
758
759 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
760}
761
762func (s *Server) backupRoutine() {
763 logger := s.logger.With("name", "backupRoutine")
764
765 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
766 return
767 }
768
769 if s.s3Config.Region == "" {
770 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
771 return
772 }
773
774 if s.s3Config.Bucket == "" {
775 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
776 return
777 }
778
779 if s.s3Config.AccessKey == "" {
780 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
781 return
782 }
783
784 if s.s3Config.SecretKey == "" {
785 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
786 return
787 }
788
789 shouldBackupNow := false
790 lastBackupStr, err := os.ReadFile("last-backup.txt")
791 if err != nil {
792 shouldBackupNow = true
793 } else {
794 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
795 if err != nil {
796 shouldBackupNow = true
797 } else if time.Since(lastBackup).Seconds() > 3600 {
798 shouldBackupNow = true
799 }
800 }
801
802 if shouldBackupNow {
803 go s.doBackup()
804 }
805
806 ticker := time.NewTicker(time.Hour)
807 for range ticker.C {
808 go s.doBackup()
809 }
810}
811
812func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
813 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
814 return err
815 }
816
817 return nil
818}