Newer
Older
package database
import (
"context"
"testing"
"github.com/golang-migrate/migrate/v4/source"
"github.com/jmoiron/sqlx"
var dbtablenames []string
dsn := os.Getenv("TEST_DB_DSN")
if dsn == "" {
tb.Fatal("Please define a TEST_DB_DSN environment variable")
}
func clearDB(ctx context.Context, tb testing.TB, c *sqlx.Conn) {
tb.Helper()
// Drop all tables, and reinit
if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil {
// TestDB is a sqlx.DB wrapper that release a global lock on close.
type TestDB struct {
*sqlx.DB
// SetTB changes the current tb and returns a function to restore the original one.
func (db *TestDB) SetTB(tb testing.TB) func() {
return func() {
db.tb = save
}
}
// ClearDB empty the database.
tx, err := db.lockConn.BeginTxx(db.ctx, nil)
defer func() {
require.NoError(db.tb, tx.Commit())
}()
db.DisableConstraints(tx.Exec)
defer db.EnableConstraints(tx.Exec)
for _, name := range dbtablenames {
if _, err := tx.Exec(
); err != nil {
db.tb.Fatal(err)
}
}
var alltables []string
require.NoError(db.tb,
tx.Select(&alltables,
"SELECT tablename "+
"FROM pg_catalog.pg_tables "+
"WHERE schemaname = 'public' "+
"AND tablename <> 'schema_migrations';",
))
alltables = slices.DeleteFunc(alltables, func(name string) bool {
return slices.Contains(dbtablenames, name)
})
for _, name := range alltables {
_, err := tx.Exec("DROP TABLE " + name)
require.NoError(db.tb, err)
func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB {
db, err := Open(dsn, 0)
require.NoError(tb, err)
defer func() {
if !success {
if err := db.Close(); err != nil {
tb.Log("Error closing db:", err)
}
}
}()
c, err := db.Connx(ctx)
require.NoError(tb, err)
defer func() {
if !success {
if err := c.Close(); err != nil {
tb.Log("Error closing conn:", err)
}
}
}()
if _, err := c.ExecContext(
ctx,
"SELECT pg_advisory_lock($1)", dbLockID,
); err != nil {
tb.Fatal(err)
}
// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called.
func GetTestDB(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB {
tb.Helper()
db := GetTestDBNoClear(ctx, tb, sourceDriver)
func GetTestDBNoClear(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB {
tb.Helper()
var success bool
db, err := Open(dsn, 0)
if err != nil {
defer func() {
if !success {
if ctx.Err() == nil {
if err := db.Close(); err != nil {
}
}
}()
defer func() {
if !success {
if err := c.Close(); err != nil {
}
}
}()
"SELECT pg_advisory_lock($1)", dbLockID,
); err != nil {
m, err := NewMigrate(dsn, sourceDriver)
if err != nil {
}
if err := m.Up(); err != nil {
}
if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
if len(dbtablenames) == 0 {
if err := db.Select(&dbtablenames,
"SELECT tablename "+
"FROM pg_catalog.pg_tables "+
"WHERE schemaname = 'public' "+
"AND tablename <> 'schema_migrations';",
success = true
func (db *TestDB) DisableConstraints(exec func(string, ...any) (sql.Result, error)) {
"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
}
func (db *TestDB) EnableConstraints(exec func(string, ...any) (sql.Result, error)) {
"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
}
// Close the lock connection, then the database.
if db.lockConn != nil {
if err := db.lockConn.Close(); err != nil {
db.tb.Error("lockConn.Close() failed", err)
}
db.lockConn = nil
if db.DB != nil {
if err := db.DB.Close(); err != nil {
db.tb.Error("db.Close() failed", err)
}
db.DB = nil