package database import ( "context" "database/sql" "os" "slices" "testing" "github.com/golang-migrate/migrate/v4/source" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/require" ) var dbtablenames []string const dbLockID = 15104 func getDSN(tb testing.TB) string { tb.Helper() dsn := os.Getenv("TEST_DB_DSN") if dsn == "" { tb.Fatal("Please define a TEST_DB_DSN environment variable") } return dsn } func clearDB(ctx context.Context, tb testing.TB, c *sqlx.Conn) { tb.Helper() // Drop all tables, and reinit if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil { tb.Fatal(err) } if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil { tb.Fatal(err) } } // TestDB is a sqlx.DB wrapper that release a global lock on close. type TestDB struct { *sqlx.DB DSN string ctx context.Context tb testing.TB lockConn *sqlx.Conn } // SetTB changes the current tb and returns a function to restore the original one. func (db *TestDB) SetTB(tb testing.TB) func() { tb.Helper() save := db.tb db.tb = tb return func() { db.tb = save } } // ClearDB empty the database. func (db *TestDB) ClearDB() { tx, err := db.lockConn.BeginTxx(db.ctx, nil) require.NoError(db.tb, err) defer func() { require.NoError(db.tb, tx.Commit()) }() db.DisableConstraints(tx.Exec) defer db.EnableConstraints(tx.Exec) for _, name := range dbtablenames { if _, err := tx.Exec( "DELETE FROM " + name, ); err != nil { db.tb.Fatal(err) } } var alltables []string require.NoError(db.tb, tx.Select(&alltables, "SELECT tablename "+ "FROM pg_catalog.pg_tables "+ "WHERE schemaname = 'public' "+ "AND tablename <> 'schema_migrations';", )) alltables = slices.DeleteFunc(alltables, func(name string) bool { return slices.Contains(dbtablenames, name) }) for _, name := range alltables { _, err := tx.Exec("DROP TABLE " + name) require.NoError(db.tb, err) } } func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB { tb.Helper() var success bool dsn := getDSN(tb) db, err := Open(dsn, 0) require.NoError(tb, err) defer func() { if !success { if err := db.Close(); err != nil { tb.Log("Error closing db:", err) } } }() c, err := db.Connx(ctx) require.NoError(tb, err) defer func() { if !success { if err := c.Close(); err != nil { tb.Log("Error closing conn:", err) } } }() if _, err := c.ExecContext( ctx, "SELECT pg_advisory_lock($1)", dbLockID, ); err != nil { tb.Fatal(err) } success = true return &TestDB{db, dsn, ctx, tb, c} } // GetTestDB creates a db and returns it. It must be closed within the test. // If it fails, t.Fatal() is called. func GetTestDB(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB { tb.Helper() db := GetTestDBNoClear(ctx, tb, sourceDriver) db.ClearDB() return db } func GetTestDBNoClear(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB { tb.Helper() var success bool dsn := getDSN(tb) db, err := Open(dsn, 0) if err != nil { tb.Fatal(err) } defer func() { if !success { if ctx.Err() == nil { if err := db.Close(); err != nil { tb.Log("Error closing db:", err) } } } }() c, err := db.Connx(ctx) if err != nil { tb.Fatal(err) } defer func() { if !success { if err := c.Close(); err != nil { tb.Log("Error closing conn:", err) } } }() if _, err := c.ExecContext( ctx, "SELECT pg_advisory_lock($1)", dbLockID, ); err != nil { tb.Fatal(err) } if len(dbtablenames) == 0 { clearDB(ctx, tb, c) m, err := NewMigrate(dsn, sourceDriver) if err != nil { tb.Fatal(err) } if err := m.Up(); err != nil { tb.Fatal(err) } if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil { tb.Fatal(srcErr, dbErr) } } testDB := TestDB{db, dsn, ctx, tb, c} if len(dbtablenames) == 0 { if err := db.Select(&dbtablenames, "SELECT tablename "+ "FROM pg_catalog.pg_tables "+ "WHERE schemaname = 'public' "+ "AND tablename <> 'schema_migrations';", ); err != nil { tb.Fatal(err) } } success = true return &testDB } func (db *TestDB) DisableConstraints(exec func(string, ...any) (sql.Result, error)) { for _, name := range dbtablenames { if _, err := exec( "ALTER TABLE " + name + " DISABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) } } } func (db *TestDB) EnableConstraints(exec func(string, ...any) (sql.Result, error)) { for _, name := range dbtablenames { if _, err := exec( "ALTER TABLE " + name + " ENABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) } } } // Close the lock connection, then the database. func (db *TestDB) Close() { if db.lockConn != nil { if err := db.lockConn.Close(); err != nil { db.tb.Error("lockConn.Close() failed", err) } db.lockConn = nil } if db.DB != nil { if err := db.DB.Close(); err != nil { db.tb.Error("db.Close() failed", err) } db.DB = nil } }