Skip to content
Snippets Groups Projects
test.go 2.72 KiB
Newer Older
package database

import (
	"context"
	"os"
	"testing"

	"github.com/golang-migrate/migrate/v4/source"
	"github.com/jmoiron/sqlx"
)

var (
	dbtablenames []string
)

const dbLockID = 15104

func clearDB(t *testing.T, c *sqlx.Conn, dsn string) {
	// Drop all tables, and reinit
	if _, err := c.ExecContext(context.Background(), "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
	if _, err := c.ExecContext(context.Background(), "CREATE SCHEMA public"); err != nil {
		t.Fatal(err)
	}
}

// TestDB is a sqlx.DB wrapper that release a global lock on close
type TestDB struct {
	*sqlx.DB
	tb       testing.TB
	lockConn *sqlx.Conn
}

// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called
func GetTestDB(t *testing.T, sourceDriver source.Driver) *TestDB {
	dsn := os.Getenv("TEST_DB_DSN")
	if dsn == "" {
		t.Fatal("Please define a TEST_DB_DSN environment variable")
	}
	db, err := Open(dsn, 0)
	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if !success {
			if err := db.Close(); err != nil {
				t.Log("Error closing db:", err)
			}
		}
	}()

	c, err := db.Connx(context.Background())
	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if !success {
			if err := c.Close(); err != nil {
				t.Log("Error closing conn:", err)
			}
		}
	}()

	if _, err := c.ExecContext(
		context.Background(),
		"SELECT pg_advisory_lock($1)", dbLockID,
	); err != nil {
		t.Fatal(err)
	}

	if len(dbtablenames) == 0 {
		clearDB(t, c, dsn)

		m, err := NewMigrate(dsn, sourceDriver)
		if err != nil {
			t.Fatal(err)
		}
		if err := m.Up(); err != nil {
			t.Fatal(err)
		}
		if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
			t.Fatal(srcErr, dbErr)
		}
	}
	if len(dbtablenames) == 0 {
		if err := db.Select(&dbtablenames,
			"SELECT tablename "+
				"FROM pg_catalog.pg_tables "+
				"WHERE schemaname = 'public'",
		); err != nil {
			t.Fatal(err)
		}
	} else {
		for _, name := range dbtablenames {
			if _, err := db.Exec(
				"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
			); err != nil {
				t.Fatal(err)
				return nil
			}
		}
		for _, name := range dbtablenames {
			if _, err := db.Exec(
				"DELETE FROM " + name, //nolint:gosec
			); err != nil {
				t.Fatal(err)
				return nil
			}
		}
		for _, name := range dbtablenames {
			if _, err := db.Exec(
				"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
			); err != nil {
				t.Fatal(err)
				return nil
			}
		}
	}

	return &TestDB{db, t, c}
}

// Close the lock connection, then the database
func (db *TestDB) Close() {
	if err := db.lockConn.Close(); err != nil {
		db.tb.Error("lockConn.Close() failed", err)
	}

	if err := db.DB.Close(); err != nil {
		db.tb.Error("db.Close() failed", err)
	}