A fork of the Cocoon PDS but being made more distributed.
0

Configure Feed

Select the types of activity you want to include in your feed.

1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "database/sql" 8 "embed" 9 "errors" 10 "fmt" 11 "io" 12 "log/slog" 13 "net/http" 14 "net/smtp" 15 "os" 16 "path/filepath" 17 "sync" 18 "text/template" 19 "time" 20 21 "github.com/aws/aws-sdk-go/aws" 22 "github.com/aws/aws-sdk-go/aws/credentials" 23 "github.com/aws/aws-sdk-go/aws/session" 24 "github.com/aws/aws-sdk-go/service/s3" 25 "github.com/bluesky-social/indigo/api/atproto" 26 "github.com/bluesky-social/indigo/atproto/syntax" 27 "github.com/bluesky-social/indigo/events" 28 "github.com/bluesky-social/indigo/util" 29 "github.com/bluesky-social/indigo/xrpc" 30 "github.com/domodwyer/mailyak/v3" 31 "github.com/go-playground/validator" 32 "github.com/gorilla/sessions" 33 "github.com/haileyok/cocoon/identity" 34 "github.com/haileyok/cocoon/internal/db" 35 "github.com/haileyok/cocoon/internal/helpers" 36 "github.com/haileyok/cocoon/models" 37 "github.com/haileyok/cocoon/oauth/client" 38 "github.com/haileyok/cocoon/oauth/constants" 39 "github.com/haileyok/cocoon/oauth/dpop" 40 "github.com/haileyok/cocoon/oauth/provider" 41 "github.com/haileyok/cocoon/plc" 42 "github.com/ipfs/go-cid" 43 "github.com/labstack/echo-contrib/echoprometheus" 44 echo_session "github.com/labstack/echo-contrib/session" 45 "github.com/labstack/echo/v4" 46 "github.com/labstack/echo/v4/middleware" 47 slogecho "github.com/samber/slog-echo" 48 _ "github.com/tursodatabase/libsql-client-go/libsql" 49 "gorm.io/driver/postgres" 50 "gorm.io/driver/sqlite" 51 "gorm.io/gorm" 52) 53 54const ( 55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 56) 57 58type S3Config struct { 59 BackupsEnabled bool 60 BlobstoreEnabled bool 61 Endpoint string 62 Region string 63 Bucket string 64 AccessKey string 65 SecretKey string 66 CDNUrl string 67} 68 69type Server struct { 70 http *http.Client 71 httpd *http.Server 72 mail *mailyak.MailYak 73 mailLk *sync.Mutex 74 echo *echo.Echo 75 db *db.DB 76 plcClient *plc.Client 77 logger *slog.Logger 78 config *config 79 privateKey *ecdsa.PrivateKey 80 repoman *RepoMan 81 oauthProvider *provider.Provider 82 evtman *events.EventManager 83 passport *identity.Passport 84 fallbackProxy string 85 86 lastRequestCrawl time.Time 87 requestCrawlMu sync.Mutex 88 89 dbName string 90 dbType string 91 s3Config *S3Config 92} 93 94type Args struct { 95 Logger *slog.Logger 96 97 LogLevel slog.Level 98 Addr string 99 DbName string 100 DbType string 101 DatabaseURL string 102 TursoToken string 103 Version string 104 Did string 105 Hostname string 106 RotationKeyPath string 107 JwkPath string 108 ContactEmail string 109 Relays []string 110 AdminPassword string 111 RequireInvite bool 112 113 SmtpUser string 114 SmtpPass string 115 SmtpHost string 116 SmtpPort string 117 SmtpEmail string 118 SmtpName string 119 120 S3Config *S3Config 121 122 SessionSecret string 123 SessionCookieKey string 124 125 BlockstoreVariant BlockstoreVariant 126 FallbackProxy string 127 128 PushBasedEvents bool 129 SubscribeReposServiceURL string 130 NonceSecret string 131} 132 133type config struct { 134 LogLevel slog.Level 135 Version string 136 Did string 137 Hostname string 138 ContactEmail string 139 EnforcePeering bool 140 Relays []string 141 AdminPassword string 142 RequireInvite bool 143 SmtpEmail string 144 SmtpName string 145 SessionCookieKey string 146 BlockstoreVariant BlockstoreVariant 147 FallbackProxy string 148 149 PushBasedEvents bool 150 SubscribeReposServiceURL string 151} 152 153type CustomValidator struct { 154 validator *validator.Validate 155} 156 157type ValidationError struct { 158 error 159 Field string 160 Tag string 161} 162 163func (cv *CustomValidator) Validate(i any) error { 164 if err := cv.validator.Struct(i); err != nil { 165 var validateErrors validator.ValidationErrors 166 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 167 first := validateErrors[0] 168 return ValidationError{ 169 error: err, 170 Field: first.Field(), 171 Tag: first.Tag(), 172 } 173 } 174 175 return err 176 } 177 178 return nil 179} 180 181//go:embed templates/* 182var templateFS embed.FS 183 184//go:embed static/* 185var staticFS embed.FS 186 187type TemplateRenderer struct { 188 templates *template.Template 189 isDev bool 190 templatePath string 191} 192 193func (s *Server) loadTemplates() { 194 absPath, _ := filepath.Abs("server/templates/*.html") 195 if s.config.Version == "dev" { 196 tmpl := template.Must(template.ParseGlob(absPath)) 197 s.echo.Renderer = &TemplateRenderer{ 198 templates: tmpl, 199 isDev: true, 200 templatePath: absPath, 201 } 202 } else { 203 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 204 s.echo.Renderer = &TemplateRenderer{ 205 templates: tmpl, 206 isDev: false, 207 } 208 } 209} 210 211func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 212 if t.isDev { 213 tmpl, err := template.ParseGlob(t.templatePath) 214 if err != nil { 215 return err 216 } 217 t.templates = tmpl 218 } 219 220 if viewContext, isMap := data.(map[string]any); isMap { 221 viewContext["reverse"] = c.Echo().Reverse 222 } 223 224 return t.templates.ExecuteTemplate(w, name, data) 225} 226 227type filteredHandler struct { 228 level slog.Level 229 handler slog.Handler 230} 231 232func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool { 233 return level >= h.level && h.handler.Enabled(ctx, level) 234} 235 236func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error { 237 return h.handler.Handle(ctx, r) 238} 239 240func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler { 241 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)} 242} 243 244func (h *filteredHandler) WithGroup(name string) slog.Handler { 245 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)} 246} 247 248func New(args *Args) (*Server, error) { 249 if args.Logger == nil { 250 args.Logger = slog.Default() 251 } 252 253 if args.LogLevel != 0 { 254 args.Logger = slog.New(&filteredHandler{ 255 level: args.LogLevel, 256 handler: args.Logger.Handler(), 257 }) 258 } 259 260 logger := args.Logger.With("name", "New") 261 262 if args.Addr == "" { 263 return nil, fmt.Errorf("addr must be set") 264 } 265 266 if args.DbName == "" { 267 return nil, fmt.Errorf("db name must be set") 268 } 269 270 if args.Did == "" { 271 return nil, fmt.Errorf("cocoon did must be set") 272 } 273 274 if args.ContactEmail == "" { 275 return nil, fmt.Errorf("cocoon contact email is required") 276 } 277 278 if _, err := syntax.ParseDID(args.Did); err != nil { 279 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 280 } 281 282 if args.Hostname == "" { 283 return nil, fmt.Errorf("cocoon hostname must be set") 284 } 285 286 if args.AdminPassword == "" { 287 return nil, fmt.Errorf("admin password must be set") 288 } 289 290 if args.SessionSecret == "" { 291 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 292 } 293 294 e := echo.New() 295 296 e.Pre(middleware.RemoveTrailingSlash()) 297 e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 298 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 299 e.Use(echoprometheus.NewMiddleware("cocoon")) 300 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 301 AllowOrigins: []string{"*"}, 302 AllowHeaders: []string{"*"}, 303 AllowMethods: []string{"*"}, 304 AllowCredentials: true, 305 MaxAge: 100_000_000, 306 })) 307 308 vdtor := validator.New() 309 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 310 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 311 return false 312 } 313 return true 314 }) 315 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 316 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 317 return false 318 } 319 return true 320 }) 321 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 322 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 323 return false 324 } 325 return true 326 }) 327 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 328 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 329 return false 330 } 331 return true 332 }) 333 334 e.Validator = &CustomValidator{validator: vdtor} 335 336 httpd := &http.Server{ 337 Addr: args.Addr, 338 Handler: e, 339 // shitty defaults but okay for now, needed for import repo 340 ReadTimeout: 5 * time.Minute, 341 WriteTimeout: 5 * time.Minute, 342 IdleTimeout: 5 * time.Minute, 343 } 344 345 dbType := args.DbType 346 if dbType == "" { 347 dbType = "sqlite" 348 } 349 350 var gdb *gorm.DB 351 gormCfg := gorm.Config{ 352 TranslateError: true, 353 } 354 var err error 355 switch dbType { 356 case "postgres": 357 if args.DatabaseURL == "" { 358 return nil, fmt.Errorf("database-url must be set when using postgres") 359 } 360 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gormCfg) 361 if err != nil { 362 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 363 } 364 logger.Info("connected to PostgreSQL database") 365 case "turso": 366 primaryUrl := args.DatabaseURL 367 authToken := args.TursoToken 368 369 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken)) 370 gdb, err = gorm.Open(sqlite.New(sqlite.Config{ 371 Conn: db, 372 }), &gormCfg) 373 if err != nil { 374 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 375 } 376 logger.Info("connected to PostgreSQL database") 377 378 default: 379 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gormCfg) 380 if err != nil { 381 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 382 } 383 gdb.Exec("PRAGMA journal_mode=WAL") 384 gdb.Exec("PRAGMA synchronous=NORMAL") 385 386 logger.Info("connected to SQLite database", "path", args.DbName) 387 } 388 dbw := db.NewDB(gdb) 389 390 rkbytes, err := os.ReadFile(args.RotationKeyPath) 391 if err != nil { 392 return nil, err 393 } 394 395 h := util.RobustHTTPClient() 396 397 plcClient, err := plc.NewClient(&plc.ClientArgs{ 398 H: h, 399 Service: "https://plc.directory", 400 PdsHostname: args.Hostname, 401 RotationKey: rkbytes, 402 }) 403 if err != nil { 404 return nil, err 405 } 406 407 jwkbytes, err := os.ReadFile(args.JwkPath) 408 if err != nil { 409 return nil, err 410 } 411 412 key, err := helpers.ParseJWKFromBytes(jwkbytes) 413 if err != nil { 414 return nil, err 415 } 416 417 var pkey ecdsa.PrivateKey 418 if err := key.Raw(&pkey); err != nil { 419 return nil, err 420 } 421 422 oauthCli := &http.Client{ 423 Timeout: 10 * time.Second, 424 } 425 426 var nonceSecret []byte 427 if args.NonceSecret != "" { 428 nonceSecret = []byte(args.NonceSecret) 429 } else { 430 maybeSecret, err := os.ReadFile("nonce.secret") 431 if err != nil && !os.IsNotExist(err) { 432 logger.Error("error attempting to read nonce secret", "error", err) 433 } else { 434 nonceSecret = maybeSecret 435 } 436 } 437 438 evtPersister, err := NewDbPersister(gdb, 72*time.Hour) 439 if err != nil { 440 return nil, fmt.Errorf("failed to create event persister: %w", err) 441 } 442 443 s := &Server{ 444 http: h, 445 httpd: httpd, 446 echo: e, 447 logger: args.Logger, 448 db: dbw, 449 plcClient: plcClient, 450 privateKey: &pkey, 451 config: &config{ 452 LogLevel: args.LogLevel, 453 Version: args.Version, 454 Did: args.Did, 455 Hostname: args.Hostname, 456 ContactEmail: args.ContactEmail, 457 EnforcePeering: false, 458 Relays: args.Relays, 459 AdminPassword: args.AdminPassword, 460 RequireInvite: args.RequireInvite, 461 SmtpName: args.SmtpName, 462 SmtpEmail: args.SmtpEmail, 463 SessionCookieKey: args.SessionCookieKey, 464 BlockstoreVariant: args.BlockstoreVariant, 465 FallbackProxy: args.FallbackProxy, 466 PushBasedEvents: args.PushBasedEvents, 467 SubscribeReposServiceURL: args.SubscribeReposServiceURL, 468 }, 469 evtman: events.NewEventManager(evtPersister), 470 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 471 472 dbName: args.DbName, 473 dbType: dbType, 474 s3Config: args.S3Config, 475 476 oauthProvider: provider.NewProvider(provider.Args{ 477 Hostname: args.Hostname, 478 ClientManagerArgs: client.ManagerArgs{ 479 Cli: oauthCli, 480 Logger: args.Logger.With("component", "oauth-client-manager"), 481 }, 482 DpopManagerArgs: dpop.ManagerArgs{ 483 NonceSecret: nonceSecret, 484 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 485 OnNonceSecretCreated: func(newNonce []byte) { 486 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 487 logger.Error("error writing new nonce secret", "error", err) 488 } 489 }, 490 Logger: args.Logger.With("component", "dpop-manager"), 491 Hostname: args.Hostname, 492 }, 493 }), 494 } 495 496 s.loadTemplates() 497 498 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 499 500 // TODO: should validate these args 501 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 502 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 503 } else { 504 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 505 mail.From(s.config.SmtpEmail) 506 mail.FromName(s.config.SmtpName) 507 508 s.mail = mail 509 s.mailLk = &sync.Mutex{} 510 } 511 512 return s, nil 513} 514 515func (s *Server) addRoutes() { 516 // static 517 if s.config.Version == "dev" { 518 s.echo.Static("/static", "server/static") 519 } else { 520 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 521 } 522 523 // random stuff 524 s.echo.GET("/", s.handleRoot) 525 s.echo.GET("/xrpc/_health", s.handleHealth) 526 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 527 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 528 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 529 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 530 s.echo.GET("/robots.txt", s.handleRobots) 531 532 // public 533 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 534 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 535 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 536 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 537 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 538 539 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 540 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 541 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 542 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 543 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 544 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 545 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 546 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 547 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 548 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 549 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 550 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 551 552 // labels 553 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 554 555 // account 556 s.echo.GET("/account", s.handleAccount) 557 s.echo.POST("/account/revoke", s.handleAccountRevoke) 558 s.echo.POST("/account/switch", s.handleAccountSwitchPost) 559 s.echo.GET("/account/signin", s.handleAccountSigninGet) 560 s.echo.POST("/account/signin", s.handleAccountSigninPost) 561 s.echo.GET("/account/signout", s.handleAccountSignout) 562 563 // oauth account 564 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 565 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 566 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 567 568 // oauth authorization 569 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 570 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 571 572 // authed 573 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 574 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 575 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 576 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 577 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 578 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 579 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 580 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 581 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 582 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 583 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 584 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 585 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 586 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 587 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 588 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 589 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 590 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 591 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 592 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 593 594 // repo 595 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 596 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 597 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 598 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 599 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 600 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 601 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 602 603 // stupid silly endpoints 604 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 605 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 606 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 607 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 608 // admin routes 609 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 610 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 611 612 // are there any routes that we should be allowing without auth? i dont think so but idk 613 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 614 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 615} 616 617func (s *Server) Serve(ctx context.Context) error { 618 logger := s.logger.With("name", "Serve") 619 620 s.addRoutes() 621 622 logger.Info("migrating...") 623 624 s.db.AutoMigrate( 625 &models.Actor{}, 626 &models.Repo{}, 627 &models.InviteCode{}, 628 &models.Token{}, 629 &models.RefreshToken{}, 630 &models.Block{}, 631 &models.Record{}, 632 &models.Blob{}, 633 &models.BlobPart{}, 634 &models.ReservedKey{}, 635 &provider.OauthToken{}, 636 &provider.OauthAuthorizationRequest{}, 637 ) 638 639 logger.Info("starting cocoon") 640 641 go func() { 642 if err := s.httpd.ListenAndServe(); err != nil { 643 panic(err) 644 } 645 }() 646 647 go s.backupRoutine() 648 649 go func() { 650 if err := s.requestCrawl(ctx); err != nil { 651 logger.Error("error requesting crawls", "err", err) 652 } 653 }() 654 655 if s.config.PushBasedEvents { 656 slog.Info("pushed based events enabled") 657 go func() { 658 if err := s.emmitEvents(ctx); err != nil { 659 logger.Error("error emitting events", "err", err) 660 } 661 }() 662 } 663 664 <-ctx.Done() 665 666 fmt.Println("shut down") 667 668 return nil 669} 670 671func (s *Server) requestCrawl(ctx context.Context) error { 672 logger := s.logger.With("component", "request-crawl") 673 s.requestCrawlMu.Lock() 674 defer s.requestCrawlMu.Unlock() 675 676 logger.Info("requesting crawl with configured relays") 677 678 if time.Since(s.lastRequestCrawl) <= 1*time.Minute { 679 return fmt.Errorf("a crawl request has already been made within the last minute") 680 } 681 682 for _, relay := range s.config.Relays { 683 logger := logger.With("relay", relay) 684 logger.Info("requesting crawl from relay") 685 cli := xrpc.Client{Host: relay} 686 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 687 Hostname: s.config.Hostname, 688 }); err != nil { 689 logger.Error("error requesting crawl", "err", err) 690 } else { 691 logger.Info("crawl requested successfully") 692 } 693 } 694 695 s.lastRequestCrawl = time.Now() 696 697 return nil 698} 699 700func (s *Server) doBackup() { 701 logger := s.logger.With("name", "doBackup") 702 703 if s.dbType == "postgres" || s.dbType == "turso" { 704 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)") 705 return 706 } 707 708 start := time.Now() 709 710 logger.Info("beginning backup to s3...") 711 712 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano)) 713 defer os.Remove(tmpFile) 714 715 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil { 716 logger.Error("error creating tmp backup file", "err", err) 717 return 718 } 719 720 backupData, err := os.ReadFile(tmpFile) 721 if err != nil { 722 logger.Error("error reading tmp backup file", "err", err) 723 return 724 } 725 726 logger.Info("sending to s3...") 727 728 currTime := time.Now().Format("2006-01-02_15-04-05") 729 key := "cocoon-backup-" + currTime + ".db" 730 731 config := &aws.Config{ 732 Region: aws.String(s.s3Config.Region), 733 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 734 } 735 736 if s.s3Config.Endpoint != "" { 737 config.Endpoint = aws.String(s.s3Config.Endpoint) 738 config.S3ForcePathStyle = aws.Bool(true) 739 } 740 741 sess, err := session.NewSession(config) 742 if err != nil { 743 logger.Error("error creating s3 session", "err", err) 744 return 745 } 746 747 svc := s3.New(sess) 748 749 if _, err := svc.PutObject(&s3.PutObjectInput{ 750 Bucket: aws.String(s.s3Config.Bucket), 751 Key: aws.String(key), 752 Body: bytes.NewReader(backupData), 753 }); err != nil { 754 logger.Error("error uploading file to s3", "err", err) 755 return 756 } 757 758 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds()) 759 760 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644) 761} 762 763func (s *Server) backupRoutine() { 764 logger := s.logger.With("name", "backupRoutine") 765 766 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 767 return 768 } 769 770 if s.s3Config.Region == "" { 771 logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 772 return 773 } 774 775 if s.s3Config.Bucket == "" { 776 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 777 return 778 } 779 780 if s.s3Config.AccessKey == "" { 781 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 782 return 783 } 784 785 if s.s3Config.SecretKey == "" { 786 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 787 return 788 } 789 790 shouldBackupNow := false 791 lastBackupStr, err := os.ReadFile("last-backup.txt") 792 if err != nil { 793 shouldBackupNow = true 794 } else { 795 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr)) 796 if err != nil { 797 shouldBackupNow = true 798 } else if time.Since(lastBackup).Seconds() > 3600 { 799 shouldBackupNow = true 800 } 801 } 802 803 if shouldBackupNow { 804 go s.doBackup() 805 } 806 807 ticker := time.NewTicker(time.Hour) 808 for range ticker.C { 809 go s.doBackup() 810 } 811} 812 813func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 814 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 815 return err 816 } 817 818 return nil 819}