# HG changeset patch # User Christophe de Vienne <christophe@cdevienne.info> # Date 1603722480 -3600 # Mon Oct 26 15:28:00 2020 +0100 # Node ID 2e7fd36e7588f706e20ec0b1a1c5e815cb08a1d5 # Parent 1ed99df25b0ed1e16c0dab292edbcde5c2e01627 TestDB: add SetDB and Clear diff --git a/database/test.go b/database/test.go --- a/database/test.go +++ b/database/test.go @@ -7,6 +7,7 @@ "github.com/golang-migrate/migrate/v4/source" "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/require" ) var ( @@ -32,6 +33,43 @@ lockConn *sqlx.Conn } +// SetTB changes the current tb and returns a function to restore the original one +func (db *TestDB) SetTB(tb testing.TB) func() { + save := db.tb + db.tb = tb + return func() { + db.tb = save + } +} + +// Clear empty the database +func (db *TestDB) Clear() { + tx, err := db.lockConn.BeginTxx(context.Background(), nil) + require.NoError(db.tb, err) + for _, name := range dbtablenames { + if _, err := tx.Exec( + "ALTER TABLE " + name + " DISABLE TRIGGER ALL", + ); err != nil { + db.tb.Fatal(err) + } + } + for _, name := range dbtablenames { + if _, err := tx.Exec( + "DELETE FROM " + name, //nolint:gosec + ); err != nil { + db.tb.Fatal(err) + } + } + for _, name := range dbtablenames { + if _, err := tx.Exec( + "ALTER TABLE " + name + " ENABLE TRIGGER ALL", + ); err != nil { + db.tb.Fatal(err) + } + } + require.NoError(db.tb, tx.Commit()) +} + // GetTestDB creates a db and returns it. It must be closed within the test. // If it fails, t.Fatal() is called func GetTestDB(t *testing.T, sourceDriver source.Driver) *TestDB {