diff --git a/cmd/migrate.go b/cmd/migrate.go index 39a574cadd10d2efbd4a4437a188ed392cb516a5_Y21kL21pZ3JhdGUuZ28=..5021a3b4a23a2900eb97308912ebd136937e3c32_Y21kL21pZ3JhdGUuZ28= 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -1,8 +1,9 @@ package cmd import ( + "database/sql" "errors" "fmt" "github.com/golang-migrate/migrate/v4" "github.com/rs/zerolog" @@ -4,13 +5,11 @@ "errors" "fmt" "github.com/golang-migrate/migrate/v4" "github.com/rs/zerolog" - - "orus.io/orus-io/go-orusapi/database" ) type MigrateCmd[E any] struct { program *Program[E] } @@ -11,9 +10,9 @@ ) type MigrateCmd[E any] struct { program *Program[E] } -func (cmd *MigrateCmd[E]) Execute([]string) error { +func (cmd *MigrateCmd[E]) Execute([]string) (finalErr error) { cmd.program.LoggingOptions.SetMinLoggingLevel(zerolog.InfoLevel) log := cmd.program.Logger @@ -18,4 +17,5 @@ cmd.program.LoggingOptions.SetMinLoggingLevel(zerolog.InfoLevel) log := cmd.program.Logger - m, err := database.NewMigrate(cmd.program.DatabaseOptions.DSN, cmd.program.dbMigrateSource) + + db, err := sql.Open("postgres", cmd.program.DatabaseOptions.DSN) if err != nil { @@ -21,4 +21,4 @@ if err != nil { - return fmt.Errorf("failed to init migration engine: %w", err) + return err } defer func() { @@ -23,12 +23,7 @@ } defer func() { - if sourceErr, databaseErr := m.Close(); sourceErr != nil || databaseErr != nil { - if sourceErr != nil { - log.Err(err).Msg("error closing Migrate source") - } - if databaseErr != nil { - log.Err(err).Msg("error closing Migrate database") - } + if err := db.Close(); err != nil && finalErr != nil { + finalErr = fmt.Errorf("could not close db: %w", err) } }() @@ -32,7 +27,7 @@ } }() - if err := m.Up(); err != nil { + if err := cmd.program.dbSchema.Migrate(db, cmd.program.Logger); err != nil { if errors.Is(err, migrate.ErrNoChange) { log.Info().Msg("The database is already up-to-date") } else { diff --git a/cmd/program.go b/cmd/program.go index 39a574cadd10d2efbd4a4437a188ed392cb516a5_Y21kL3Byb2dyYW0uZ28=..5021a3b4a23a2900eb97308912ebd136937e3c32_Y21kL3Byb2dyYW0uZ28= 100644 --- a/cmd/program.go +++ b/cmd/program.go @@ -40,6 +40,7 @@ } type Program[E any] struct { + Name string Version Version BootstrapParser *flags.Parser Parser *flags.Parser @@ -62,8 +63,8 @@ MigrateCmd *MigrateCmd[E] GenerateConfigCmd *GenerateConfigCmd - hasDB bool - dbMigrateSource source.Driver + hasDB bool + dbSchema database.SchemaSet middlewares []Middleware @@ -67,7 +68,8 @@ middlewares []Middleware - setupXbusActors func(*Program[E]) []XbusActorFactory + hasXbus bool + setupXbusActors []func(*Program[E]) []XbusActorFactory xbusActorNames []string setupHandler func(*Program[E]) http.Handler setupSubcommands []subcommand[E] @@ -117,5 +119,13 @@ func WithXbusActors[E any](factories func(*Program[E]) []XbusActorFactory) Option[E] { return func(program *Program[E]) { - program.setupXbusActors = factories + program.setupXbusActors = append(program.setupXbusActors, factories) + + if !program.hasXbus { + g, err := program.Parser.AddGroup("Xbus", "Xbus options", &program.XbusOptions) + if err != nil { + panic(err) + } + g.Namespace = "xbus" + g.EnvNamespace = "XBUS" @@ -121,5 +131,4 @@ - g, err := program.Parser.AddGroup("Xbus", "Xbus options", &program.XbusOptions) - if err != nil { - panic(err) + SetupXbusCmd(program) + program.hasXbus = true } @@ -125,4 +134,4 @@ } - g.Namespace = "xbus" - g.EnvNamespace = "XBUS" + } +} @@ -128,5 +137,9 @@ - SetupXbusCmd(program) +func CombineOptions[E any](options ...Option[E]) Option[E] { + return func(program *Program[E]) { + for _, option := range options { + option(program) + } } } @@ -154,7 +167,18 @@ func WithDatabase[E any](migrateSource source.Driver) Option[E] { return func(program *Program[E]) { program.hasDB = true - program.dbMigrateSource = migrateSource + program.dbSchema.Main.Source = migrateSource + } +} + +func WithSchema[E any](name string, migrateSource source.Driver) Option[E] { + return func(program *Program[E]) { + program.dbSchema.Extra = append( + program.dbSchema.Extra, + database.Schema{ + Name: name, Source: migrateSource, + }, + ) } } @@ -185,8 +209,9 @@ parser := flags.NewNamedParser(name, flags.HelpFlag|flags.PassDoubleDash) program := Program[E]{ + Name: name, Ext: ext, Version: version, BootstrapParser: bootstrapParser, Parser: parser, Logger: orusapi.DefaultLogger(os.Stdout), @@ -188,8 +213,9 @@ Ext: ext, Version: version, BootstrapParser: bootstrapParser, Parser: parser, Logger: orusapi.DefaultLogger(os.Stdout), + dbSchema: database.SchemaSet{Main: database.Schema{Name: name}}, } for _, opt := range options { @@ -334,6 +360,9 @@ if !program.hasDB { return nil } + if program.DB != nil { + return nil + } program.Logger.Debug().Msg("Connecting to the database...") @@ -337,21 +366,7 @@ program.Logger.Debug().Msg("Connecting to the database...") - if automigrate { - if err := database.AutoMigrate( - program.DatabaseOptions.DSN, program.dbMigrateSource, program.Logger, - ); err != nil { - return err - } - } else { - if err := database.IsUptodate( - program.DatabaseOptions.DSN, program.dbMigrateSource, - ); err != nil { - return err - } - } - db, err := program.DatabaseOptions.Open() if err != nil { return err } @@ -354,9 +369,16 @@ db, err := program.DatabaseOptions.Open() if err != nil { return err } + defer func() { + if program.DB == nil { + if err := db.Close(); err != nil { + program.Logger.Err(err).Msg("could not close database connection") + } + } + }() if err := db.Ping(); err != nil { return err } @@ -358,7 +380,19 @@ if err := db.Ping(); err != nil { return err } + if err := program.dbSchema.IsUptodate(db.DB); err != nil { + if automigrate && errors.Is(err, database.ErrDBNeedUpgrade) { + program.Logger.Warn().Msg("Database is not up-to-date, will auto-migrate...") + if err := program.dbSchema.Migrate(db.DB, program.Logger); err != nil { + return err + } + program.Logger.Info().Msg("Database successfully migrated") + } else { + return err + } + } + program.DB = db @@ -363,6 +397,6 @@ program.DB = db - program.Logger.Info().Msg("Connected to the database") + program.Logger.Info().Msg("Database is ready") return nil } diff --git a/cmd/xbus.go b/cmd/xbus.go index 39a574cadd10d2efbd4a4437a188ed392cb516a5_Y21kL3hidXMuZ28=..5021a3b4a23a2900eb97308912ebd136937e3c32_Y21kL3hidXMuZ28= 100644 --- a/cmd/xbus.go +++ b/cmd/xbus.go @@ -36,13 +36,30 @@ } func (p *Program[E]) xbusRegisterActors() error { - for _, factory := range p.setupXbusActors(p) { - switch f := factory.Factory.(type) { - case service.ConsumerFactory: - if err := service.RegisterConsumer(factory.Name, f); err != nil { - return err - } - case service.ConsumerFunc: - if err := service.RegisterConsumerFunc(factory.Name, f); err != nil { - return err + for _, ff := range p.setupXbusActors { + for _, factory := range ff(p) { + switch f := factory.Factory.(type) { + case service.ConsumerFactory: + if err := service.RegisterConsumer(factory.Name, f); err != nil { + return err + } + case service.ConsumerFunc: + if err := service.RegisterConsumerFunc(factory.Name, f); err != nil { + return err + } + case service.WorkerFactory: + if err := service.RegisterWorker(factory.Name, f); err != nil { + return err + } + case service.WorkerFunc: + if err := service.RegisterWorkerFunc(factory.Name, f); err != nil { + return err + } + case xbus.NewActorServiceFunc: + xbus.RegisterActorService(factory.Name, f) + default: + return fmt.Errorf( + "%s type is not a recognised as a factory: %t", + factory.Name, factory.Factory, + ) } @@ -48,14 +65,3 @@ } - case service.WorkerFactory: - if err := service.RegisterWorker(factory.Name, f); err != nil { - return err - } - case service.WorkerFunc: - if err := service.RegisterWorkerFunc(factory.Name, f); err != nil { - return err - } - case xbus.NewActorServiceFunc: - xbus.RegisterActorService(factory.Name, f) - default: - return fmt.Errorf("%s has an factory type: %t", factory.Name, factory.Factory) + p.xbusActorNames = append(p.xbusActorNames, factory.Name) } @@ -61,5 +67,4 @@ } - p.xbusActorNames = append(p.xbusActorNames, factory.Name) } return nil diff --git a/database/migrations.go b/database/migrations.go index 39a574cadd10d2efbd4a4437a188ed392cb516a5_ZGF0YWJhc2UvbWlncmF0aW9ucy5nbw==..5021a3b4a23a2900eb97308912ebd136937e3c32_ZGF0YWJhc2UvbWlncmF0aW9ucy5nbw== 100644 --- a/database/migrations.go +++ b/database/migrations.go @@ -5,6 +5,7 @@ "errors" "fmt" "os" + "strings" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/postgres" @@ -13,8 +14,6 @@ ) //nolint:dupword -//go:generate go-bindata -pkg database -prefix migrations migrations -//go:generate sed -i "1s;^;// Code generated by go-bindata. DO NOT EDIT.\\n\\n;" bindata.go // ErrDBNotVersioned is returned by IsUptodate if the database is not versionned // at all. @@ -18,6 +17,11 @@ // ErrDBNotVersioned is returned by IsUptodate if the database is not versionned // at all. -var ErrDBNotVersioned = errors.New( - "the database is not versionned, please run the 'migrate' command") +var ErrDBNotVersioned = errors.New("not versionned") + +var ErrDBNeedUpgrade = errors.New("version is too old") +var ErrDBFutureVersion = errors.New("version is too new") + +var ErrDBInconsistent = errors.New("database is inconsistent") +var ErrDBAny = errors.New("database error") @@ -23,3 +27,3 @@ -// ErrDBNeedUpgrade is returned if the database is not up-to-date with the +// DBNeedUpgradeError is returned if the database is not up-to-date with the // server version. @@ -25,7 +29,23 @@ // server version. -type ErrDBNeedUpgrade struct { +type DBNeedUpgradeError struct { + SchemaName string + CurrentVersion uint + RequiredVersion uint +} + +// Error returns the formatted error message. +func (e DBNeedUpgradeError) Error() string { + return fmt.Sprintf("%w: schema: %s, required version: %d, current version: %d", + ErrDBNeedUpgrade, e.SchemaName, e.RequiredVersion, e.CurrentVersion, + ) +} + +// DBFutureVersionError is returned if the database is more recent than the server. +// It generally means the server is not at the right version. +type DBFutureVersionError struct { + SchemaName string CurrentVersion uint RequiredVersion uint } // Error returns the formatted error message. @@ -27,12 +47,11 @@ CurrentVersion uint RequiredVersion uint } // Error returns the formatted error message. -func (e ErrDBNeedUpgrade) Error() string { - return fmt.Sprintf("Database version is too old. Please run '%s migrate'."+ - " Required version: %d, current version: %d", - os.Args[0], e.RequiredVersion, e.CurrentVersion, +func (e DBFutureVersionError) Error() string { + return fmt.Sprintf("%w: schema: %s, required version: %d, current version: %d", + ErrDBFutureVersion, e.SchemaName, e.RequiredVersion, e.CurrentVersion, ) } @@ -36,10 +55,82 @@ ) } -// ErrDBFutureVersion is returned if the database is more recent than the server. -// It generally means the server is not at the right version. -type ErrDBFutureVersion struct { - CurrentVersion uint - RequiredVersion uint +type Schema struct { + Name string + Source source.Driver +} + +func (s Schema) GetMigrate(db *sql.DB, isMain bool) (*migrate.Migrate, error) { + var migrationTable string + if !isMain { + migrationTable = "schema_migrations_" + s.Name + } + driver, err := postgres.WithInstance(db, &postgres.Config{ + MigrationsTable: migrationTable, + MultiStatementEnabled: true, + }) + if err != nil { + return nil, err + } + m, err := migrate.NewWithInstance( + s.Name+"_schema", s.Source, + "postgres", driver, + ) + if err != nil { + return nil, err + } + + return m, nil +} + +func (s Schema) IsUptodate(db *sql.DB, isMain bool) error { + m, err := s.GetMigrate(db, isMain) + if err != nil { + return fmt.Errorf("failed to check database version: %w", err) + } + + // Lookup the last available db version + lastVersion, err := s.Source.First() + if err != nil { + return err + } + for { + next, err := s.Source.Next(lastVersion) + var pathErr *os.PathError + if errors.As(err, &pathErr) && errors.Is(err, os.ErrNotExist) { + break + } else if err != nil { + return err + } + lastVersion = next + } + + version, dirty, err := m.Version() + if errors.Is(err, migrate.ErrNilVersion) { + return fmt.Errorf("schema '%s' is %w", s.Name, ErrDBNotVersioned) + } + if err != nil { + return fmt.Errorf("error while checking the database version: %w", err) + } + if dirty { + return errors.New("schema '%s' is marked 'dirty'") + } + + if version < lastVersion { + return DBNeedUpgradeError{ + SchemaName: s.Name, + CurrentVersion: version, + RequiredVersion: lastVersion, + } + } + if version > lastVersion { + return DBFutureVersionError{ + SchemaName: s.Name, + CurrentVersion: version, + RequiredVersion: lastVersion, + } + } + + return nil } @@ -44,12 +135,94 @@ } -// Error returns the formatted error message. -func (e ErrDBFutureVersion) Error() string { - return fmt.Sprintf("Database version is too new. It probably means your "+ - "version of '%s' is too old."+ - " Required version: %d, current version: %d", - os.Args[0], e.RequiredVersion, e.CurrentVersion, - ) +func (s Schema) Migrate(db *sql.DB, isMain bool) error { + m, err := s.GetMigrate(db, isMain) + if err != nil { + return fmt.Errorf("failed to migrate database schema '%s': %w", s.Name, err) + } + + if err := m.Up(); err != nil { + return fmt.Errorf("could not migrate '%s': %w", s.Name, err) + } + + return nil +} + +type SchemaSet struct { + Main Schema + Extra []Schema +} + +func (s SchemaSet) Migrate(db *sql.DB, log zerolog.Logger) error { + if err := s.Main.Migrate(db, true); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return err + } + log.Info().Str("schema", s.Main.Name).Msg("successfully migrated") + + for i := range s.Extra { + if err := s.Extra[i].Migrate( + db, false, + ); err != nil && !errors.Is(err, migrate.ErrNoChange) { + return nil + } + log.Info().Str("schema", s.Extra[i].Name).Msg("successfully migrated") + } + + return nil +} + +func (s SchemaSet) IsUptodate(db *sql.DB) error { + var errs []error + + if err := s.Main.IsUptodate(db, true); err != nil { + errs = append(errs, err) + } + + for i := range s.Extra { + if err := s.Extra[i].IsUptodate(db, false); err != nil { + errs = append(errs, err) + } + } + + if len(errs) != 0 { + var tooOld, tooNew, notVersionned, otherErrors bool + for _, err := range errs { + if errors.Is(err, ErrDBFutureVersion) { + tooNew = true + } else if errors.Is(err, ErrDBNeedUpgrade) { + tooOld = true + } else if errors.Is(err, ErrDBNotVersioned) { + notVersionned = true + } else { + otherErrors = true + } + } + args := make([]any, len(errs)+1) + switch { + case otherErrors: + args[0] = ErrDBAny + case tooOld && tooNew && notVersionned: + args[0] = ErrDBInconsistent + case tooOld && tooNew: + args[0] = ErrDBInconsistent + case tooOld && notVersionned: + args[0] = ErrDBNeedUpgrade + case tooNew && notVersionned: + args[0] = ErrDBInconsistent + case tooOld: + args[0] = ErrDBNeedUpgrade + case tooNew: + args[0] = ErrDBFutureVersion + case notVersionned: + args[0] = ErrDBNeedUpgrade + } + for i, err := range errs { + args[i+1] = err + } + + return fmt.Errorf("%w:"+strings.Repeat(" (%w)", len(errs)), args...) + } + + return nil } // NewMigrate initializes a github.com/golang-migrate/migrate/v4.Migrate for @@ -82,7 +255,7 @@ if err == nil { return nil } - var upgradeErr ErrDBNeedUpgrade + var upgradeErr DBNeedUpgradeError if ok := errors.As(err, &upgradeErr); !errors.Is(err, ErrDBNotVersioned) == !ok { return err } @@ -136,9 +309,9 @@ } if version < lastVersion { - return ErrDBNeedUpgrade{ + return DBNeedUpgradeError{ CurrentVersion: version, RequiredVersion: lastVersion, } } if version > lastVersion { @@ -140,9 +313,9 @@ CurrentVersion: version, RequiredVersion: lastVersion, } } if version > lastVersion { - return ErrDBFutureVersion{ + return DBFutureVersionError{ CurrentVersion: version, RequiredVersion: lastVersion, } diff --git a/database/test.go b/database/test.go index 39a574cadd10d2efbd4a4437a188ed392cb516a5_ZGF0YWJhc2UvdGVzdC5nbw==..5021a3b4a23a2900eb97308912ebd136937e3c32_ZGF0YWJhc2UvdGVzdC5nbw== 100644 --- a/database/test.go +++ b/database/test.go @@ -9,6 +9,7 @@ "github.com/golang-migrate/migrate/v4/source" "github.com/jmoiron/sqlx" + "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) @@ -83,7 +84,7 @@ "SELECT tablename "+ "FROM pg_catalog.pg_tables "+ "WHERE schemaname = 'public' "+ - "AND tablename <> 'schema_migrations';", + "AND tablename NOT LIKE 'schema_migrations%';", )) alltables = slices.DeleteFunc(alltables, func(name string) bool { return slices.Contains(dbtablenames, name) @@ -134,5 +135,5 @@ return &TestDB{db, dsn, ctx, tb, c} } -// GetTestDB creates a db and returns it. It must be closed within the test. +// GetTestDBWSchemaSet creates a db and returns it. It must be closed within the test. // If it fails, t.Fatal() is called. @@ -138,3 +139,3 @@ // If it fails, t.Fatal() is called. -func GetTestDB(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB { +func GetTestDBWSchemaSet(ctx context.Context, tb testing.TB, sset SchemaSet) *TestDB { tb.Helper() @@ -140,7 +141,7 @@ tb.Helper() - db := GetTestDBNoClear(ctx, tb, sourceDriver) + db := GetTestDBWSchemaSetNoClear(ctx, tb, sset) db.ClearDB() return db } @@ -142,7 +143,9 @@ db.ClearDB() return db } -func GetTestDBNoClear(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB { +func GetTestDBWSchemaSetNoClear( + ctx context.Context, tb testing.TB, sset SchemaSet, +) *TestDB { tb.Helper() @@ -148,4 +151,6 @@ tb.Helper() + log := zerolog.Ctx(ctx) + var success bool dsn := getDSN(tb) @@ -185,5 +190,5 @@ tb.Fatal(err) } - if len(dbtablenames) == 0 { + if err := sset.IsUptodate(db.DB); err != nil { clearDB(ctx, tb, c) @@ -189,2 +194,3 @@ clearDB(ctx, tb, c) + dbtablenames = nil @@ -190,14 +196,5 @@ - m, err := NewMigrate(dsn, sourceDriver) - if err != nil { - tb.Fatal(err) - } - if err := m.Up(); err != nil { - tb.Fatal(err) - } - if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil { - tb.Fatal(srcErr, dbErr) - } + require.NoError(tb, sset.Migrate(db.DB, *log)) } testDB := TestDB{db, dsn, ctx, tb, c} @@ -206,7 +203,7 @@ "SELECT tablename "+ "FROM pg_catalog.pg_tables "+ "WHERE schemaname = 'public' "+ - "AND tablename <> 'schema_migrations';", + "AND tablename NOT LIKE 'schema_migrations%';", ); err != nil { tb.Fatal(err) } @@ -217,6 +214,24 @@ return &testDB } +// GetTestDB creates a db and returns it. It must be closed within the test. +// If it fails, t.Fatal() is called. +func GetTestDB(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB { + tb.Helper() + db := GetTestDBNoClear(ctx, tb, sourceDriver) + db.ClearDB() + + return db +} + +func GetTestDBNoClear(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB { + tb.Helper() + + return GetTestDBWSchemaSetNoClear( + ctx, tb, SchemaSet{Main: Schema{Name: "main", Source: sourceDriver}}, + ) +} + func (db *TestDB) DisableConstraints(exec func(string, ...any) (sql.Result, error)) { for _, name := range dbtablenames { if _, err := exec(