Skip to content
Snippets Groups Projects
test.go 4.9 KiB
Newer Older
package database

import (
	"context"
	"database/sql"
	"slices"
	"testing"

	"github.com/golang-migrate/migrate/v4/source"
	"github.com/jmoiron/sqlx"
	"github.com/stretchr/testify/require"
var dbtablenames []string
func getDSN(tb testing.TB) string {
Axel Prel's avatar
Axel Prel committed
	tb.Helper()
	dsn := os.Getenv("TEST_DB_DSN")
	if dsn == "" {
		tb.Fatal("Please define a TEST_DB_DSN environment variable")
	}
Axel Prel's avatar
Axel Prel committed

Axel Prel's avatar
Axel Prel committed
func clearDB(ctx context.Context, tb testing.TB, c *sqlx.Conn) {
	tb.Helper()
	// Drop all tables, and reinit
	if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
Axel Prel's avatar
Axel Prel committed
		tb.Fatal(err)
	if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil {
Axel Prel's avatar
Axel Prel committed
		tb.Fatal(err)
// TestDB is a sqlx.DB wrapper that release a global lock on close.
type TestDB struct {
	*sqlx.DB
	DSN      string
	ctx      context.Context
	lockConn *sqlx.Conn
// SetTB changes the current tb and returns a function to restore the original one.
func (db *TestDB) SetTB(tb testing.TB) func() {
Axel Prel's avatar
Axel Prel committed
	tb.Helper()
	save := db.tb
	db.tb = tb
Axel Prel's avatar
Axel Prel committed

	return func() {
		db.tb = save
	}
}

// ClearDB empty the database.
func (db *TestDB) ClearDB() {
	tx, err := db.lockConn.BeginTxx(db.ctx, nil)
	require.NoError(db.tb, err)

	defer func() {
		require.NoError(db.tb, tx.Commit())
	}()

	db.DisableConstraints(tx.Exec)
	defer db.EnableConstraints(tx.Exec)

	for _, name := range dbtablenames {
		if _, err := tx.Exec(
Axel Prel's avatar
Axel Prel committed
			"DELETE FROM " + name,
		); err != nil {
			db.tb.Fatal(err)
		}
	}

	var alltables []string
	require.NoError(db.tb,
		tx.Select(&alltables,
			"SELECT tablename "+
				"FROM pg_catalog.pg_tables "+
				"WHERE schemaname = 'public' "+
				"AND tablename <> 'schema_migrations';",
		))
	alltables = slices.DeleteFunc(alltables, func(name string) bool {
		return slices.Contains(dbtablenames, name)
	})
	for _, name := range alltables {
		_, err := tx.Exec("DROP TABLE " + name)
		require.NoError(db.tb, err)
func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB {
Axel Prel's avatar
Axel Prel committed
	tb.Helper()
	dsn := getDSN(tb)

	db, err := Open(dsn, 0)
	require.NoError(tb, err)

	defer func() {
		if !success {
			if err := db.Close(); err != nil {
				tb.Log("Error closing db:", err)
			}
		}
	}()

	c, err := db.Connx(ctx)
	require.NoError(tb, err)

	defer func() {
		if !success {
			if err := c.Close(); err != nil {
				tb.Log("Error closing conn:", err)
			}
		}
	}()

	if _, err := c.ExecContext(
		ctx,
		"SELECT pg_advisory_lock($1)", dbLockID,
	); err != nil {
		tb.Fatal(err)
	}

	return &TestDB{db, dsn, ctx, tb, c}
// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called.
Axel Prel's avatar
Axel Prel committed
func GetTestDB(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB {
	tb.Helper()
	db := GetTestDBNoClear(ctx, tb, sourceDriver)
Axel Prel's avatar
Axel Prel committed

Axel Prel's avatar
Axel Prel committed
func GetTestDBNoClear(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB {
	tb.Helper()
Axel Prel's avatar
Axel Prel committed
	dsn := getDSN(tb)
	db, err := Open(dsn, 0)
	if err != nil {
Axel Prel's avatar
Axel Prel committed
		tb.Fatal(err)
			if ctx.Err() == nil {
				if err := db.Close(); err != nil {
Axel Prel's avatar
Axel Prel committed
					tb.Log("Error closing db:", err)
	c, err := db.Connx(ctx)
	if err != nil {
Axel Prel's avatar
Axel Prel committed
		tb.Fatal(err)
	defer func() {
		if !success {
			if err := c.Close(); err != nil {
Axel Prel's avatar
Axel Prel committed
				tb.Log("Error closing conn:", err)
	if _, err := c.ExecContext(
		"SELECT pg_advisory_lock($1)", dbLockID,
	); err != nil {
Axel Prel's avatar
Axel Prel committed
		tb.Fatal(err)
	if len(dbtablenames) == 0 {
Axel Prel's avatar
Axel Prel committed
		clearDB(ctx, tb, c)

		m, err := NewMigrate(dsn, sourceDriver)
		if err != nil {
Axel Prel's avatar
Axel Prel committed
			tb.Fatal(err)
		}
		if err := m.Up(); err != nil {
Axel Prel's avatar
Axel Prel committed
			tb.Fatal(err)
		}
		if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
Axel Prel's avatar
Axel Prel committed
			tb.Fatal(srcErr, dbErr)
Axel Prel's avatar
Axel Prel committed
	testDB := TestDB{db, dsn, ctx, tb, c}
	if len(dbtablenames) == 0 {
		if err := db.Select(&dbtablenames,
			"SELECT tablename "+
				"FROM pg_catalog.pg_tables "+
				"WHERE schemaname = 'public' "+
				"AND tablename <> 'schema_migrations';",
Axel Prel's avatar
Axel Prel committed
			tb.Fatal(err)
Axel Prel's avatar
Axel Prel committed

func (db *TestDB) DisableConstraints(exec func(string, ...any) (sql.Result, error)) {
	for _, name := range dbtablenames {
		if _, err := exec(
			"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
}

func (db *TestDB) EnableConstraints(exec func(string, ...any) (sql.Result, error)) {
	for _, name := range dbtablenames {
		if _, err := exec(
			"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
}

// Close the lock connection, then the database.
func (db *TestDB) Close() {
	if db.lockConn != nil {
		if err := db.lockConn.Close(); err != nil {
			db.tb.Error("lockConn.Close() failed", err)
		}
		db.lockConn = nil
	if db.DB != nil {
		if err := db.DB.Close(); err != nil {
			db.tb.Error("db.Close() failed", err)
		}
		db.DB = nil