A fork of the Cocoon PDS but being made more distributed.
0

Configure Feed

Select the types of activity you want to include in your feed.

1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "database/sql" 8 "embed" 9 "errors" 10 "fmt" 11 "io" 12 "log/slog" 13 "net/http" 14 "net/smtp" 15 "os" 16 "path/filepath" 17 "sync" 18 "text/template" 19 "time" 20 21 "github.com/aws/aws-sdk-go/aws" 22 "github.com/aws/aws-sdk-go/aws/credentials" 23 "github.com/aws/aws-sdk-go/aws/session" 24 "github.com/aws/aws-sdk-go/service/s3" 25 "github.com/bluesky-social/indigo/api/atproto" 26 "github.com/bluesky-social/indigo/atproto/syntax" 27 "github.com/bluesky-social/indigo/events" 28 "github.com/bluesky-social/indigo/util" 29 "github.com/bluesky-social/indigo/xrpc" 30 "github.com/domodwyer/mailyak/v3" 31 "github.com/go-playground/validator" 32 "github.com/gorilla/sessions" 33 "github.com/haileyok/cocoon/identity" 34 "github.com/haileyok/cocoon/internal/db" 35 "github.com/haileyok/cocoon/internal/helpers" 36 "github.com/haileyok/cocoon/models" 37 "github.com/haileyok/cocoon/oauth/client" 38 "github.com/haileyok/cocoon/oauth/constants" 39 "github.com/haileyok/cocoon/oauth/dpop" 40 "github.com/haileyok/cocoon/oauth/provider" 41 "github.com/haileyok/cocoon/plc" 42 "github.com/ipfs/go-cid" 43 "github.com/labstack/echo-contrib/echoprometheus" 44 echo_session "github.com/labstack/echo-contrib/session" 45 "github.com/labstack/echo/v4" 46 "github.com/labstack/echo/v4/middleware" 47 slogecho "github.com/samber/slog-echo" 48 _ "github.com/tursodatabase/libsql-client-go/libsql" 49 "gorm.io/driver/postgres" 50 "gorm.io/driver/sqlite" 51 "gorm.io/gorm" 52) 53 54const ( 55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 56) 57 58type S3Config struct { 59 BackupsEnabled bool 60 BlobstoreEnabled bool 61 Endpoint string 62 Region string 63 Bucket string 64 AccessKey string 65 SecretKey string 66 CDNUrl string 67} 68 69type Server struct { 70 http *http.Client 71 httpd *http.Server 72 mail *mailyak.MailYak 73 mailLk *sync.Mutex 74 echo *echo.Echo 75 db *db.DB 76 plcClient *plc.Client 77 logger *slog.Logger 78 config *config 79 privateKey *ecdsa.PrivateKey 80 repoman *RepoMan 81 oauthProvider *provider.Provider 82 evtman *events.EventManager 83 passport *identity.Passport 84 fallbackProxy string 85 86 lastRequestCrawl time.Time 87 requestCrawlMu sync.Mutex 88 89 dbName string 90 dbType string 91 s3Config *S3Config 92} 93 94type Args struct { 95 Logger *slog.Logger 96 97 LogLevel slog.Level 98 Addr string 99 DbName string 100 DbType string 101 DatabaseURL string 102 TursoToken string 103 Version string 104 Did string 105 Hostname string 106 RotationKeyPath string 107 JwkPath string 108 ContactEmail string 109 Relays []string 110 AdminPassword string 111 RequireInvite bool 112 113 SmtpUser string 114 SmtpPass string 115 SmtpHost string 116 SmtpPort string 117 SmtpEmail string 118 SmtpName string 119 120 S3Config *S3Config 121 122 SessionSecret string 123 SessionCookieKey string 124 125 BlockstoreVariant BlockstoreVariant 126 FallbackProxy string 127 128 PushBasedEvents bool 129 SubscribeReposServiceURL string 130 NonceSecret string 131} 132 133type config struct { 134 LogLevel slog.Level 135 Version string 136 Did string 137 Hostname string 138 ContactEmail string 139 EnforcePeering bool 140 Relays []string 141 AdminPassword string 142 RequireInvite bool 143 SmtpEmail string 144 SmtpName string 145 SessionCookieKey string 146 BlockstoreVariant BlockstoreVariant 147 FallbackProxy string 148 149 PushBasedEvents bool 150 SubscribeReposServiceURL string 151} 152 153type CustomValidator struct { 154 validator *validator.Validate 155} 156 157type ValidationError struct { 158 error 159 Field string 160 Tag string 161} 162 163func (cv *CustomValidator) Validate(i any) error { 164 if err := cv.validator.Struct(i); err != nil { 165 var validateErrors validator.ValidationErrors 166 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 167 first := validateErrors[0] 168 return ValidationError{ 169 error: err, 170 Field: first.Field(), 171 Tag: first.Tag(), 172 } 173 } 174 175 return err 176 } 177 178 return nil 179} 180 181//go:embed templates/* 182var templateFS embed.FS 183 184//go:embed static/* 185var staticFS embed.FS 186 187type TemplateRenderer struct { 188 templates *template.Template 189 isDev bool 190 templatePath string 191} 192 193func (s *Server) loadTemplates() { 194 absPath, _ := filepath.Abs("server/templates/*.html") 195 if s.config.Version == "dev" { 196 tmpl := template.Must(template.ParseGlob(absPath)) 197 s.echo.Renderer = &TemplateRenderer{ 198 templates: tmpl, 199 isDev: true, 200 templatePath: absPath, 201 } 202 } else { 203 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 204 s.echo.Renderer = &TemplateRenderer{ 205 templates: tmpl, 206 isDev: false, 207 } 208 } 209} 210 211func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 212 if t.isDev { 213 tmpl, err := template.ParseGlob(t.templatePath) 214 if err != nil { 215 return err 216 } 217 t.templates = tmpl 218 } 219 220 if viewContext, isMap := data.(map[string]any); isMap { 221 viewContext["reverse"] = c.Echo().Reverse 222 } 223 224 return t.templates.ExecuteTemplate(w, name, data) 225} 226 227type filteredHandler struct { 228 level slog.Level 229 handler slog.Handler 230} 231 232func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool { 233 return level >= h.level && h.handler.Enabled(ctx, level) 234} 235 236func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error { 237 return h.handler.Handle(ctx, r) 238} 239 240func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler { 241 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)} 242} 243 244func (h *filteredHandler) WithGroup(name string) slog.Handler { 245 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)} 246} 247 248func New(args *Args) (*Server, error) { 249 if args.Logger == nil { 250 args.Logger = slog.Default() 251 } 252 253 if args.LogLevel != 0 { 254 args.Logger = slog.New(&filteredHandler{ 255 level: args.LogLevel, 256 handler: args.Logger.Handler(), 257 }) 258 } 259 260 logger := args.Logger.With("name", "New") 261 262 if args.Addr == "" { 263 return nil, fmt.Errorf("addr must be set") 264 } 265 266 if args.DbName == "" { 267 return nil, fmt.Errorf("db name must be set") 268 } 269 270 if args.Did == "" { 271 return nil, fmt.Errorf("cocoon did must be set") 272 } 273 274 if args.ContactEmail == "" { 275 return nil, fmt.Errorf("cocoon contact email is required") 276 } 277 278 if _, err := syntax.ParseDID(args.Did); err != nil { 279 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 280 } 281 282 if args.Hostname == "" { 283 return nil, fmt.Errorf("cocoon hostname must be set") 284 } 285 286 if args.AdminPassword == "" { 287 return nil, fmt.Errorf("admin password must be set") 288 } 289 290 if args.SessionSecret == "" { 291 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 292 } 293 294 e := echo.New() 295 296 e.Pre(middleware.RemoveTrailingSlash()) 297 e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 298 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 299 e.Use(echoprometheus.NewMiddleware("cocoon")) 300 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 301 AllowOrigins: []string{"*"}, 302 AllowHeaders: []string{"*"}, 303 AllowMethods: []string{"*"}, 304 AllowCredentials: true, 305 MaxAge: 100_000_000, 306 })) 307 308 vdtor := validator.New() 309 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 310 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 311 return false 312 } 313 return true 314 }) 315 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 316 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 317 return false 318 } 319 return true 320 }) 321 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 322 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 323 return false 324 } 325 return true 326 }) 327 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 328 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 329 return false 330 } 331 return true 332 }) 333 334 e.Validator = &CustomValidator{validator: vdtor} 335 336 httpd := &http.Server{ 337 Addr: args.Addr, 338 Handler: e, 339 // shitty defaults but okay for now, needed for import repo 340 ReadTimeout: 5 * time.Minute, 341 WriteTimeout: 5 * time.Minute, 342 IdleTimeout: 5 * time.Minute, 343 } 344 345 dbType := args.DbType 346 if dbType == "" { 347 dbType = "sqlite" 348 } 349 350 var gdb *gorm.DB 351 gormCfg := gorm.Config{ 352 TranslateError: true, 353 } 354 var err error 355 switch dbType { 356 case "postgres": 357 if args.DatabaseURL == "" { 358 return nil, fmt.Errorf("database-url must be set when using postgres") 359 } 360 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gormCfg) 361 if err != nil { 362 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 363 } 364 logger.Info("connected to PostgreSQL database") 365 case "turso": 366 primaryUrl := args.DatabaseURL 367 authToken := args.TursoToken 368 369 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken)) 370 gdb, err = gorm.Open(sqlite.New(sqlite.Config{ 371 Conn: db, 372 }), &gormCfg) 373 if err != nil { 374 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 375 } 376 logger.Info("connected to PostgreSQL database") 377 378 default: 379 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gormCfg) 380 if err != nil { 381 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 382 } 383 gdb.Exec("PRAGMA journal_mode=WAL") 384 gdb.Exec("PRAGMA synchronous=NORMAL") 385 386 logger.Info("connected to SQLite database", "path", args.DbName) 387 } 388 dbw := db.NewDB(gdb) 389 390 rkbytes, err := os.ReadFile(args.RotationKeyPath) 391 if err != nil { 392 return nil, err 393 } 394 395 h := util.RobustHTTPClient() 396 397 plcClient, err := plc.NewClient(&plc.ClientArgs{ 398 H: h, 399 Service: "https://plc.directory", 400 PdsHostname: args.Hostname, 401 RotationKey: rkbytes, 402 }) 403 if err != nil { 404 return nil, err 405 } 406 407 jwkbytes, err := os.ReadFile(args.JwkPath) 408 if err != nil { 409 return nil, err 410 } 411 412 key, err := helpers.ParseJWKFromBytes(jwkbytes) 413 if err != nil { 414 return nil, err 415 } 416 417 var pkey ecdsa.PrivateKey 418 if err := key.Raw(&pkey); err != nil { 419 return nil, err 420 } 421 422 oauthCli := &http.Client{ 423 Timeout: 10 * time.Second, 424 } 425 426 var nonceSecret []byte 427 if args.NonceSecret != "" { 428 nonceSecret = []byte(args.NonceSecret) 429 } else { 430 maybeSecret, err := os.ReadFile("nonce.secret") 431 if err != nil && !os.IsNotExist(err) { 432 logger.Error("error attempting to read nonce secret", "error", err) 433 } else { 434 nonceSecret = maybeSecret 435 } 436 } 437 438 evtPersister, err := NewDbPersister(gdb, 72*time.Hour) 439 if err != nil { 440 return nil, fmt.Errorf("failed to create event persister: %w", err) 441 } 442 443 s := &Server{ 444 http: h, 445 httpd: httpd, 446 echo: e, 447 logger: args.Logger, 448 db: dbw, 449 plcClient: plcClient, 450 privateKey: &pkey, 451 config: &config{ 452 LogLevel: args.LogLevel, 453 Version: args.Version, 454 Did: args.Did, 455 Hostname: args.Hostname, 456 ContactEmail: args.ContactEmail, 457 EnforcePeering: false, 458 Relays: args.Relays, 459 AdminPassword: args.AdminPassword, 460 RequireInvite: args.RequireInvite, 461 SmtpName: args.SmtpName, 462 SmtpEmail: args.SmtpEmail, 463 SessionCookieKey: args.SessionCookieKey, 464 BlockstoreVariant: args.BlockstoreVariant, 465 FallbackProxy: args.FallbackProxy, 466 PushBasedEvents: args.PushBasedEvents, 467 SubscribeReposServiceURL: args.SubscribeReposServiceURL, 468 }, 469 evtman: events.NewEventManager(evtPersister), 470 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 471 472 dbName: args.DbName, 473 dbType: dbType, 474 s3Config: args.S3Config, 475 476 oauthProvider: provider.NewProvider(provider.Args{ 477 Hostname: args.Hostname, 478 ClientManagerArgs: client.ManagerArgs{ 479 Cli: oauthCli, 480 Logger: args.Logger.With("component", "oauth-client-manager"), 481 }, 482 DpopManagerArgs: dpop.ManagerArgs{ 483 NonceSecret: nonceSecret, 484 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 485 OnNonceSecretCreated: func(newNonce []byte) { 486 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 487 logger.Error("error writing new nonce secret", "error", err) 488 } 489 }, 490 Logger: args.Logger.With("component", "dpop-manager"), 491 Hostname: args.Hostname, 492 }, 493 }), 494 } 495 496 s.loadTemplates() 497 498 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 499 500 // TODO: should validate these args 501 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 502 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 503 } else { 504 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 505 mail.From(s.config.SmtpEmail) 506 mail.FromName(s.config.SmtpName) 507 508 s.mail = mail 509 s.mailLk = &sync.Mutex{} 510 } 511 512 return s, nil 513} 514 515func (s *Server) addRoutes() { 516 // static 517 if s.config.Version == "dev" { 518 s.echo.Static("/static", "server/static") 519 } else { 520 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 521 } 522 523 // random stuff 524 s.echo.GET("/", s.handleRoot) 525 s.echo.GET("/xrpc/_health", s.handleHealth) 526 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 527 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 528 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 529 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 530 s.echo.GET("/robots.txt", s.handleRobots) 531 532 // public 533 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 534 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 535 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 536 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 537 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 538 539 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 540 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 541 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 542 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 543 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 544 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 545 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 546 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 547 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 548 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 549 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 550 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 551 552 // labels 553 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 554 555 // account 556 s.echo.GET("/account", s.handleAccount) 557 s.echo.POST("/account/revoke", s.handleAccountRevoke) 558 s.echo.GET("/account/signin", s.handleAccountSigninGet) 559 s.echo.POST("/account/signin", s.handleAccountSigninPost) 560 s.echo.GET("/account/signout", s.handleAccountSignout) 561 562 // oauth account 563 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 564 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 565 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 566 567 // oauth authorization 568 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 569 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 570 571 // authed 572 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 573 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 574 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 575 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 576 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 577 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 578 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 579 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 580 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 581 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 582 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 583 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 584 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 585 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 586 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 587 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 588 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 589 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 590 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 591 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 592 593 // repo 594 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 595 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 596 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 597 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 598 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 599 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 600 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 601 602 // stupid silly endpoints 603 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 604 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 605 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 606 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 607 // admin routes 608 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 609 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 610 611 // are there any routes that we should be allowing without auth? i dont think so but idk 612 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 613 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 614} 615 616func (s *Server) Serve(ctx context.Context) error { 617 logger := s.logger.With("name", "Serve") 618 619 s.addRoutes() 620 621 logger.Info("migrating...") 622 623 s.db.AutoMigrate( 624 &models.Actor{}, 625 &models.Repo{}, 626 &models.InviteCode{}, 627 &models.Token{}, 628 &models.RefreshToken{}, 629 &models.Block{}, 630 &models.Record{}, 631 &models.Blob{}, 632 &models.BlobPart{}, 633 &models.ReservedKey{}, 634 &provider.OauthToken{}, 635 &provider.OauthAuthorizationRequest{}, 636 ) 637 638 logger.Info("starting cocoon") 639 640 go func() { 641 if err := s.httpd.ListenAndServe(); err != nil { 642 panic(err) 643 } 644 }() 645 646 go s.backupRoutine() 647 648 go func() { 649 if err := s.requestCrawl(ctx); err != nil { 650 logger.Error("error requesting crawls", "err", err) 651 } 652 }() 653 654 if s.config.PushBasedEvents { 655 slog.Info("pushed based events enabled") 656 go func() { 657 if err := s.emmitEvents(ctx); err != nil { 658 logger.Error("error emitting events", "err", err) 659 } 660 }() 661 } 662 663 <-ctx.Done() 664 665 fmt.Println("shut down") 666 667 return nil 668} 669 670func (s *Server) requestCrawl(ctx context.Context) error { 671 logger := s.logger.With("component", "request-crawl") 672 s.requestCrawlMu.Lock() 673 defer s.requestCrawlMu.Unlock() 674 675 logger.Info("requesting crawl with configured relays") 676 677 if time.Since(s.lastRequestCrawl) <= 1*time.Minute { 678 return fmt.Errorf("a crawl request has already been made within the last minute") 679 } 680 681 for _, relay := range s.config.Relays { 682 logger := logger.With("relay", relay) 683 logger.Info("requesting crawl from relay") 684 cli := xrpc.Client{Host: relay} 685 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 686 Hostname: s.config.Hostname, 687 }); err != nil { 688 logger.Error("error requesting crawl", "err", err) 689 } else { 690 logger.Info("crawl requested successfully") 691 } 692 } 693 694 s.lastRequestCrawl = time.Now() 695 696 return nil 697} 698 699func (s *Server) doBackup() { 700 logger := s.logger.With("name", "doBackup") 701 702 if s.dbType == "postgres" || s.dbType == "turso" { 703 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)") 704 return 705 } 706 707 start := time.Now() 708 709 logger.Info("beginning backup to s3...") 710 711 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano)) 712 defer os.Remove(tmpFile) 713 714 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil { 715 logger.Error("error creating tmp backup file", "err", err) 716 return 717 } 718 719 backupData, err := os.ReadFile(tmpFile) 720 if err != nil { 721 logger.Error("error reading tmp backup file", "err", err) 722 return 723 } 724 725 logger.Info("sending to s3...") 726 727 currTime := time.Now().Format("2006-01-02_15-04-05") 728 key := "cocoon-backup-" + currTime + ".db" 729 730 config := &aws.Config{ 731 Region: aws.String(s.s3Config.Region), 732 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 733 } 734 735 if s.s3Config.Endpoint != "" { 736 config.Endpoint = aws.String(s.s3Config.Endpoint) 737 config.S3ForcePathStyle = aws.Bool(true) 738 } 739 740 sess, err := session.NewSession(config) 741 if err != nil { 742 logger.Error("error creating s3 session", "err", err) 743 return 744 } 745 746 svc := s3.New(sess) 747 748 if _, err := svc.PutObject(&s3.PutObjectInput{ 749 Bucket: aws.String(s.s3Config.Bucket), 750 Key: aws.String(key), 751 Body: bytes.NewReader(backupData), 752 }); err != nil { 753 logger.Error("error uploading file to s3", "err", err) 754 return 755 } 756 757 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds()) 758 759 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644) 760} 761 762func (s *Server) backupRoutine() { 763 logger := s.logger.With("name", "backupRoutine") 764 765 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 766 return 767 } 768 769 if s.s3Config.Region == "" { 770 logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 771 return 772 } 773 774 if s.s3Config.Bucket == "" { 775 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 776 return 777 } 778 779 if s.s3Config.AccessKey == "" { 780 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 781 return 782 } 783 784 if s.s3Config.SecretKey == "" { 785 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 786 return 787 } 788 789 shouldBackupNow := false 790 lastBackupStr, err := os.ReadFile("last-backup.txt") 791 if err != nil { 792 shouldBackupNow = true 793 } else { 794 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr)) 795 if err != nil { 796 shouldBackupNow = true 797 } else if time.Since(lastBackup).Seconds() > 3600 { 798 shouldBackupNow = true 799 } 800 } 801 802 if shouldBackupNow { 803 go s.doBackup() 804 } 805 806 ticker := time.NewTicker(time.Hour) 807 for range ticker.C { 808 go s.doBackup() 809 } 810} 811 812func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 813 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 814 return err 815 } 816 817 return nil 818}