package database import ( "context" "os" "testing" "github.com/golang-migrate/migrate/v4/source" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/require" ) var ( dbtablenames []string ) const dbLockID = 15104 func getDSN(tb testing.TB) string { dsn := os.Getenv("TEST_DB_DSN") if dsn == "" { tb.Fatal("Please define a TEST_DB_DSN environment variable") } return dsn } //nolint:revive func clearDB(ctx context.Context, t testing.TB, c *sqlx.Conn, dsn string) { // Drop all tables, and reinit if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil { t.Fatal(err) } if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil { t.Fatal(err) } } // TestDB is a sqlx.DB wrapper that release a global lock on close type TestDB struct { *sqlx.DB DSN string ctx context.Context tb testing.TB lockConn *sqlx.Conn } // SetTB changes the current tb and returns a function to restore the original one func (db *TestDB) SetTB(tb testing.TB) func() { save := db.tb db.tb = tb return func() { db.tb = save } } // Clear empty the database func (db *TestDB) Clear() { tx, err := db.lockConn.BeginTxx(db.ctx, nil) require.NoError(db.tb, err) for _, name := range dbtablenames { if _, err := tx.Exec( "ALTER TABLE " + name + " DISABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) } } for _, name := range dbtablenames { if _, err := tx.Exec( "DELETE FROM " + name, //nolint:gosec ); err != nil { db.tb.Fatal(err) } } for _, name := range dbtablenames { if _, err := tx.Exec( "ALTER TABLE " + name + " ENABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) } } require.NoError(db.tb, tx.Commit()) } func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB { var success bool dsn := getDSN(tb) db, err := Open(dsn, 0) require.NoError(tb, err) defer func() { if !success { if err := db.Close(); err != nil { tb.Log("Error closing db:", err) } } }() c, err := db.Connx(ctx) require.NoError(tb, err) defer func() { if !success { if err := c.Close(); err != nil { tb.Log("Error closing conn:", err) } } }() if _, err := c.ExecContext( ctx, "SELECT pg_advisory_lock($1)", dbLockID, ); err != nil { tb.Fatal(err) } success = true return &TestDB{db, dsn, ctx, tb, c} } // GetTestDB creates a db and returns it. It must be closed within the test. // If it fails, t.Fatal() is called func GetTestDB(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB { var success bool dsn := getDSN(t) db, err := Open(dsn, 0) if err != nil { t.Fatal(err) } defer func() { if !success { if ctx.Err() == nil { if err := db.Close(); err != nil { t.Log("Error closing db:", err) } } } }() c, err := db.Connx(ctx) if err != nil { t.Fatal(err) } defer func() { if !success { if err := c.Close(); err != nil { t.Log("Error closing conn:", err) } } }() if _, err := c.ExecContext( ctx, "SELECT pg_advisory_lock($1)", dbLockID, ); err != nil { t.Fatal(err) } if len(dbtablenames) == 0 { clearDB(ctx, t, c, dsn) m, err := NewMigrate(dsn, sourceDriver) if err != nil { t.Fatal(err) } if err := m.Up(); err != nil { t.Fatal(err) } if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil { t.Fatal(srcErr, dbErr) } } testDB := TestDB{db, dsn, ctx, t, c} if len(dbtablenames) == 0 { if err := db.Select(&dbtablenames, "SELECT tablename "+ "FROM pg_catalog.pg_tables "+ "WHERE schemaname = 'public'", ); err != nil { t.Fatal(err) } } else { testDB.ClearDB() } success = true return &testDB } func (db *TestDB) DisableConstraints() { for _, name := range dbtablenames { if _, err := db.Exec( "ALTER TABLE " + name + " DISABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) } } } func (db *TestDB) EnableConstraints() { for _, name := range dbtablenames { if _, err := db.Exec( "ALTER TABLE " + name + " ENABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) } } } func (db *TestDB) ClearDB() { db.DisableConstraints() defer db.EnableConstraints() for _, name := range dbtablenames { if _, err := db.Exec( "DELETE FROM " + name, //nolint:gosec ); err != nil { db.tb.Fatal(err) } } } // Close the lock connection, then the database func (db *TestDB) Close() { if db.lockConn != nil { if err := db.lockConn.Close(); err != nil { db.tb.Error("lockConn.Close() failed", err) } db.lockConn = nil } if db.DB != nil { if err := db.DB.Close(); err != nil { db.tb.Error("db.Close() failed", err) } db.DB = nil } }