Skip to content
Snippets Groups Projects
test.go 4.82 KiB
Newer Older
package database

import (
	"context"
	"database/sql"
	"slices"
	"testing"

	"github.com/golang-migrate/migrate/v4/source"
	"github.com/jmoiron/sqlx"
	"github.com/stretchr/testify/require"
var dbtablenames []string
func getDSN(tb testing.TB) string {
	dsn := os.Getenv("TEST_DB_DSN")
	if dsn == "" {
		tb.Fatal("Please define a TEST_DB_DSN environment variable")
	}
	return dsn
}

func clearDB(ctx context.Context, t testing.TB, c *sqlx.Conn, dsn string) {
	// Drop all tables, and reinit
	if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
	if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil {
// TestDB is a sqlx.DB wrapper that release a global lock on close.
type TestDB struct {
	*sqlx.DB
	DSN      string
	ctx      context.Context
	lockConn *sqlx.Conn
// SetTB changes the current tb and returns a function to restore the original one.
func (db *TestDB) SetTB(tb testing.TB) func() {
	save := db.tb
	db.tb = tb
	return func() {
		db.tb = save
	}
}

// ClearDB empty the database.
func (db *TestDB) ClearDB() {
	tx, err := db.lockConn.BeginTxx(db.ctx, nil)
	require.NoError(db.tb, err)

	defer func() {
		require.NoError(db.tb, tx.Commit())
	}()

	db.DisableConstraints(tx.Exec)
	defer db.EnableConstraints(tx.Exec)

	for _, name := range dbtablenames {
		if _, err := tx.Exec(
		); err != nil {
			db.tb.Fatal(err)
		}
	}

	var alltables []string
	require.NoError(db.tb,
		tx.Select(&alltables,
			"SELECT tablename "+
				"FROM pg_catalog.pg_tables "+
				"WHERE schemaname = 'public' "+
				"AND tablename <> 'schema_migrations';",
		))
	alltables = slices.DeleteFunc(alltables, func(name string) bool {
		return slices.Contains(dbtablenames, name)
	})
	for _, name := range alltables {
		_, err := tx.Exec("DROP TABLE " + name)
		require.NoError(db.tb, err)
func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB {
	var success bool

	dsn := getDSN(tb)

	db, err := Open(dsn, 0)
	require.NoError(tb, err)

	defer func() {
		if !success {
			if err := db.Close(); err != nil {
				tb.Log("Error closing db:", err)
			}
		}
	}()

	c, err := db.Connx(ctx)
	require.NoError(tb, err)

	defer func() {
		if !success {
			if err := c.Close(); err != nil {
				tb.Log("Error closing conn:", err)
			}
		}
	}()

	if _, err := c.ExecContext(
		ctx,
		"SELECT pg_advisory_lock($1)", dbLockID,
	); err != nil {
		tb.Fatal(err)
	}

	return &TestDB{db, dsn, ctx, tb, c}
// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called.
func GetTestDB(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB {
	db := GetTestDBNoClear(ctx, t, sourceDriver)
	db.ClearDB()
	return db
}

func GetTestDBNoClear(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB {
	dsn := getDSN(t)
	db, err := Open(dsn, 0)
	if err != nil {
		t.Fatal(err)
	}

			if ctx.Err() == nil {
				if err := db.Close(); err != nil {
					t.Log("Error closing db:", err)
				}
	c, err := db.Connx(ctx)
	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if !success {
			if err := c.Close(); err != nil {
				t.Log("Error closing conn:", err)
			}
		}
	}()

	if _, err := c.ExecContext(
		"SELECT pg_advisory_lock($1)", dbLockID,
	); err != nil {
		t.Fatal(err)
	}

	if len(dbtablenames) == 0 {
		clearDB(ctx, t, c, dsn)

		m, err := NewMigrate(dsn, sourceDriver)
		if err != nil {
			t.Fatal(err)
		}
		if err := m.Up(); err != nil {
			t.Fatal(err)
		}
		if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
			t.Fatal(srcErr, dbErr)
		}
	}

	testDB := TestDB{db, dsn, ctx, t, c}
	if len(dbtablenames) == 0 {
		if err := db.Select(&dbtablenames,
			"SELECT tablename "+
				"FROM pg_catalog.pg_tables "+
				"WHERE schemaname = 'public' "+
				"AND tablename <> 'schema_migrations';",
func (db *TestDB) DisableConstraints(exec func(string, ...any) (sql.Result, error)) {
	for _, name := range dbtablenames {
		if _, err := exec(
			"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
}

func (db *TestDB) EnableConstraints(exec func(string, ...any) (sql.Result, error)) {
	for _, name := range dbtablenames {
		if _, err := exec(
			"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
}

// Close the lock connection, then the database.
func (db *TestDB) Close() {
	if db.lockConn != nil {
		if err := db.lockConn.Close(); err != nil {
			db.tb.Error("lockConn.Close() failed", err)
		}
		db.lockConn = nil
	if db.DB != nil {
		if err := db.DB.Close(); err != nil {
			db.tb.Error("db.Close() failed", err)
		}
		db.DB = nil