Newer
Older
package database
import (
"context"
"testing"
"github.com/golang-migrate/migrate/v4/source"
"github.com/jmoiron/sqlx"
var dbtablenames []string
dsn := os.Getenv("TEST_DB_DSN")
if dsn == "" {
tb.Fatal("Please define a TEST_DB_DSN environment variable")
}
return dsn
}
func clearDB(ctx context.Context, t testing.TB, c *sqlx.Conn, dsn string) {
// Drop all tables, and reinit
if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil {
// TestDB is a sqlx.DB wrapper that release a global lock on close.
type TestDB struct {
*sqlx.DB
// SetTB changes the current tb and returns a function to restore the original one.
func (db *TestDB) SetTB(tb testing.TB) func() {
save := db.tb
db.tb = tb
return func() {
db.tb = save
}
}
// ClearDB empty the database.
tx, err := db.lockConn.BeginTxx(db.ctx, nil)
defer func() {
require.NoError(db.tb, tx.Commit())
}()
db.DisableConstraints(tx.Exec)
defer db.EnableConstraints(tx.Exec)
for _, name := range dbtablenames {
if _, err := tx.Exec(
"DELETE FROM " + name,
); err != nil {
db.tb.Fatal(err)
}
}
var alltables []string
require.NoError(db.tb,
tx.Select(&alltables,
"SELECT tablename "+
"FROM pg_catalog.pg_tables "+
"WHERE schemaname = 'public' "+
"AND tablename <> 'schema_migrations';",
))
alltables = slices.DeleteFunc(alltables, func(name string) bool {
return slices.Contains(dbtablenames, name)
})
for _, name := range alltables {
_, err := tx.Exec("DROP TABLE " + name)
require.NoError(db.tb, err)
func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB {
var success bool
db, err := Open(dsn, 0)
require.NoError(tb, err)
defer func() {
if !success {
if err := db.Close(); err != nil {
tb.Log("Error closing db:", err)
}
}
}()
c, err := db.Connx(ctx)
require.NoError(tb, err)
defer func() {
if !success {
if err := c.Close(); err != nil {
tb.Log("Error closing conn:", err)
}
}
}()
if _, err := c.ExecContext(
ctx,
"SELECT pg_advisory_lock($1)", dbLockID,
); err != nil {
tb.Fatal(err)
}
// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called.
func GetTestDB(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB {
db := GetTestDBNoClear(ctx, t, sourceDriver)
db.ClearDB()
return db
}
func GetTestDBNoClear(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB {
var success bool
db, err := Open(dsn, 0)
if err != nil {
t.Fatal(err)
}
defer func() {
if !success {
if ctx.Err() == nil {
if err := db.Close(); err != nil {
t.Log("Error closing db:", err)
}
}
}
}()
if err != nil {
t.Fatal(err)
}
defer func() {
if !success {
if err := c.Close(); err != nil {
t.Log("Error closing conn:", err)
}
}
}()
"SELECT pg_advisory_lock($1)", dbLockID,
); err != nil {
t.Fatal(err)
}
m, err := NewMigrate(dsn, sourceDriver)
if err != nil {
t.Fatal(err)
}
if err := m.Up(); err != nil {
t.Fatal(err)
}
if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
t.Fatal(srcErr, dbErr)
}
}
testDB := TestDB{db, dsn, ctx, t, c}
if len(dbtablenames) == 0 {
if err := db.Select(&dbtablenames,
"SELECT tablename "+
"FROM pg_catalog.pg_tables "+
"WHERE schemaname = 'public' "+
"AND tablename <> 'schema_migrations';",
); err != nil {
t.Fatal(err)
}
}
success = true
func (db *TestDB) DisableConstraints(exec func(string, ...any) (sql.Result, error)) {
"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
}
func (db *TestDB) EnableConstraints(exec func(string, ...any) (sql.Result, error)) {
"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
}
// Close the lock connection, then the database.
if db.lockConn != nil {
if err := db.lockConn.Close(); err != nil {
db.tb.Error("lockConn.Close() failed", err)
}
db.lockConn = nil
if db.DB != nil {
if err := db.DB.Close(); err != nil {
db.tb.Error("db.Close() failed", err)
}
db.DB = nil