# HG changeset patch # User Christophe de Vienne <christophe@cdevienne.info> # Date 1706007683 -3600 # Tue Jan 23 12:01:23 2024 +0100 # Node ID c357c7fe204244566f1af396050f9581c2a7a83a # Parent 7209cf17ed9ac7fa5c8aff8dedfab1b9445a7447 Cleanup the ClearDB api diff --git a/database/test.go b/database/test.go --- a/database/test.go +++ b/database/test.go @@ -2,7 +2,9 @@ import ( "context" + "database/sql" "os" + "slices" "testing" "github.com/golang-migrate/migrate/v4/source" @@ -51,17 +53,18 @@ } } -// Clear empty the database -func (db *TestDB) Clear() { +// ClearDB empty the database +func (db *TestDB) ClearDB() { tx, err := db.lockConn.BeginTxx(db.ctx, nil) require.NoError(db.tb, err) - for _, name := range dbtablenames { - if _, err := tx.Exec( - "ALTER TABLE " + name + " DISABLE TRIGGER ALL", - ); err != nil { - db.tb.Fatal(err) - } - } + + defer func() { + require.NoError(db.tb, tx.Commit()) + }() + + db.DisableConstraints(tx.Exec) + defer db.EnableConstraints(tx.Exec) + for _, name := range dbtablenames { if _, err := tx.Exec( "DELETE FROM " + name, //nolint:gosec @@ -69,14 +72,22 @@ db.tb.Fatal(err) } } - for _, name := range dbtablenames { - if _, err := tx.Exec( - "ALTER TABLE " + name + " ENABLE TRIGGER ALL", - ); err != nil { - db.tb.Fatal(err) - } + + var alltables []string + require.NoError(db.tb, + tx.Select(&alltables, + "SELECT tablename "+ + "FROM pg_catalog.pg_tables "+ + "WHERE schemaname = 'public' "+ + "AND tablename <> 'schema_migrations';", + )) + alltables = slices.DeleteFunc(alltables, func(name string) bool { + return slices.Contains(dbtablenames, name) + }) + for _, name := range alltables { + _, err := tx.Exec("DROP TABLE " + name) + require.NoError(db.tb, err) } - require.NoError(db.tb, tx.Commit()) } func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB { @@ -197,9 +208,9 @@ return &testDB } -func (db *TestDB) DisableConstraints() { +func (db *TestDB) DisableConstraints(exec func(string, ...any) (sql.Result, error)) { for _, name := range dbtablenames { - if _, err := db.Exec( + if _, err := exec( "ALTER TABLE " + name + " DISABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) @@ -207,9 +218,9 @@ } } -func (db *TestDB) EnableConstraints() { +func (db *TestDB) EnableConstraints(exec func(string, ...any) (sql.Result, error)) { for _, name := range dbtablenames { - if _, err := db.Exec( + if _, err := exec( "ALTER TABLE " + name + " ENABLE TRIGGER ALL", ); err != nil { db.tb.Fatal(err) @@ -217,18 +228,6 @@ } } -func (db *TestDB) ClearDB() { - db.DisableConstraints() - defer db.EnableConstraints() - for _, name := range dbtablenames { - if _, err := db.Exec( - "DELETE FROM " + name, //nolint:gosec - ); err != nil { - db.tb.Fatal(err) - } - } -} - // Close the lock connection, then the database func (db *TestDB) Close() { if db.lockConn != nil {