diff --git a/database/test.go b/database/test.go index 1a687f2d452daab915c2d2e9322060e7d83d2f51_ZGF0YWJhc2UvdGVzdC5nbw==..cbb304751ba80580514ead1e2ecc488e482213af_ZGF0YWJhc2UvdGVzdC5nbw== 100644 --- a/database/test.go +++ b/database/test.go @@ -16,7 +16,7 @@ const dbLockID = 15104 -func GetDSN(tb testing.TB) string { +func getDSN(tb testing.TB) string { dsn := os.Getenv("TEST_DB_DSN") if dsn == "" { tb.Fatal("Please define a TEST_DB_DSN environment variable") @@ -37,6 +37,7 @@ // TestDB is a sqlx.DB wrapper that release a global lock on close type TestDB struct { *sqlx.DB + DSN string ctx context.Context tb testing.TB lockConn *sqlx.Conn @@ -82,7 +83,7 @@ func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB { var success bool - dsn := GetDSN(tb) + dsn := getDSN(tb) db, err := Open(dsn, 0) require.NoError(tb, err) @@ -98,5 +99,20 @@ c, err := db.Connx(ctx) require.NoError(tb, err) + defer func() { + if !success { + if err := c.Close(); err != nil { + tb.Log("Error closing conn:", err) + } + } + }() + + if _, err := c.ExecContext( + ctx, + "SELECT pg_advisory_lock($1)", dbLockID, + ); err != nil { + tb.Fatal(err) + } + success = true @@ -101,6 +117,6 @@ success = true - return &TestDB{db, ctx, tb, c} + return &TestDB{db, dsn, ctx, tb, c} } // GetTestDB creates a db and returns it. It must be closed within the test. @@ -108,7 +124,7 @@ func GetTestDB(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB { var success bool - dsn := GetDSN(t) + dsn := getDSN(t) db, err := Open(dsn, 0) if err != nil { @@ -117,8 +133,10 @@ defer func() { if !success { - if err := db.Close(); err != nil { - t.Log("Error closing db:", err) + if ctx.Err() == nil { + if err := db.Close(); err != nil { + t.Log("Error closing db:", err) + } } } }() @@ -157,6 +175,8 @@ t.Fatal(srcErr, dbErr) } } + + testDB := TestDB{db, dsn, ctx, t, c} if len(dbtablenames) == 0 { if err := db.Select(&dbtablenames, "SELECT tablename "+ @@ -166,30 +186,7 @@ t.Fatal(err) } } else { - for _, name := range dbtablenames { - if _, err := db.Exec( - "ALTER TABLE " + name + " DISABLE TRIGGER ALL", - ); err != nil { - t.Fatal(err) - return nil - } - } - for _, name := range dbtablenames { - if _, err := db.Exec( - "DELETE FROM " + name, //nolint:gosec - ); err != nil { - t.Fatal(err) - return nil - } - } - for _, name := range dbtablenames { - if _, err := db.Exec( - "ALTER TABLE " + name + " ENABLE TRIGGER ALL", - ); err != nil { - t.Fatal(err) - return nil - } - } + testDB.ClearDB() } success = true @@ -193,8 +190,40 @@ } success = true - return &TestDB{db, ctx, t, c} + return &testDB +} + +func (db *TestDB) DisableConstraints() { + for _, name := range dbtablenames { + if _, err := db.Exec( + "ALTER TABLE " + name + " DISABLE TRIGGER ALL", + ); err != nil { + db.tb.Fatal(err) + } + } +} + +func (db *TestDB) EnableConstraints() { + for _, name := range dbtablenames { + if _, err := db.Exec( + "ALTER TABLE " + name + " ENABLE TRIGGER ALL", + ); err != nil { + db.tb.Fatal(err) + } + } +} + +func (db *TestDB) ClearDB() { + db.DisableConstraints() + defer db.EnableConstraints() + for _, name := range dbtablenames { + if _, err := db.Exec( + "DELETE FROM " + name, //nolint:gosec + ); err != nil { + db.tb.Fatal(err) + } + } } // Close the lock connection, then the database func (db *TestDB) Close() { @@ -197,8 +226,11 @@ } // Close the lock connection, then the database func (db *TestDB) Close() { - if err := db.lockConn.Close(); err != nil { - db.tb.Error("lockConn.Close() failed", err) + if db.lockConn != nil { + if err := db.lockConn.Close(); err != nil { + db.tb.Error("lockConn.Close() failed", err) + } + db.lockConn = nil } @@ -203,6 +235,9 @@ } - if err := db.DB.Close(); err != nil { - db.tb.Error("db.Close() failed", err) + if db.DB != nil { + if err := db.DB.Close(); err != nil { + db.tb.Error("db.Close() failed", err) + } + db.DB = nil } }