A fork of the Cocoon PDS but being made more distributed.
0

Configure Feed

Select the types of activity you want to include in your feed.

1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "embed" 8 "errors" 9 "fmt" 10 "io" 11 "log/slog" 12 "net/http" 13 "net/smtp" 14 "os" 15 "path/filepath" 16 "sync" 17 "text/template" 18 "time" 19 20 "github.com/aws/aws-sdk-go/aws" 21 "github.com/aws/aws-sdk-go/aws/credentials" 22 "github.com/aws/aws-sdk-go/aws/session" 23 "github.com/aws/aws-sdk-go/service/s3" 24 "github.com/bluesky-social/indigo/api/atproto" 25 "github.com/bluesky-social/indigo/atproto/syntax" 26 "github.com/bluesky-social/indigo/events" 27 "github.com/bluesky-social/indigo/util" 28 "github.com/bluesky-social/indigo/xrpc" 29 "github.com/domodwyer/mailyak/v3" 30 "github.com/go-playground/validator" 31 "github.com/gorilla/sessions" 32 "github.com/haileyok/cocoon/identity" 33 "github.com/haileyok/cocoon/internal/db" 34 "github.com/haileyok/cocoon/internal/helpers" 35 "github.com/haileyok/cocoon/models" 36 "github.com/haileyok/cocoon/oauth/client" 37 "github.com/haileyok/cocoon/oauth/constants" 38 "github.com/haileyok/cocoon/oauth/dpop" 39 "github.com/haileyok/cocoon/oauth/provider" 40 "github.com/haileyok/cocoon/plc" 41 "github.com/ipfs/go-cid" 42 "github.com/labstack/echo-contrib/echoprometheus" 43 echo_session "github.com/labstack/echo-contrib/session" 44 "github.com/labstack/echo/v4" 45 "github.com/labstack/echo/v4/middleware" 46 slogecho "github.com/samber/slog-echo" 47 "gorm.io/driver/postgres" 48 "gorm.io/driver/sqlite" 49 "gorm.io/gorm" 50) 51 52const ( 53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 54) 55 56type S3Config struct { 57 BackupsEnabled bool 58 BlobstoreEnabled bool 59 Endpoint string 60 Region string 61 Bucket string 62 AccessKey string 63 SecretKey string 64 CDNUrl string 65} 66 67type Server struct { 68 http *http.Client 69 httpd *http.Server 70 mail *mailyak.MailYak 71 mailLk *sync.Mutex 72 echo *echo.Echo 73 db *db.DB 74 plcClient *plc.Client 75 logger *slog.Logger 76 config *config 77 privateKey *ecdsa.PrivateKey 78 repoman *RepoMan 79 oauthProvider *provider.Provider 80 evtman *events.EventManager 81 passport *identity.Passport 82 fallbackProxy string 83 84 lastRequestCrawl time.Time 85 requestCrawlMu sync.Mutex 86 87 dbName string 88 dbType string 89 s3Config *S3Config 90} 91 92type Args struct { 93 Logger *slog.Logger 94 95 LogLevel slog.Level 96 Addr string 97 DbName string 98 DbType string 99 DatabaseURL string 100 Version string 101 Did string 102 Hostname string 103 RotationKeyPath string 104 JwkPath string 105 ContactEmail string 106 Relays []string 107 AdminPassword string 108 RequireInvite bool 109 110 SmtpUser string 111 SmtpPass string 112 SmtpHost string 113 SmtpPort string 114 SmtpEmail string 115 SmtpName string 116 117 S3Config *S3Config 118 119 SessionSecret string 120 SessionCookieKey string 121 122 BlockstoreVariant BlockstoreVariant 123 FallbackProxy string 124} 125 126type config struct { 127 LogLevel slog.Level 128 Version string 129 Did string 130 Hostname string 131 ContactEmail string 132 EnforcePeering bool 133 Relays []string 134 AdminPassword string 135 RequireInvite bool 136 SmtpEmail string 137 SmtpName string 138 SessionCookieKey string 139 BlockstoreVariant BlockstoreVariant 140 FallbackProxy string 141} 142 143type CustomValidator struct { 144 validator *validator.Validate 145} 146 147type ValidationError struct { 148 error 149 Field string 150 Tag string 151} 152 153func (cv *CustomValidator) Validate(i any) error { 154 if err := cv.validator.Struct(i); err != nil { 155 var validateErrors validator.ValidationErrors 156 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 157 first := validateErrors[0] 158 return ValidationError{ 159 error: err, 160 Field: first.Field(), 161 Tag: first.Tag(), 162 } 163 } 164 165 return err 166 } 167 168 return nil 169} 170 171//go:embed templates/* 172var templateFS embed.FS 173 174//go:embed static/* 175var staticFS embed.FS 176 177type TemplateRenderer struct { 178 templates *template.Template 179 isDev bool 180 templatePath string 181} 182 183func (s *Server) loadTemplates() { 184 absPath, _ := filepath.Abs("server/templates/*.html") 185 if s.config.Version == "dev" { 186 tmpl := template.Must(template.ParseGlob(absPath)) 187 s.echo.Renderer = &TemplateRenderer{ 188 templates: tmpl, 189 isDev: true, 190 templatePath: absPath, 191 } 192 } else { 193 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 194 s.echo.Renderer = &TemplateRenderer{ 195 templates: tmpl, 196 isDev: false, 197 } 198 } 199} 200 201func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 202 if t.isDev { 203 tmpl, err := template.ParseGlob(t.templatePath) 204 if err != nil { 205 return err 206 } 207 t.templates = tmpl 208 } 209 210 if viewContext, isMap := data.(map[string]any); isMap { 211 viewContext["reverse"] = c.Echo().Reverse 212 } 213 214 return t.templates.ExecuteTemplate(w, name, data) 215} 216 217type filteredHandler struct { 218 level slog.Level 219 handler slog.Handler 220} 221 222func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool { 223 return level >= h.level && h.handler.Enabled(ctx, level) 224} 225 226func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error { 227 return h.handler.Handle(ctx, r) 228} 229 230func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler { 231 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)} 232} 233 234func (h *filteredHandler) WithGroup(name string) slog.Handler { 235 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)} 236} 237 238func New(args *Args) (*Server, error) { 239 if args.Logger == nil { 240 args.Logger = slog.Default() 241 } 242 243 if args.LogLevel != 0 { 244 args.Logger = slog.New(&filteredHandler{ 245 level: args.LogLevel, 246 handler: args.Logger.Handler(), 247 }) 248 } 249 250 logger := args.Logger.With("name", "New") 251 252 if args.Addr == "" { 253 return nil, fmt.Errorf("addr must be set") 254 } 255 256 if args.DbName == "" { 257 return nil, fmt.Errorf("db name must be set") 258 } 259 260 if args.Did == "" { 261 return nil, fmt.Errorf("cocoon did must be set") 262 } 263 264 if args.ContactEmail == "" { 265 return nil, fmt.Errorf("cocoon contact email is required") 266 } 267 268 if _, err := syntax.ParseDID(args.Did); err != nil { 269 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 270 } 271 272 if args.Hostname == "" { 273 return nil, fmt.Errorf("cocoon hostname must be set") 274 } 275 276 if args.AdminPassword == "" { 277 return nil, fmt.Errorf("admin password must be set") 278 } 279 280 if args.SessionSecret == "" { 281 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 282 } 283 284 e := echo.New() 285 286 e.Pre(middleware.RemoveTrailingSlash()) 287 e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 288 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 289 e.Use(echoprometheus.NewMiddleware("cocoon")) 290 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 291 AllowOrigins: []string{"*"}, 292 AllowHeaders: []string{"*"}, 293 AllowMethods: []string{"*"}, 294 AllowCredentials: true, 295 MaxAge: 100_000_000, 296 })) 297 298 vdtor := validator.New() 299 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 300 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 301 return false 302 } 303 return true 304 }) 305 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 306 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 307 return false 308 } 309 return true 310 }) 311 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 312 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 313 return false 314 } 315 return true 316 }) 317 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 318 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 319 return false 320 } 321 return true 322 }) 323 324 e.Validator = &CustomValidator{validator: vdtor} 325 326 httpd := &http.Server{ 327 Addr: args.Addr, 328 Handler: e, 329 // shitty defaults but okay for now, needed for import repo 330 ReadTimeout: 5 * time.Minute, 331 WriteTimeout: 5 * time.Minute, 332 IdleTimeout: 5 * time.Minute, 333 } 334 335 dbType := args.DbType 336 if dbType == "" { 337 dbType = "sqlite" 338 } 339 340 var gdb *gorm.DB 341 var err error 342 switch dbType { 343 case "postgres": 344 if args.DatabaseURL == "" { 345 return nil, fmt.Errorf("database-url must be set when using postgres") 346 } 347 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{}) 348 if err != nil { 349 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 350 } 351 logger.Info("connected to PostgreSQL database") 352 default: 353 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{}) 354 if err != nil { 355 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 356 } 357 gdb.Exec("PRAGMA journal_mode=WAL") 358 gdb.Exec("PRAGMA synchronous=NORMAL") 359 360 logger.Info("connected to SQLite database", "path", args.DbName) 361 } 362 dbw := db.NewDB(gdb) 363 364 rkbytes, err := os.ReadFile(args.RotationKeyPath) 365 if err != nil { 366 return nil, err 367 } 368 369 h := util.RobustHTTPClient() 370 371 plcClient, err := plc.NewClient(&plc.ClientArgs{ 372 H: h, 373 Service: "https://plc.directory", 374 PdsHostname: args.Hostname, 375 RotationKey: rkbytes, 376 }) 377 if err != nil { 378 return nil, err 379 } 380 381 jwkbytes, err := os.ReadFile(args.JwkPath) 382 if err != nil { 383 return nil, err 384 } 385 386 key, err := helpers.ParseJWKFromBytes(jwkbytes) 387 if err != nil { 388 return nil, err 389 } 390 391 var pkey ecdsa.PrivateKey 392 if err := key.Raw(&pkey); err != nil { 393 return nil, err 394 } 395 396 oauthCli := &http.Client{ 397 Timeout: 10 * time.Second, 398 } 399 400 var nonceSecret []byte 401 maybeSecret, err := os.ReadFile("nonce.secret") 402 if err != nil && !os.IsNotExist(err) { 403 logger.Error("error attempting to read nonce secret", "error", err) 404 } else { 405 nonceSecret = maybeSecret 406 } 407 408 evtPersister, err := NewDbPersister(gdb, 72*time.Hour) 409 if err != nil { 410 return nil, fmt.Errorf("failed to create event persister: %w", err) 411 } 412 413 s := &Server{ 414 http: h, 415 httpd: httpd, 416 echo: e, 417 logger: args.Logger, 418 db: dbw, 419 plcClient: plcClient, 420 privateKey: &pkey, 421 config: &config{ 422 LogLevel: args.LogLevel, 423 Version: args.Version, 424 Did: args.Did, 425 Hostname: args.Hostname, 426 ContactEmail: args.ContactEmail, 427 EnforcePeering: false, 428 Relays: args.Relays, 429 AdminPassword: args.AdminPassword, 430 RequireInvite: args.RequireInvite, 431 SmtpName: args.SmtpName, 432 SmtpEmail: args.SmtpEmail, 433 SessionCookieKey: args.SessionCookieKey, 434 BlockstoreVariant: args.BlockstoreVariant, 435 FallbackProxy: args.FallbackProxy, 436 }, 437 evtman: events.NewEventManager(evtPersister), 438 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 439 440 dbName: args.DbName, 441 dbType: dbType, 442 s3Config: args.S3Config, 443 444 oauthProvider: provider.NewProvider(provider.Args{ 445 Hostname: args.Hostname, 446 ClientManagerArgs: client.ManagerArgs{ 447 Cli: oauthCli, 448 Logger: args.Logger.With("component", "oauth-client-manager"), 449 }, 450 DpopManagerArgs: dpop.ManagerArgs{ 451 NonceSecret: nonceSecret, 452 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 453 OnNonceSecretCreated: func(newNonce []byte) { 454 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 455 logger.Error("error writing new nonce secret", "error", err) 456 } 457 }, 458 Logger: args.Logger.With("component", "dpop-manager"), 459 Hostname: args.Hostname, 460 }, 461 }), 462 } 463 464 s.loadTemplates() 465 466 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 467 468 // TODO: should validate these args 469 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 470 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 471 } else { 472 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 473 mail.From(s.config.SmtpEmail) 474 mail.FromName(s.config.SmtpName) 475 476 s.mail = mail 477 s.mailLk = &sync.Mutex{} 478 } 479 480 return s, nil 481} 482 483func (s *Server) addRoutes() { 484 // static 485 if s.config.Version == "dev" { 486 s.echo.Static("/static", "server/static") 487 } else { 488 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 489 } 490 491 // random stuff 492 s.echo.GET("/", s.handleRoot) 493 s.echo.GET("/xrpc/_health", s.handleHealth) 494 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 495 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 496 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 497 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 498 s.echo.GET("/robots.txt", s.handleRobots) 499 500 // public 501 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 502 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 503 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 504 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 505 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 506 507 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 508 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 509 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 510 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 511 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 512 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 513 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 514 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 515 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 516 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 517 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 518 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 519 520 // labels 521 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 522 523 // account 524 s.echo.GET("/account", s.handleAccount) 525 s.echo.POST("/account/revoke", s.handleAccountRevoke) 526 s.echo.GET("/account/signin", s.handleAccountSigninGet) 527 s.echo.POST("/account/signin", s.handleAccountSigninPost) 528 s.echo.GET("/account/signout", s.handleAccountSignout) 529 530 // oauth account 531 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 532 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 533 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 534 535 // oauth authorization 536 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 537 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 538 539 // authed 540 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 541 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 542 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 543 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 544 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 545 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 546 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 547 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 548 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 549 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 550 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 551 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 552 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 553 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 554 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 555 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 556 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 557 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 558 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 559 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 560 561 // repo 562 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 563 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 564 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 565 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 566 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 567 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 568 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 569 570 // stupid silly endpoints 571 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 572 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 573 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 574 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 575 // admin routes 576 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 577 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 578 579 // are there any routes that we should be allowing without auth? i dont think so but idk 580 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 581 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 582} 583 584func (s *Server) Serve(ctx context.Context) error { 585 logger := s.logger.With("name", "Serve") 586 587 s.addRoutes() 588 589 logger.Info("migrating...") 590 591 s.db.AutoMigrate( 592 &models.Actor{}, 593 &models.Repo{}, 594 &models.InviteCode{}, 595 &models.Token{}, 596 &models.RefreshToken{}, 597 &models.Block{}, 598 &models.Record{}, 599 &models.Blob{}, 600 &models.BlobPart{}, 601 &models.ReservedKey{}, 602 &provider.OauthToken{}, 603 &provider.OauthAuthorizationRequest{}, 604 ) 605 606 logger.Info("starting cocoon") 607 608 go func() { 609 if err := s.httpd.ListenAndServe(); err != nil { 610 panic(err) 611 } 612 }() 613 614 go s.backupRoutine() 615 616 go func() { 617 if err := s.requestCrawl(ctx); err != nil { 618 logger.Error("error requesting crawls", "err", err) 619 } 620 }() 621 622 <-ctx.Done() 623 624 fmt.Println("shut down") 625 626 return nil 627} 628 629func (s *Server) requestCrawl(ctx context.Context) error { 630 logger := s.logger.With("component", "request-crawl") 631 s.requestCrawlMu.Lock() 632 defer s.requestCrawlMu.Unlock() 633 634 logger.Info("requesting crawl with configured relays") 635 636 if time.Since(s.lastRequestCrawl) <= 1*time.Minute { 637 return fmt.Errorf("a crawl request has already been made within the last minute") 638 } 639 640 for _, relay := range s.config.Relays { 641 logger := logger.With("relay", relay) 642 logger.Info("requesting crawl from relay") 643 cli := xrpc.Client{Host: relay} 644 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 645 Hostname: s.config.Hostname, 646 }); err != nil { 647 logger.Error("error requesting crawl", "err", err) 648 } else { 649 logger.Info("crawl requested successfully") 650 } 651 } 652 653 s.lastRequestCrawl = time.Now() 654 655 return nil 656} 657 658func (s *Server) doBackup() { 659 logger := s.logger.With("name", "doBackup") 660 661 if s.dbType == "postgres" { 662 logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)") 663 return 664 } 665 666 start := time.Now() 667 668 logger.Info("beginning backup to s3...") 669 670 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano)) 671 defer os.Remove(tmpFile) 672 673 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil { 674 logger.Error("error creating tmp backup file", "err", err) 675 return 676 } 677 678 backupData, err := os.ReadFile(tmpFile) 679 if err != nil { 680 logger.Error("error reading tmp backup file", "err", err) 681 return 682 } 683 684 logger.Info("sending to s3...") 685 686 currTime := time.Now().Format("2006-01-02_15-04-05") 687 key := "cocoon-backup-" + currTime + ".db" 688 689 config := &aws.Config{ 690 Region: aws.String(s.s3Config.Region), 691 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 692 } 693 694 if s.s3Config.Endpoint != "" { 695 config.Endpoint = aws.String(s.s3Config.Endpoint) 696 config.S3ForcePathStyle = aws.Bool(true) 697 } 698 699 sess, err := session.NewSession(config) 700 if err != nil { 701 logger.Error("error creating s3 session", "err", err) 702 return 703 } 704 705 svc := s3.New(sess) 706 707 if _, err := svc.PutObject(&s3.PutObjectInput{ 708 Bucket: aws.String(s.s3Config.Bucket), 709 Key: aws.String(key), 710 Body: bytes.NewReader(backupData), 711 }); err != nil { 712 logger.Error("error uploading file to s3", "err", err) 713 return 714 } 715 716 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds()) 717 718 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644) 719} 720 721func (s *Server) backupRoutine() { 722 logger := s.logger.With("name", "backupRoutine") 723 724 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 725 return 726 } 727 728 if s.s3Config.Region == "" { 729 logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 730 return 731 } 732 733 if s.s3Config.Bucket == "" { 734 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 735 return 736 } 737 738 if s.s3Config.AccessKey == "" { 739 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 740 return 741 } 742 743 if s.s3Config.SecretKey == "" { 744 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 745 return 746 } 747 748 shouldBackupNow := false 749 lastBackupStr, err := os.ReadFile("last-backup.txt") 750 if err != nil { 751 shouldBackupNow = true 752 } else { 753 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr)) 754 if err != nil { 755 shouldBackupNow = true 756 } else if time.Since(lastBackup).Seconds() > 3600 { 757 shouldBackupNow = true 758 } 759 } 760 761 if shouldBackupNow { 762 go s.doBackup() 763 } 764 765 ticker := time.NewTicker(time.Hour) 766 for range ticker.C { 767 go s.doBackup() 768 } 769} 770 771func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 772 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 773 return err 774 } 775 776 return nil 777}