Newer
Older
package database
import (
"context"
"os"
"testing"
"github.com/golang-migrate/migrate/v4/source"
"github.com/jmoiron/sqlx"
)
var (
dbtablenames []string
)
const dbLockID = 15104
func GetDSN(tb testing.TB) string {
dsn := os.Getenv("TEST_DB_DSN")
if dsn == "" {
tb.Fatal("Please define a TEST_DB_DSN environment variable")
}
return dsn
}
func clearDB(ctx context.Context, t testing.TB, c *sqlx.Conn, dsn string) {
// Drop all tables, and reinit
if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil {
t.Fatal(err)
}
}
// TestDB is a sqlx.DB wrapper that release a global lock on close
type TestDB struct {
*sqlx.DB
// SetTB changes the current tb and returns a function to restore the original one
func (db *TestDB) SetTB(tb testing.TB) func() {
save := db.tb
db.tb = tb
return func() {
db.tb = save
}
}
// Clear empty the database
func (db *TestDB) Clear() {
tx, err := db.lockConn.BeginTxx(db.ctx, nil)
require.NoError(db.tb, err)
for _, name := range dbtablenames {
if _, err := tx.Exec(
"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
for _, name := range dbtablenames {
if _, err := tx.Exec(
"DELETE FROM " + name, //nolint:gosec
); err != nil {
db.tb.Fatal(err)
}
}
for _, name := range dbtablenames {
if _, err := tx.Exec(
"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
require.NoError(db.tb, tx.Commit())
}
func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB {
var success bool
dsn := GetDSN(tb)
db, err := Open(dsn, 0)
require.NoError(tb, err)
defer func() {
if !success {
if err := db.Close(); err != nil {
tb.Log("Error closing db:", err)
}
}
}()
c, err := db.Connx(ctx)
require.NoError(tb, err)
success = true
return &TestDB{db, ctx, tb, c}
}
// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called
func GetTestDB(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB {
var success bool
db, err := Open(dsn, 0)
if err != nil {
t.Fatal(err)
}
defer func() {
if !success {
if err := db.Close(); err != nil {
t.Log("Error closing db:", err)
}
}
}()
if err != nil {
t.Fatal(err)
}
defer func() {
if !success {
if err := c.Close(); err != nil {
t.Log("Error closing conn:", err)
}
}
}()
"SELECT pg_advisory_lock($1)", dbLockID,
); err != nil {
t.Fatal(err)
}
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
m, err := NewMigrate(dsn, sourceDriver)
if err != nil {
t.Fatal(err)
}
if err := m.Up(); err != nil {
t.Fatal(err)
}
if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
t.Fatal(srcErr, dbErr)
}
}
if len(dbtablenames) == 0 {
if err := db.Select(&dbtablenames,
"SELECT tablename "+
"FROM pg_catalog.pg_tables "+
"WHERE schemaname = 'public'",
); err != nil {
t.Fatal(err)
}
} else {
for _, name := range dbtablenames {
if _, err := db.Exec(
"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
); err != nil {
t.Fatal(err)
return nil
}
}
for _, name := range dbtablenames {
if _, err := db.Exec(
"DELETE FROM " + name, //nolint:gosec
); err != nil {
t.Fatal(err)
return nil
}
}
for _, name := range dbtablenames {
if _, err := db.Exec(
"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
); err != nil {
t.Fatal(err)
return nil
}
}
}
success = true
return &TestDB{db, ctx, t, c}
}
// Close the lock connection, then the database
if err := db.lockConn.Close(); err != nil {
db.tb.Error("lockConn.Close() failed", err)
}
if err := db.DB.Close(); err != nil {
db.tb.Error("db.Close() failed", err)
}