Newer
Older
package database
import (
"context"
"os"
"testing"
"github.com/golang-migrate/migrate/v4/source"
"github.com/jmoiron/sqlx"
var dbtablenames []string
dsn := os.Getenv("TEST_DB_DSN")
if dsn == "" {
tb.Fatal("Please define a TEST_DB_DSN environment variable")
}
return dsn
}
//nolint:revive
func clearDB(ctx context.Context, t testing.TB, c *sqlx.Conn, dsn string) {
// Drop all tables, and reinit
if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil {
t.Fatal(err)
}
}
// TestDB is a sqlx.DB wrapper that release a global lock on close
type TestDB struct {
*sqlx.DB
// SetTB changes the current tb and returns a function to restore the original one
func (db *TestDB) SetTB(tb testing.TB) func() {
save := db.tb
db.tb = tb
return func() {
db.tb = save
}
}
// Clear empty the database
func (db *TestDB) Clear() {
tx, err := db.lockConn.BeginTxx(db.ctx, nil)
require.NoError(db.tb, err)
for _, name := range dbtablenames {
if _, err := tx.Exec(
"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
for _, name := range dbtablenames {
if _, err := tx.Exec(
"DELETE FROM " + name, //nolint:gosec
); err != nil {
db.tb.Fatal(err)
}
}
for _, name := range dbtablenames {
if _, err := tx.Exec(
"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
require.NoError(db.tb, tx.Commit())
}
func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB {
var success bool
db, err := Open(dsn, 0)
require.NoError(tb, err)
defer func() {
if !success {
if err := db.Close(); err != nil {
tb.Log("Error closing db:", err)
}
}
}()
c, err := db.Connx(ctx)
require.NoError(tb, err)
defer func() {
if !success {
if err := c.Close(); err != nil {
tb.Log("Error closing conn:", err)
}
}
}()
if _, err := c.ExecContext(
ctx,
"SELECT pg_advisory_lock($1)", dbLockID,
); err != nil {
tb.Fatal(err)
}
// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called
func GetTestDB(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB {
db := GetTestDBNoClear(ctx, t, sourceDriver)
db.ClearDB()
return db
}
func GetTestDBNoClear(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB {
var success bool
db, err := Open(dsn, 0)
if err != nil {
t.Fatal(err)
}
defer func() {
if !success {
if ctx.Err() == nil {
if err := db.Close(); err != nil {
t.Log("Error closing db:", err)
}
}
}
}()
if err != nil {
t.Fatal(err)
}
defer func() {
if !success {
if err := c.Close(); err != nil {
t.Log("Error closing conn:", err)
}
}
}()
"SELECT pg_advisory_lock($1)", dbLockID,
); err != nil {
t.Fatal(err)
}
m, err := NewMigrate(dsn, sourceDriver)
if err != nil {
t.Fatal(err)
}
if err := m.Up(); err != nil {
t.Fatal(err)
}
if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
t.Fatal(srcErr, dbErr)
}
}
testDB := TestDB{db, dsn, ctx, t, c}
if len(dbtablenames) == 0 {
if err := db.Select(&dbtablenames,
"SELECT tablename "+
"FROM pg_catalog.pg_tables "+
"WHERE schemaname = 'public' "+
"AND tablename <> 'schema_migrations';",
); err != nil {
t.Fatal(err)
}
}
success = true
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
return &testDB
}
func (db *TestDB) DisableConstraints() {
for _, name := range dbtablenames {
if _, err := db.Exec(
"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
}
func (db *TestDB) EnableConstraints() {
for _, name := range dbtablenames {
if _, err := db.Exec(
"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
}
func (db *TestDB) ClearDB() {
db.DisableConstraints()
defer db.EnableConstraints()
for _, name := range dbtablenames {
if _, err := db.Exec(
"DELETE FROM " + name, //nolint:gosec
); err != nil {
db.tb.Fatal(err)
}
}
}
// Close the lock connection, then the database
if db.lockConn != nil {
if err := db.lockConn.Close(); err != nil {
db.tb.Error("lockConn.Close() failed", err)
}
db.lockConn = nil
if db.DB != nil {
if err := db.DB.Close(); err != nil {
db.tb.Error("db.Close() failed", err)
}
db.DB = nil