# HG changeset patch # User Christophe de Vienne <christophe@cdevienne.info> # Date 1599742253 -7200 # Thu Sep 10 14:50:53 2020 +0200 # Node ID fcf7843defb03511ccc57f78b3d87715210640f2 # Parent a05e099f7b300257697d4a56eb02b75d3753f623 Import the rednerd 'database' module Made it generic too. diff --git a/database/database.go b/database/database.go new file mode 100644 --- /dev/null +++ b/database/database.go @@ -0,0 +1,17 @@ +package database + +import ( + "github.com/jmoiron/sqlx" +) + +// Open opens a connection to the database +func Open(dsn string, maxConn int) (*sqlx.DB, error) { + db, err := sqlx.Open("postgres", dsn) + if err != nil { + return nil, err + } + if maxConn != 0 { + db.SetMaxOpenConns(maxConn) + } + return db, nil +} diff --git a/database/join.go b/database/join.go new file mode 100644 --- /dev/null +++ b/database/join.go @@ -0,0 +1,9 @@ +package database + +import "fmt" + +// Join builds a join clause +func Join(tableFrom, fieldFrom, tableTo, fieldTo string) string { + return fmt.Sprintf("%[3]s ON %[3]s.%[4]s = %[1]s.%[2]s", + tableFrom, fieldFrom, tableTo, fieldTo) +} diff --git a/database/migrations.go b/database/migrations.go new file mode 100644 --- /dev/null +++ b/database/migrations.go @@ -0,0 +1,146 @@ +package database + +import ( + "database/sql" + "errors" + "fmt" + "os" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/source" + "github.com/rs/zerolog" +) + +//go:generate go-bindata -pkg database -prefix migrations migrations +//go:generate sed -i "1s;^;// Code generated by go-bindata. DO NOT EDIT.\\n\\n;" bindata.go + +// ErrDBNotVersioned is returned by IsUptodate if the database is not versionned +// at all +var ErrDBNotVersioned = errors.New( + "the database is not versionned, please run the 'migrate' command") + +// ErrDBNeedUpgrade is returned if the database is not up-to-date with the +// server version +type ErrDBNeedUpgrade struct { + CurrentVersion uint + RequiredVersion uint +} + +// Error returns the formatted error message +func (e ErrDBNeedUpgrade) Error() string { + return fmt.Sprintf("Database version is too old. Please run '%s migrate'."+ + " Required version: %d, current version: %d", + os.Args[0], e.RequiredVersion, e.CurrentVersion, + ) +} + +// ErrDBFutureVersion is returned if the database is more recent than the server. +// It generally means the server is not at the right version +type ErrDBFutureVersion struct { + CurrentVersion uint + RequiredVersion uint +} + +// Error returns the formatted error message +func (e ErrDBFutureVersion) Error() string { + return fmt.Sprintf("Database version is too new. It probably means your "+ + "version of '%s' is too old."+ + " Required version: %d, current version: %d", + os.Args[0], e.RequiredVersion, e.CurrentVersion, + ) +} + +// NewMigrate initializes a github.com/golang-migrate/migrate/v4.Migrate for +// the given database options +func NewMigrate(dsn string, sourceDriver source.Driver) (*migrate.Migrate, error) { + db, err := sql.Open("postgres", dsn) + if err != nil { + panic(err) + } + + driver, err := postgres.WithInstance(db, &postgres.Config{}) + if err != nil { + return nil, err + } + m, err := migrate.NewWithInstance( + "go-bindata", sourceDriver, + "postgres", driver, + ) + if err != nil { + return nil, err + } + return m, nil +} + +// AutoMigrate brings the db up-to-date and logs a warning if it needed some +// changes +func AutoMigrate(dsn string, sourceDriver source.Driver, log zerolog.Logger) error { + err := IsUptodate(dsn, sourceDriver) + if err == nil { + return nil + } + if _, ok := err.(ErrDBNeedUpgrade); err != ErrDBNotVersioned == !ok { + return err + } + log.Warn().Msg("Database is not up-to-date, it will be migrated automatically") + m, err := NewMigrate(dsn, sourceDriver) + if err != nil { + return fmt.Errorf("failed to init migration engine: %s", err) + } + defer m.Close() + + if err := m.Up(); err != nil { + return fmt.Errorf("error during auto-migration: %s", err) + } + log.Info().Msg("Successfully upgraded database") + return nil +} + +// IsUptodate returns nil if the database version is up to date +func IsUptodate(dsn string, sourceDriver source.Driver) error { + m, err := NewMigrate(dsn, sourceDriver) + if err != nil { + return fmt.Errorf("failed to check database version: %s", err) + } + + // Lookup the last available db version + lastVersion, err := sourceDriver.First() + if err != nil { + return err + } + for { + next, err := sourceDriver.Next(lastVersion) + if pathError, ok := err.(*os.PathError); err == os.ErrNotExist || ok && pathError.Err == os.ErrNotExist { + break + } else if err != nil { + return err + } + lastVersion = next + } + + version, dirty, err := m.Version() + if err == migrate.ErrNilVersion { + return ErrDBNotVersioned + } + if err != nil { + return fmt.Errorf("error while checking the database version: %s", err) + } + if dirty { + return fmt.Errorf("database is marked 'dirty'") + } + + if version < lastVersion { + return ErrDBNeedUpgrade{ + CurrentVersion: version, + RequiredVersion: lastVersion, + } + } + if version > lastVersion { + return ErrDBFutureVersion{ + CurrentVersion: version, + RequiredVersion: lastVersion, + } + } + return nil +} diff --git a/database/options.go b/database/options.go new file mode 100644 --- /dev/null +++ b/database/options.go @@ -0,0 +1,14 @@ +package database + +import "github.com/jmoiron/sqlx" + +// Options is a jessevdk/go-flags compatible struct for db-related options +type Options struct { + DSN string `long:"db-dsn" env:"DB_DSN" ini-name:"dsn" description:"DSN of the database"` + MaxConn int `long:"db-max-conn" env:"DB_MAX_CONN" ini-name:"max-conn" description:"Database max connection" default:"0"` +} + +// Open a connection to the database +func (o Options) Open() (*sqlx.DB, error) { + return Open(o.DSN, o.MaxConn) +} diff --git a/database/sequence.go b/database/sequence.go new file mode 100644 --- /dev/null +++ b/database/sequence.go @@ -0,0 +1,27 @@ +package database + +import ( + sq "github.com/Masterminds/squirrel" +) + +type SequenceBuilder struct { + name string +} + +func Sequence(name string) SequenceBuilder { + return SequenceBuilder{name} +} + +func (s SequenceBuilder) Set(value int) sq.SelectBuilder { + return sq.Select().Column("setval(?, ?)", s.name, value) +} + +func (s SequenceBuilder) Last() sq.SelectBuilder { + return sq.Select().Column("last_value").From(s.name) +} +func (s SequenceBuilder) Next() sq.SelectBuilder { + return sq.Select().Column("nextval(?)", s.name) +} +func (s SequenceBuilder) Current() sq.SelectBuilder { + return sq.Select().Column("currval(?)", s.name) +} diff --git a/database/sql.go b/database/sql.go new file mode 100644 --- /dev/null +++ b/database/sql.go @@ -0,0 +1,159 @@ +package database + +import ( + "context" + "database/sql" + "strings" + + "github.com/Masterminds/squirrel" + "github.com/lann/builder" + "github.com/rs/zerolog" +) + +// Mapped is the common interface of all structs that are mapped in the database +type Mapped interface { + Table() string + PKeyColumn() string + Columns(withPKey bool) []string + Values(columns ...string) []interface{} +} + +// SQ is a squirrel StatementBuilder +var SQ = squirrel.StatementBuilder + +// SQLTrace logs a sql query and args as TRACE level +func SQLTrace(log *zerolog.Logger, sql string, args []interface{}) { + if log != nil && log.GetLevel() < 0 { + arr := zerolog.Arr() + for _, arg := range args { + arr.Interface(arg) + } + logger := log.With().Array("args", arr).Logger() + logger.Trace().Msg("SQL: " + sql) + } +} + +// SQLExecutor is the common interface of sqlx.DB and sqlx.Tx that we use in +// the functions below +type SQLExecutor interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error + SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error +} + +func setPlaceHolderFormat(query squirrel.Sqlizer) squirrel.Sqlizer { + return builder.Set(query, "PlaceholderFormat", squirrel.Dollar).(squirrel.Sqlizer) +} + +// Get loads an object +func Get(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error { + return GetContext(context.Background(), e, obj, query, log) +} + +// GetContext loads an object +func GetContext( + ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger, +) error { + sqlQuery, args, err := setPlaceHolderFormat(query).ToSql() + if err != nil { + return err + } + SQLTrace(log, sqlQuery, args) + return e.GetContext(ctx, obj, sqlQuery, args...) +} + +// Select load an object list +func Select(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error { + return SelectContext(context.Background(), e, obj, query, log) +} + +// SelectContext loads an object list +func SelectContext( + ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger, +) error { + sqlQuery, args, err := setPlaceHolderFormat(query).ToSql() + if err != nil { + return err + } + SQLTrace(log, sqlQuery, args) + return e.SelectContext(ctx, obj, sqlQuery, args...) +} + +// Exec runs a squirrel query on the given db or tx +func Exec(e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) { + return ExecContext(context.Background(), e, query, log) +} + +// ExecContext runs a squirrel query on the given db or tx +func ExecContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) { + query = setPlaceHolderFormat(query) + sqlQuery, args, err := query. + ToSql() + if err != nil { + return nil, err + } + SQLTrace(log, sqlQuery, args) + return e.ExecContext(ctx, sqlQuery, args...) +} + +// ValuesMap returns the values for a list of columns as a map. If a column does +// not exits, the corresponding value is set to nil +func ValuesMap(m Mapped, columns ...string) map[string]interface{} { + valuesList := m.Values(columns...) + values := make(map[string]interface{}) + for i, column := range columns { + values[column] = valuesList[i] + } + return values +} + +// SQLUpsert generates a squirrel "upsert" statement +func SQLUpsert(m Mapped) squirrel.InsertBuilder { + updateSQL, updateArgs, err := squirrel. + Update(m.Table()). + SetMap(ValuesMap(m, m.Columns(false)...)). + ToSql() + if err != nil { + panic(err) + } + + updateParts := strings.Split(updateSQL, " SET ") + if len(updateParts) != 2 { + panic("Could not split the UPDATE query: " + updateSQL) //nolint:gosec + } + suffix := "ON CONFLICT (" + m.PKeyColumn() + ") DO UPDATE SET " + updateParts[1] + + allColumns := m.Columns(true) + q := squirrel. + Insert(m.Table()). + PlaceholderFormat(squirrel.Dollar). + Columns(allColumns...). + Values(m.Values(allColumns...)...). + Suffix(suffix, updateArgs...) + return q +} + +// SQLInsert build a Insert statement to insert one or several mapped instances +// in the database. All the instances must be of the same actual type +func SQLInsert(instances ...Mapped) squirrel.InsertBuilder { + allColumns := instances[0].Columns(true) + q := SQ.Insert(instances[0].Table()). + Columns(allColumns...) + for _, m := range instances { + q = q.Values(m.Values(allColumns...)...) + } + return q +} + +// SQLInsertNoPKey build a Insert statement to insert one or several mapped instances +// in the database, but leave the pkey undefined so it gets auto-generated +// by the database. All the instances must be of the same actual type +func SQLInsertNoPKey(instances ...Mapped) squirrel.InsertBuilder { + allColumns := instances[0].Columns(false) + q := SQ.Insert(instances[0].Table()). + Columns(allColumns...) + for _, m := range instances { + q = q.Values(m.Values(allColumns...)...) + } + return q +} diff --git a/database/sql_helper.go b/database/sql_helper.go new file mode 100644 --- /dev/null +++ b/database/sql_helper.go @@ -0,0 +1,73 @@ +package database + +import ( + "context" + "database/sql" + + sq "github.com/Masterminds/squirrel" + "github.com/rs/zerolog" +) + +// NewSQLHelper creates a SQLHelper. It must not outlive the given context and sql executor +func NewSQLHelper(ctx context.Context, sqle SQLExecutor, log zerolog.Logger) SQLHelper { + return SQLHelper{ctx, sqle, log} +} + +// SQLHelper is a short lived helper to easily get/select Mapped structures +type SQLHelper struct { + ctx context.Context + sqle SQLExecutor + log zerolog.Logger +} + +// Get loads a mapped structure +func (h *SQLHelper) Get(obj interface{}, query sq.Sqlizer) error { + return GetContext(h.ctx, h.sqle, obj, query, &h.log) +} + +// GetByPKey loads a mapped structure +func (h *SQLHelper) GetByPKey(obj Mapped, value interface{}) error { + return h.GetBy(obj, obj.PKeyColumn(), value) +} + +// GetBy ... +func (h *SQLHelper) GetBy(obj Mapped, column string, value interface{}) error { + return h.GetWhere(obj, sq.Eq{column: value}) +} + +// GetWhere loads a mapped structure +func (h *SQLHelper) GetWhere(obj Mapped, pred interface{}, args ...interface{}) error { + query := SQ. + Select(obj.Columns(true)...). + From(obj.Table()). + Where(pred, args...) + + return h.Get(obj, query) +} + +// Select loads a structure list +func (h *SQLHelper) Select(obj interface{}, query sq.Sqlizer) error { + return SelectContext(h.ctx, h.sqle, obj, query, &h.log) +} + +// Exec executes a query +func (h *SQLHelper) Exec(query sq.Sqlizer) (sql.Result, error) { + return ExecContext(h.ctx, h.sqle, query, &h.log) +} + +// Insert inserts a Mapped into the db +func (h *SQLHelper) Insert(instances ...Mapped) (sql.Result, error) { + query := SQLInsert(instances...) + return h.Exec(query) +} + +// Upsert upserts a Mapped into the db +func (h *SQLHelper) Upsert(instances ...Mapped) error { + for _, instance := range instances { + query := SQLUpsert(instance) + if _, err := h.Exec(query); err != nil { + return err + } + } + return nil +} diff --git a/database/test.go b/database/test.go new file mode 100644 --- /dev/null +++ b/database/test.go @@ -0,0 +1,126 @@ +package database + +import ( + "context" + "database/sql" + "os" + "testing" + + "github.com/golang-migrate/migrate/v4/source" + "github.com/jmoiron/sqlx" +) + +var ( + dbtablenames []string +) + +const dbLockID = 15104 + +func clearDB(t *testing.T, dsn string) { + // Drop all tables, and reinit + db, err := Open(dsn, 0) + if err != nil { + t.Fatal(err) + } + defer db.Close() + if _, err := db.Exec("DROP SCHEMA IF EXISTS public CASCADE"); err != nil { + t.Fatal(err) + } + if _, err := db.Exec("CREATE SCHEMA public"); err != nil { + t.Fatal(err) + } +} + +// TestDB is a sqlx.DB wrapper that release a global lock on close +type TestDB struct { + *sqlx.DB + tb testing.TB + lockConn *sql.Conn +} + +// GetTestDB creates a db and returns it. It must be closed within the test. +// If it fails, t.Fatal() is called +func GetTestDB(t *testing.T, sourceDriver source.Driver) *TestDB { + dsn := os.Getenv("TEST_DB_DSN") + if dsn == "" { + t.Fatal("Please define a TEST_DB_DSN environment variable") + } + if len(dbtablenames) == 0 { + clearDB(t, dsn) + + m, err := NewMigrate(dsn, sourceDriver) + if err != nil { + t.Fatal(err) + } + if err := m.Up(); err != nil { + t.Fatal(err) + } + if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil { + t.Fatal(srcErr, dbErr) + } + } + db, err := Open(dsn, 0) + if err != nil { + t.Fatal(err) + } + if len(dbtablenames) == 0 { + if err := db.Select(&dbtablenames, + "SELECT tablename "+ + "FROM pg_catalog.pg_tables "+ + "WHERE schemaname = 'public'", + ); err != nil { + t.Fatal(err) + } + } else { + for _, name := range dbtablenames { + if _, err := db.Exec( + "ALTER TABLE " + name + " DISABLE TRIGGER ALL", + ); err != nil { + t.Fatal(err) + return nil + } + } + for _, name := range dbtablenames { + if _, err := db.Exec( + "DELETE FROM " + name, //nolint:gosec + ); err != nil { + t.Fatal(err) + return nil + } + } + for _, name := range dbtablenames { + if _, err := db.Exec( + "ALTER TABLE " + name + " ENABLE TRIGGER ALL", + ); err != nil { + t.Fatal(err) + return nil + } + } + } + + c, err := db.Conn(context.Background()) + + if err != nil { + _ = db.Close() + t.Fatal(err) + } + + if _, err := c.ExecContext( + context.Background(), + "SELECT pg_advisory_lock($1)", dbLockID, + ); err != nil { + _ = db.Close() + t.Fatal(err) + } + + return &TestDB{db, t, c} +} + +// Close the lock connection, then the database +func (db *TestDB) Close() error { + if err := db.lockConn.Close(); err != nil { + db.tb.Error("lockConn.Close() failed", err) + } + + return db.DB.Close() +} diff --git a/database/tx.go b/database/tx.go new file mode 100644 --- /dev/null +++ b/database/tx.go @@ -0,0 +1,69 @@ +package database + +import ( + "context" + "fmt" + + "github.com/jmoiron/sqlx" + "github.com/rs/zerolog" +) + +// TxState is a enum that describes the states of a transaction +type TxState int + +const ( + // TxOpened means the transaction begun and is not yet closed + TxOpened TxState = iota + // TxCommitted means the transaction was committed (maybe with an error) + TxCommitted + // TxRolledback means the transaction was rolled back (maybe with an error) + TxRolledback +) + +// Begin begins a transaction and returns a *database.Tx +func Begin(ctx context.Context, db *sqlx.DB) (*Tx, error) { + tx, err := db.BeginTxx(ctx, nil) + if err != nil { + return nil, err + } + return &Tx{tx, TxOpened}, nil +} + +// Tx wraps a sqlx.Tx +type Tx struct { + *sqlx.Tx + + state TxState +} + +// Commit commits the transaction +func (tx *Tx) Commit() error { + if tx.state != TxOpened { + return fmt.Errorf("transaction is not open") + } + tx.state = TxCommitted + return tx.Tx.Commit() +} + +// Rollback rollbacks the transaction +func (tx *Tx) Rollback() error { + if tx.state != TxOpened { + return fmt.Errorf("transaction is not open") + } + tx.state = TxRolledback + return tx.Tx.Rollback() +} + +// LoggedRollback rollbacks the transaction and log the error if it fails +func (tx *Tx) LoggedRollback(log zerolog.Logger) { + if err := tx.Rollback(); err != nil { + log.Err(err).Msg("rollback failed") + } +} + +// RollbackIfOpened rollbacks the transaction if it is still opened +func (tx *Tx) RollbackIfOpened(log zerolog.Logger) { + if tx.state == TxOpened { + tx.LoggedRollback(log) + } +}