package database

import (
	"context"
	"os"
	"testing"

	"github.com/golang-migrate/migrate/v4/source"
	"github.com/jmoiron/sqlx"
	"github.com/stretchr/testify/require"
)

var (
	dbtablenames []string
)

const dbLockID = 15104

func getDSN(tb testing.TB) string {
	dsn := os.Getenv("TEST_DB_DSN")
	if dsn == "" {
		tb.Fatal("Please define a TEST_DB_DSN environment variable")
	}
	return dsn
}

func clearDB(ctx context.Context, t testing.TB, c *sqlx.Conn, dsn string) {
	// Drop all tables, and reinit
	if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
		t.Fatal(err)
	}
	if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil {
		t.Fatal(err)
	}
}

// TestDB is a sqlx.DB wrapper that release a global lock on close
type TestDB struct {
	*sqlx.DB
	DSN      string
	ctx      context.Context
	tb       testing.TB
	lockConn *sqlx.Conn
}

// SetTB changes the current tb and returns a function to restore the original one
func (db *TestDB) SetTB(tb testing.TB) func() {
	save := db.tb
	db.tb = tb
	return func() {
		db.tb = save
	}
}

// Clear empty the database
func (db *TestDB) Clear() {
	tx, err := db.lockConn.BeginTxx(db.ctx, nil)
	require.NoError(db.tb, err)
	for _, name := range dbtablenames {
		if _, err := tx.Exec(
			"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
	for _, name := range dbtablenames {
		if _, err := tx.Exec(
			"DELETE FROM " + name, //nolint:gosec
		); err != nil {
			db.tb.Fatal(err)
		}
	}
	for _, name := range dbtablenames {
		if _, err := tx.Exec(
			"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
	require.NoError(db.tb, tx.Commit())
}

func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB {
	var success bool

	dsn := getDSN(tb)

	db, err := Open(dsn, 0)
	require.NoError(tb, err)

	defer func() {
		if !success {
			if err := db.Close(); err != nil {
				tb.Log("Error closing db:", err)
			}
		}
	}()

	c, err := db.Connx(ctx)
	require.NoError(tb, err)

	defer func() {
		if !success {
			if err := c.Close(); err != nil {
				tb.Log("Error closing conn:", err)
			}
		}
	}()

	if _, err := c.ExecContext(
		ctx,
		"SELECT pg_advisory_lock($1)", dbLockID,
	); err != nil {
		tb.Fatal(err)
	}

	success = true

	return &TestDB{db, dsn, ctx, tb, c}
}

// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called
func GetTestDB(ctx context.Context, t testing.TB, sourceDriver source.Driver) *TestDB {
	var success bool

	dsn := getDSN(t)

	db, err := Open(dsn, 0)
	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if !success {
			if ctx.Err() == nil {
				if err := db.Close(); err != nil {
					t.Log("Error closing db:", err)
				}
			}
		}
	}()

	c, err := db.Connx(ctx)
	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if !success {
			if err := c.Close(); err != nil {
				t.Log("Error closing conn:", err)
			}
		}
	}()

	if _, err := c.ExecContext(
		ctx,
		"SELECT pg_advisory_lock($1)", dbLockID,
	); err != nil {
		t.Fatal(err)
	}

	if len(dbtablenames) == 0 {
		clearDB(ctx, t, c, dsn)

		m, err := NewMigrate(dsn, sourceDriver)
		if err != nil {
			t.Fatal(err)
		}
		if err := m.Up(); err != nil {
			t.Fatal(err)
		}
		if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
			t.Fatal(srcErr, dbErr)
		}
	}

	testDB := TestDB{db, dsn, ctx, t, c}
	if len(dbtablenames) == 0 {
		if err := db.Select(&dbtablenames,
			"SELECT tablename "+
				"FROM pg_catalog.pg_tables "+
				"WHERE schemaname = 'public'",
		); err != nil {
			t.Fatal(err)
		}
	} else {
		testDB.ClearDB()
	}

	success = true
	return &testDB
}

func (db *TestDB) DisableConstraints() {
	for _, name := range dbtablenames {
		if _, err := db.Exec(
			"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
}

func (db *TestDB) EnableConstraints() {
	for _, name := range dbtablenames {
		if _, err := db.Exec(
			"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
}

func (db *TestDB) ClearDB() {
	db.DisableConstraints()
	defer db.EnableConstraints()
	for _, name := range dbtablenames {
		if _, err := db.Exec(
			"DELETE FROM " + name, //nolint:gosec
		); err != nil {
			db.tb.Fatal(err)
		}
	}
}

// Close the lock connection, then the database
func (db *TestDB) Close() {
	if db.lockConn != nil {
		if err := db.lockConn.Close(); err != nil {
			db.tb.Error("lockConn.Close() failed", err)
		}
		db.lockConn = nil
	}

	if db.DB != nil {
		if err := db.DB.Close(); err != nil {
			db.tb.Error("db.Close() failed", err)
		}
		db.DB = nil
	}
}