Newer
Older
package database
import (
"context"
"os"
"testing"
"github.com/golang-migrate/migrate/v4/source"
"github.com/jmoiron/sqlx"
)
var (
dbtablenames []string
)
const dbLockID = 15104
func clearDB(t *testing.T, c *sqlx.Conn, dsn string) {
// Drop all tables, and reinit
if _, err := c.ExecContext(context.Background(), "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
if _, err := c.ExecContext(context.Background(), "CREATE SCHEMA public"); err != nil {
t.Fatal(err)
}
}
// TestDB is a sqlx.DB wrapper that release a global lock on close
type TestDB struct {
*sqlx.DB
tb testing.TB
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// SetTB changes the current tb and returns a function to restore the original one
func (db *TestDB) SetTB(tb testing.TB) func() {
save := db.tb
db.tb = tb
return func() {
db.tb = save
}
}
// Clear empty the database
func (db *TestDB) Clear() {
tx, err := db.lockConn.BeginTxx(context.Background(), nil)
require.NoError(db.tb, err)
for _, name := range dbtablenames {
if _, err := tx.Exec(
"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
for _, name := range dbtablenames {
if _, err := tx.Exec(
"DELETE FROM " + name, //nolint:gosec
); err != nil {
db.tb.Fatal(err)
}
}
for _, name := range dbtablenames {
if _, err := tx.Exec(
"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
); err != nil {
db.tb.Fatal(err)
}
}
require.NoError(db.tb, tx.Commit())
}
// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called
func GetTestDB(t *testing.T, sourceDriver source.Driver) *TestDB {
var success bool
dsn := os.Getenv("TEST_DB_DSN")
if dsn == "" {
t.Fatal("Please define a TEST_DB_DSN environment variable")
}
db, err := Open(dsn, 0)
if err != nil {
t.Fatal(err)
}
defer func() {
if !success {
if err := db.Close(); err != nil {
t.Log("Error closing db:", err)
}
}
}()
c, err := db.Connx(context.Background())
if err != nil {
t.Fatal(err)
}
defer func() {
if !success {
if err := c.Close(); err != nil {
t.Log("Error closing conn:", err)
}
}
}()
if _, err := c.ExecContext(
context.Background(),
"SELECT pg_advisory_lock($1)", dbLockID,
); err != nil {
t.Fatal(err)
}
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
m, err := NewMigrate(dsn, sourceDriver)
if err != nil {
t.Fatal(err)
}
if err := m.Up(); err != nil {
t.Fatal(err)
}
if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
t.Fatal(srcErr, dbErr)
}
}
if len(dbtablenames) == 0 {
if err := db.Select(&dbtablenames,
"SELECT tablename "+
"FROM pg_catalog.pg_tables "+
"WHERE schemaname = 'public'",
); err != nil {
t.Fatal(err)
}
} else {
for _, name := range dbtablenames {
if _, err := db.Exec(
"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
); err != nil {
t.Fatal(err)
return nil
}
}
for _, name := range dbtablenames {
if _, err := db.Exec(
"DELETE FROM " + name, //nolint:gosec
); err != nil {
t.Fatal(err)
return nil
}
}
for _, name := range dbtablenames {
if _, err := db.Exec(
"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
); err != nil {
t.Fatal(err)
return nil
}
}
}
success = true
return &TestDB{db, t, c}
}
// Close the lock connection, then the database
if err := db.lockConn.Close(); err != nil {
db.tb.Error("lockConn.Close() failed", err)
}
if err := db.DB.Close(); err != nil {
db.tb.Error("db.Close() failed", err)
}