diff --git a/database/test.go b/database/test.go index 408c56d48ae9643de40e39f879eca656e6bcf989_ZGF0YWJhc2UvdGVzdC5nbw==..fe82db40aa33ddeec5923dd8a2cf69d5e4d31d54_ZGF0YWJhc2UvdGVzdC5nbw== 100644 --- a/database/test.go +++ b/database/test.go @@ -35,7 +35,9 @@ // GetTestDB creates a db and returns it. It must be closed within the test. // If it fails, t.Fatal() is called func GetTestDB(t *testing.T, sourceDriver source.Driver) *TestDB { + var success bool + dsn := os.Getenv("TEST_DB_DSN") if dsn == "" { t.Fatal("Please define a TEST_DB_DSN environment variable") } @@ -38,9 +40,10 @@ dsn := os.Getenv("TEST_DB_DSN") if dsn == "" { t.Fatal("Please define a TEST_DB_DSN environment variable") } + db, err := Open(dsn, 0) if err != nil { t.Fatal(err) } @@ -42,7 +45,15 @@ db, err := Open(dsn, 0) if err != nil { t.Fatal(err) } + defer func() { + if !success { + if err := db.Close(); err != nil { + t.Log("Error closing db:", err) + } + } + }() + c, err := db.Connx(context.Background()) if err != nil { @@ -47,6 +58,5 @@ c, err := db.Connx(context.Background()) if err != nil { - _ = db.Close() t.Fatal(err) } @@ -50,7 +60,15 @@ t.Fatal(err) } + defer func() { + if !success { + if err := c.Close(); err != nil { + t.Log("Error closing conn:", err) + } + } + }() + if _, err := c.ExecContext( context.Background(), "SELECT pg_advisory_lock($1)", dbLockID, ); err != nil { @@ -53,8 +71,7 @@ if _, err := c.ExecContext( context.Background(), "SELECT pg_advisory_lock($1)", dbLockID, ); err != nil { - _ = db.Close() t.Fatal(err) } @@ -107,6 +124,7 @@ } } + success = true return &TestDB{db, t, c} }