A fork of the Cocoon PDS but being made more distributed.
0

Configure Feed

Select the types of activity you want to include in your feed.

1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "embed" 8 "errors" 9 "fmt" 10 "io" 11 "log/slog" 12 "net/http" 13 "net/smtp" 14 "os" 15 "path/filepath" 16 "sync" 17 "text/template" 18 "time" 19 20 "github.com/aws/aws-sdk-go/aws" 21 "github.com/aws/aws-sdk-go/aws/credentials" 22 "github.com/aws/aws-sdk-go/aws/session" 23 "github.com/aws/aws-sdk-go/service/s3" 24 "github.com/bluesky-social/indigo/api/atproto" 25 "github.com/bluesky-social/indigo/atproto/syntax" 26 "github.com/bluesky-social/indigo/events" 27 "github.com/bluesky-social/indigo/util" 28 "github.com/bluesky-social/indigo/xrpc" 29 "github.com/domodwyer/mailyak/v3" 30 "github.com/go-playground/validator" 31 "github.com/gorilla/sessions" 32 "github.com/haileyok/cocoon/identity" 33 "github.com/haileyok/cocoon/internal/db" 34 "github.com/haileyok/cocoon/internal/helpers" 35 "github.com/haileyok/cocoon/models" 36 "github.com/haileyok/cocoon/oauth/client" 37 "github.com/haileyok/cocoon/oauth/constants" 38 "github.com/haileyok/cocoon/oauth/dpop" 39 "github.com/haileyok/cocoon/oauth/provider" 40 "github.com/haileyok/cocoon/plc" 41 "github.com/ipfs/go-cid" 42 "github.com/labstack/echo-contrib/echoprometheus" 43 echo_session "github.com/labstack/echo-contrib/session" 44 "github.com/labstack/echo/v4" 45 "github.com/labstack/echo/v4/middleware" 46 slogecho "github.com/samber/slog-echo" 47 "gorm.io/driver/postgres" 48 "gorm.io/driver/sqlite" 49 "gorm.io/gorm" 50) 51 52const ( 53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 54) 55 56type S3Config struct { 57 BackupsEnabled bool 58 BlobstoreEnabled bool 59 Endpoint string 60 Region string 61 Bucket string 62 AccessKey string 63 SecretKey string 64 CDNUrl string 65} 66 67type Server struct { 68 http *http.Client 69 httpd *http.Server 70 mail *mailyak.MailYak 71 mailLk *sync.Mutex 72 echo *echo.Echo 73 db *db.DB 74 plcClient *plc.Client 75 logger *slog.Logger 76 config *config 77 privateKey *ecdsa.PrivateKey 78 repoman *RepoMan 79 oauthProvider *provider.Provider 80 evtman *events.EventManager 81 passport *identity.Passport 82 fallbackProxy string 83 84 lastRequestCrawl time.Time 85 requestCrawlMu sync.Mutex 86 87 dbName string 88 dbType string 89 s3Config *S3Config 90} 91 92type Args struct { 93 Logger *slog.Logger 94 95 LogLevel slog.Level 96 Addr string 97 DbName string 98 DbType string 99 DatabaseURL string 100 Version string 101 Did string 102 Hostname string 103 RotationKeyPath string 104 JwkPath string 105 ContactEmail string 106 Relays []string 107 AdminPassword string 108 RequireInvite bool 109 110 SmtpUser string 111 SmtpPass string 112 SmtpHost string 113 SmtpPort string 114 SmtpEmail string 115 SmtpName string 116 117 S3Config *S3Config 118 119 SessionSecret string 120 SessionCookieKey string 121 122 BlockstoreVariant BlockstoreVariant 123 FallbackProxy string 124} 125 126type config struct { 127 LogLevel slog.Level 128 Version string 129 Did string 130 Hostname string 131 ContactEmail string 132 EnforcePeering bool 133 Relays []string 134 AdminPassword string 135 RequireInvite bool 136 SmtpEmail string 137 SmtpName string 138 SessionCookieKey string 139 BlockstoreVariant BlockstoreVariant 140 FallbackProxy string 141} 142 143type CustomValidator struct { 144 validator *validator.Validate 145} 146 147type ValidationError struct { 148 error 149 Field string 150 Tag string 151} 152 153func (cv *CustomValidator) Validate(i any) error { 154 if err := cv.validator.Struct(i); err != nil { 155 var validateErrors validator.ValidationErrors 156 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 157 first := validateErrors[0] 158 return ValidationError{ 159 error: err, 160 Field: first.Field(), 161 Tag: first.Tag(), 162 } 163 } 164 165 return err 166 } 167 168 return nil 169} 170 171//go:embed templates/* 172var templateFS embed.FS 173 174//go:embed static/* 175var staticFS embed.FS 176 177type TemplateRenderer struct { 178 templates *template.Template 179 isDev bool 180 templatePath string 181} 182 183func (s *Server) loadTemplates() { 184 absPath, _ := filepath.Abs("server/templates/*.html") 185 if s.config.Version == "dev" { 186 tmpl := template.Must(template.ParseGlob(absPath)) 187 s.echo.Renderer = &TemplateRenderer{ 188 templates: tmpl, 189 isDev: true, 190 templatePath: absPath, 191 } 192 } else { 193 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 194 s.echo.Renderer = &TemplateRenderer{ 195 templates: tmpl, 196 isDev: false, 197 } 198 } 199} 200 201func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 202 if t.isDev { 203 tmpl, err := template.ParseGlob(t.templatePath) 204 if err != nil { 205 return err 206 } 207 t.templates = tmpl 208 } 209 210 if viewContext, isMap := data.(map[string]any); isMap { 211 viewContext["reverse"] = c.Echo().Reverse 212 } 213 214 return t.templates.ExecuteTemplate(w, name, data) 215} 216 217type filteredHandler struct { 218 level slog.Level 219 handler slog.Handler 220} 221 222func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool { 223 return level >= h.level && h.handler.Enabled(ctx, level) 224} 225 226func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error { 227 return h.handler.Handle(ctx, r) 228} 229 230func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler { 231 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)} 232} 233 234func (h *filteredHandler) WithGroup(name string) slog.Handler { 235 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)} 236} 237 238func New(args *Args) (*Server, error) { 239 if args.Logger == nil { 240 args.Logger = slog.Default() 241 } 242 243 if args.LogLevel != 0 { 244 args.Logger = slog.New(&filteredHandler{ 245 level: args.LogLevel, 246 handler: args.Logger.Handler(), 247 }) 248 } 249 250 logger := args.Logger.With("name", "New") 251 252 if args.Addr == "" { 253 return nil, fmt.Errorf("addr must be set") 254 } 255 256 if args.DbName == "" { 257 return nil, fmt.Errorf("db name must be set") 258 } 259 260 if args.Did == "" { 261 return nil, fmt.Errorf("cocoon did must be set") 262 } 263 264 if args.ContactEmail == "" { 265 return nil, fmt.Errorf("cocoon contact email is required") 266 } 267 268 if _, err := syntax.ParseDID(args.Did); err != nil { 269 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 270 } 271 272 if args.Hostname == "" { 273 return nil, fmt.Errorf("cocoon hostname must be set") 274 } 275 276 if args.AdminPassword == "" { 277 return nil, fmt.Errorf("admin password must be set") 278 } 279 280 if args.SessionSecret == "" { 281 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 282 } 283 284 e := echo.New() 285 286 e.Pre(middleware.RemoveTrailingSlash()) 287 e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 288 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 289 e.Use(echoprometheus.NewMiddleware("cocoon")) 290 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 291 AllowOrigins: []string{"*"}, 292 AllowHeaders: []string{"*"}, 293 AllowMethods: []string{"*"}, 294 AllowCredentials: true, 295 MaxAge: 100_000_000, 296 })) 297 298 vdtor := validator.New() 299 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 300 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 301 return false 302 } 303 return true 304 }) 305 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 306 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 307 return false 308 } 309 return true 310 }) 311 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 312 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 313 return false 314 } 315 return true 316 }) 317 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 318 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 319 return false 320 } 321 return true 322 }) 323 324 e.Validator = &CustomValidator{validator: vdtor} 325 326 httpd := &http.Server{ 327 Addr: args.Addr, 328 Handler: e, 329 // shitty defaults but okay for now, needed for import repo 330 ReadTimeout: 5 * time.Minute, 331 WriteTimeout: 5 * time.Minute, 332 IdleTimeout: 5 * time.Minute, 333 } 334 335 dbType := args.DbType 336 if dbType == "" { 337 dbType = "sqlite" 338 } 339 340 var gdb *gorm.DB 341 var err error 342 switch dbType { 343 case "postgres": 344 if args.DatabaseURL == "" { 345 return nil, fmt.Errorf("database-url must be set when using postgres") 346 } 347 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{}) 348 if err != nil { 349 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 350 } 351 logger.Info("connected to PostgreSQL database") 352 default: 353 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{}) 354 if err != nil { 355 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 356 } 357 gdb.Exec("PRAGMA journal_mode=WAL") 358 gdb.Exec("PRAGMA synchronous=NORMAL") 359 360 logger.Info("connected to SQLite database", "path", args.DbName) 361 } 362 dbw := db.NewDB(gdb) 363 364 rkbytes, err := os.ReadFile(args.RotationKeyPath) 365 if err != nil { 366 return nil, err 367 } 368 369 h := util.RobustHTTPClient() 370 371 plcClient, err := plc.NewClient(&plc.ClientArgs{ 372 H: h, 373 Service: "https://plc.directory", 374 PdsHostname: args.Hostname, 375 RotationKey: rkbytes, 376 }) 377 if err != nil { 378 return nil, err 379 } 380 381 jwkbytes, err := os.ReadFile(args.JwkPath) 382 if err != nil { 383 return nil, err 384 } 385 386 key, err := helpers.ParseJWKFromBytes(jwkbytes) 387 if err != nil { 388 return nil, err 389 } 390 391 var pkey ecdsa.PrivateKey 392 if err := key.Raw(&pkey); err != nil { 393 return nil, err 394 } 395 396 oauthCli := &http.Client{ 397 Timeout: 10 * time.Second, 398 } 399 400 var nonceSecret []byte 401 maybeSecret, err := os.ReadFile("nonce.secret") 402 if err != nil && !os.IsNotExist(err) { 403 logger.Error("error attempting to read nonce secret", "error", err) 404 } else { 405 nonceSecret = maybeSecret 406 } 407 408 evtPersister, err := NewDbPersister(gdb, 72*time.Hour) 409 if err != nil { 410 return nil, fmt.Errorf("failed to create event persister: %w", err) 411 } 412 413 s := &Server{ 414 http: h, 415 httpd: httpd, 416 echo: e, 417 logger: args.Logger, 418 db: dbw, 419 plcClient: plcClient, 420 privateKey: &pkey, 421 config: &config{ 422 LogLevel: args.LogLevel, 423 Version: args.Version, 424 Did: args.Did, 425 Hostname: args.Hostname, 426 ContactEmail: args.ContactEmail, 427 EnforcePeering: false, 428 Relays: args.Relays, 429 AdminPassword: args.AdminPassword, 430 RequireInvite: args.RequireInvite, 431 SmtpName: args.SmtpName, 432 SmtpEmail: args.SmtpEmail, 433 SessionCookieKey: args.SessionCookieKey, 434 BlockstoreVariant: args.BlockstoreVariant, 435 FallbackProxy: args.FallbackProxy, 436 }, 437 evtman: events.NewEventManager(evtPersister), 438 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 439 440 dbName: args.DbName, 441 dbType: dbType, 442 s3Config: args.S3Config, 443 444 oauthProvider: provider.NewProvider(provider.Args{ 445 Hostname: args.Hostname, 446 ClientManagerArgs: client.ManagerArgs{ 447 Cli: oauthCli, 448 Logger: args.Logger.With("component", "oauth-client-manager"), 449 }, 450 DpopManagerArgs: dpop.ManagerArgs{ 451 NonceSecret: nonceSecret, 452 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 453 OnNonceSecretCreated: func(newNonce []byte) { 454 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 455 logger.Error("error writing new nonce secret", "error", err) 456 } 457 }, 458 Logger: args.Logger.With("component", "dpop-manager"), 459 Hostname: args.Hostname, 460 }, 461 }), 462 } 463 464 s.loadTemplates() 465 466 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 467 468 // TODO: should validate these args 469 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 470 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 471 } else { 472 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 473 mail.From(s.config.SmtpEmail) 474 mail.FromName(s.config.SmtpName) 475 476 s.mail = mail 477 s.mailLk = &sync.Mutex{} 478 } 479 480 return s, nil 481} 482 483func (s *Server) addRoutes() { 484 // static 485 if s.config.Version == "dev" { 486 s.echo.Static("/static", "server/static") 487 } else { 488 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 489 } 490 491 // random stuff 492 s.echo.GET("/", s.handleRoot) 493 s.echo.GET("/xrpc/_health", s.handleHealth) 494 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 495 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 496 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 497 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 498 s.echo.GET("/robots.txt", s.handleRobots) 499 500 // public 501 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 502 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 503 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 504 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 505 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 506 507 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 508 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 509 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 510 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 511 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 512 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 513 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 514 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 515 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 516 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 517 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 518 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 519 520 // labels 521 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 522 523 // account 524 s.echo.GET("/account", s.handleAccount) 525 s.echo.POST("/account/revoke", s.handleAccountRevoke) 526 s.echo.POST("/account/switch", s.handleAccountSwitchPost) 527 s.echo.GET("/account/signin", s.handleAccountSigninGet) 528 s.echo.POST("/account/signin", s.handleAccountSigninPost) 529 s.echo.GET("/account/signout", s.handleAccountSignout) 530 531 // oauth account 532 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 533 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 534 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 535 536 // oauth authorization 537 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 538 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 539 540 // authed 541 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 542 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 543 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 544 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 545 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 546 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 547 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 548 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 549 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 550 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 551 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 552 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 553 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 554 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 555 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 556 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 557 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 558 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 559 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 560 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 561 562 // repo 563 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 564 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 565 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 566 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 567 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 568 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 569 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 570 571 // stupid silly endpoints 572 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 573 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 574 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 575 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 576 // admin routes 577 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 578 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 579 580 // are there any routes that we should be allowing without auth? i dont think so but idk 581 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 582 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 583} 584 585func (s *Server) Serve(ctx context.Context) error { 586 logger := s.logger.With("name", "Serve") 587 588 s.addRoutes() 589 590 logger.Info("migrating...") 591 592 s.db.AutoMigrate( 593 &models.Actor{}, 594 &models.Repo{}, 595 &models.InviteCode{}, 596 &models.Token{}, 597 &models.RefreshToken{}, 598 &models.Block{}, 599 &models.Record{}, 600 &models.Blob{}, 601 &models.BlobPart{}, 602 &models.ReservedKey{}, 603 &provider.OauthToken{}, 604 &provider.OauthAuthorizationRequest{}, 605 ) 606 607 logger.Info("starting cocoon") 608 609 go func() { 610 if err := s.httpd.ListenAndServe(); err != nil { 611 panic(err) 612 } 613 }() 614 615 go s.backupRoutine() 616 617 go func() { 618 if err := s.requestCrawl(ctx); err != nil { 619 logger.Error("error requesting crawls", "err", err) 620 } 621 }() 622 623 <-ctx.Done() 624 625 fmt.Println("shut down") 626 627 return nil 628} 629 630func (s *Server) requestCrawl(ctx context.Context) error { 631 logger := s.logger.With("component", "request-crawl") 632 s.requestCrawlMu.Lock() 633 defer s.requestCrawlMu.Unlock() 634 635 logger.Info("requesting crawl with configured relays") 636 637 if time.Since(s.lastRequestCrawl) <= 1*time.Minute { 638 return fmt.Errorf("a crawl request has already been made within the last minute") 639 } 640 641 for _, relay := range s.config.Relays { 642 logger := logger.With("relay", relay) 643 logger.Info("requesting crawl from relay") 644 cli := xrpc.Client{Host: relay} 645 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 646 Hostname: s.config.Hostname, 647 }); err != nil { 648 logger.Error("error requesting crawl", "err", err) 649 } else { 650 logger.Info("crawl requested successfully") 651 } 652 } 653 654 s.lastRequestCrawl = time.Now() 655 656 return nil 657} 658 659func (s *Server) doBackup() { 660 logger := s.logger.With("name", "doBackup") 661 662 if s.dbType == "postgres" { 663 logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)") 664 return 665 } 666 667 start := time.Now() 668 669 logger.Info("beginning backup to s3...") 670 671 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano)) 672 defer os.Remove(tmpFile) 673 674 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil { 675 logger.Error("error creating tmp backup file", "err", err) 676 return 677 } 678 679 backupData, err := os.ReadFile(tmpFile) 680 if err != nil { 681 logger.Error("error reading tmp backup file", "err", err) 682 return 683 } 684 685 logger.Info("sending to s3...") 686 687 currTime := time.Now().Format("2006-01-02_15-04-05") 688 key := "cocoon-backup-" + currTime + ".db" 689 690 config := &aws.Config{ 691 Region: aws.String(s.s3Config.Region), 692 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 693 } 694 695 if s.s3Config.Endpoint != "" { 696 config.Endpoint = aws.String(s.s3Config.Endpoint) 697 config.S3ForcePathStyle = aws.Bool(true) 698 } 699 700 sess, err := session.NewSession(config) 701 if err != nil { 702 logger.Error("error creating s3 session", "err", err) 703 return 704 } 705 706 svc := s3.New(sess) 707 708 if _, err := svc.PutObject(&s3.PutObjectInput{ 709 Bucket: aws.String(s.s3Config.Bucket), 710 Key: aws.String(key), 711 Body: bytes.NewReader(backupData), 712 }); err != nil { 713 logger.Error("error uploading file to s3", "err", err) 714 return 715 } 716 717 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds()) 718 719 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644) 720} 721 722func (s *Server) backupRoutine() { 723 logger := s.logger.With("name", "backupRoutine") 724 725 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 726 return 727 } 728 729 if s.s3Config.Region == "" { 730 logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 731 return 732 } 733 734 if s.s3Config.Bucket == "" { 735 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 736 return 737 } 738 739 if s.s3Config.AccessKey == "" { 740 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 741 return 742 } 743 744 if s.s3Config.SecretKey == "" { 745 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 746 return 747 } 748 749 shouldBackupNow := false 750 lastBackupStr, err := os.ReadFile("last-backup.txt") 751 if err != nil { 752 shouldBackupNow = true 753 } else { 754 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr)) 755 if err != nil { 756 shouldBackupNow = true 757 } else if time.Since(lastBackup).Seconds() > 3600 { 758 shouldBackupNow = true 759 } 760 } 761 762 if shouldBackupNow { 763 go s.doBackup() 764 } 765 766 ticker := time.NewTicker(time.Hour) 767 for range ticker.C { 768 go s.doBackup() 769 } 770} 771 772func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 773 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 774 return err 775 } 776 777 return nil 778}