package database import ( "context" "os" "testing" "github.com/golang-migrate/migrate/v4/source" "github.com/jmoiron/sqlx" ) var ( dbtablenames []string ) const dbLockID = 15104 func clearDB(t *testing.T, c *sqlx.Conn, dsn string) { // Drop all tables, and reinit if _, err := c.ExecContext(context.Background(), "DROP SCHEMA IF EXISTS public CASCADE"); err != nil { t.Fatal(err) } if _, err := c.ExecContext(context.Background(), "CREATE SCHEMA public"); err != nil { t.Fatal(err) } } // TestDB is a sqlx.DB wrapper that release a global lock on close type TestDB struct { *sqlx.DB tb testing.TB lockConn *sqlx.Conn } // GetTestDB creates a db and returns it. It must be closed within the test. // If it fails, t.Fatal() is called func GetTestDB(t *testing.T, sourceDriver source.Driver) *TestDB { var success bool dsn := os.Getenv("TEST_DB_DSN") if dsn == "" { t.Fatal("Please define a TEST_DB_DSN environment variable") } db, err := Open(dsn, 0) if err != nil { t.Fatal(err) } defer func() { if !success { if err := db.Close(); err != nil { t.Log("Error closing db:", err) } } }() c, err := db.Connx(context.Background()) if err != nil { t.Fatal(err) } defer func() { if !success { if err := c.Close(); err != nil { t.Log("Error closing conn:", err) } } }() if _, err := c.ExecContext( context.Background(), "SELECT pg_advisory_lock($1)", dbLockID, ); err != nil { t.Fatal(err) } if len(dbtablenames) == 0 { clearDB(t, c, dsn) m, err := NewMigrate(dsn, sourceDriver) if err != nil { t.Fatal(err) } if err := m.Up(); err != nil { t.Fatal(err) } if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil { t.Fatal(srcErr, dbErr) } } if len(dbtablenames) == 0 { if err := db.Select(&dbtablenames, "SELECT tablename "+ "FROM pg_catalog.pg_tables "+ "WHERE schemaname = 'public'", ); err != nil { t.Fatal(err) } } else { for _, name := range dbtablenames { if _, err := db.Exec( "ALTER TABLE " + name + " DISABLE TRIGGER ALL", ); err != nil { t.Fatal(err) return nil } } for _, name := range dbtablenames { if _, err := db.Exec( "DELETE FROM " + name, //nolint:gosec ); err != nil { t.Fatal(err) return nil } } for _, name := range dbtablenames { if _, err := db.Exec( "ALTER TABLE " + name + " ENABLE TRIGGER ALL", ); err != nil { t.Fatal(err) return nil } } } success = true return &TestDB{db, t, c} } // Close the lock connection, then the database func (db *TestDB) Close() { if err := db.lockConn.Close(); err != nil { db.tb.Error("lockConn.Close() failed", err) } if err := db.DB.Close(); err != nil { db.tb.Error("db.Close() failed", err) } }