diff --git a/database/test.go b/database/test.go index 7209cf17ed9ac7fa5c8aff8dedfab1b9445a7447_ZGF0YWJhc2UvdGVzdC5nbw==..c357c7fe204244566f1af396050f9581c2a7a83a_ZGF0YWJhc2UvdGVzdC5nbw== 100644 --- a/database/test.go +++ b/database/test.go @@ -2,4 +2,5 @@ import ( "context" + "database/sql" "os" @@ -5,4 +6,5 @@ "os" + "slices" "testing" "github.com/golang-migrate/migrate/v4/source" @@ -51,7 +53,7 @@ } } -// Clear empty the database -func (db *TestDB) Clear() { +// ClearDB empty the database +func (db *TestDB) ClearDB() { tx, err := db.lockConn.BeginTxx(db.ctx, nil) require.NoError(db.tb, err) @@ -56,12 +58,13 @@ tx, err := db.lockConn.BeginTxx(db.ctx, nil) require.NoError(db.tb, err) - for _, name := range dbtablenames { - if _, err := tx.Exec( - "ALTER TABLE " + name + " DISABLE TRIGGER ALL", - ); err != nil { - db.tb.Fatal(err) - } - } + + defer func() { + require.NoError(db.tb, tx.Commit()) + }() + + db.DisableConstraints(tx.Exec) + defer db.EnableConstraints(tx.Exec) + for _, name := range dbtablenames { if _, err := tx.Exec( "DELETE FROM " + name, //nolint:gosec @@ -69,10 +72,19 @@ db.tb.Fatal(err) } } - for _, name := range dbtablenames { - if _, err := tx.Exec( - "ALTER TABLE " + name + " ENABLE TRIGGER ALL", - ); err != nil { - db.tb.Fatal(err) - } + + var alltables []string + require.NoError(db.tb, + tx.Select(&alltables, + "SELECT tablename "+ + "FROM pg_catalog.pg_tables "+ + "WHERE schemaname = 'public' "+ + "AND tablename <> 'schema_migrations';", + )) + alltables = slices.DeleteFunc(alltables, func(name string) bool { + return slices.Contains(dbtablenames, name) + }) + for _, name := range alltables { + _, err := tx.Exec("DROP TABLE " + name) + require.NoError(db.tb, err) } @@ -78,5 +90,4 @@ } - require.NoError(db.tb, tx.Commit()) } func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB { @@ -197,5 +208,5 @@ return &testDB } -func (db *TestDB) DisableConstraints() { +func (db *TestDB) DisableConstraints(exec func(string, ...any) (sql.Result, error)) { for _, name := range dbtablenames { @@ -201,5 +212,5 @@ for _, name := range dbtablenames { - if _, err := db.Exec( + if _, err := exec( "ALTER TABLE " + name + " DISABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) @@ -207,5 +218,5 @@ } } -func (db *TestDB) EnableConstraints() { +func (db *TestDB) EnableConstraints(exec func(string, ...any) (sql.Result, error)) { for _, name := range dbtablenames { @@ -211,5 +222,5 @@ for _, name := range dbtablenames { - if _, err := db.Exec( + if _, err := exec( "ALTER TABLE " + name + " ENABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) @@ -217,18 +228,6 @@ } } -func (db *TestDB) ClearDB() { - db.DisableConstraints() - defer db.EnableConstraints() - for _, name := range dbtablenames { - if _, err := db.Exec( - "DELETE FROM " + name, //nolint:gosec - ); err != nil { - db.tb.Fatal(err) - } - } -} - // Close the lock connection, then the database func (db *TestDB) Close() { if db.lockConn != nil {