package database

import (
	"context"
	"database/sql"
	"os"
	"slices"
	"testing"

	"github.com/golang-migrate/migrate/v4/source"
	"github.com/jmoiron/sqlx"
	"github.com/stretchr/testify/require"
)

var dbtablenames []string

const dbLockID = 15104

func getDSN(tb testing.TB) string {
	tb.Helper()
	dsn := os.Getenv("TEST_DB_DSN")
	if dsn == "" {
		tb.Fatal("Please define a TEST_DB_DSN environment variable")
	}

	return dsn
}

func clearDB(ctx context.Context, tb testing.TB, c *sqlx.Conn) {
	tb.Helper()
	// Drop all tables, and reinit
	if _, err := c.ExecContext(ctx, "DROP SCHEMA IF EXISTS public CASCADE"); err != nil {
		tb.Fatal(err)
	}
	if _, err := c.ExecContext(ctx, "CREATE SCHEMA public"); err != nil {
		tb.Fatal(err)
	}
}

// TestDB is a sqlx.DB wrapper that release a global lock on close.
type TestDB struct {
	*sqlx.DB
	DSN      string
	ctx      context.Context
	tb       testing.TB
	lockConn *sqlx.Conn
}

// SetTB changes the current tb and returns a function to restore the original one.
func (db *TestDB) SetTB(tb testing.TB) func() {
	tb.Helper()
	save := db.tb
	db.tb = tb

	return func() {
		db.tb = save
	}
}

// ClearDB empty the database.
func (db *TestDB) ClearDB() {
	tx, err := db.lockConn.BeginTxx(db.ctx, nil)
	require.NoError(db.tb, err)

	defer func() {
		require.NoError(db.tb, tx.Commit())
	}()

	db.DisableConstraints(tx.Exec)
	defer db.EnableConstraints(tx.Exec)

	for _, name := range dbtablenames {
		if _, err := tx.Exec(
			"DELETE FROM " + name,
		); err != nil {
			db.tb.Fatal(err)
		}
	}

	var alltables []string
	require.NoError(db.tb,
		tx.Select(&alltables,
			"SELECT tablename "+
				"FROM pg_catalog.pg_tables "+
				"WHERE schemaname = 'public' "+
				"AND tablename <> 'schema_migrations';",
		))
	alltables = slices.DeleteFunc(alltables, func(name string) bool {
		return slices.Contains(dbtablenames, name)
	})
	for _, name := range alltables {
		_, err := tx.Exec("DROP TABLE " + name)
		require.NoError(db.tb, err)
	}
}

func GetTestDBNoInit(ctx context.Context, tb testing.TB) *TestDB {
	tb.Helper()
	var success bool

	dsn := getDSN(tb)

	db, err := Open(dsn, 0)
	require.NoError(tb, err)

	defer func() {
		if !success {
			if err := db.Close(); err != nil {
				tb.Log("Error closing db:", err)
			}
		}
	}()

	c, err := db.Connx(ctx)
	require.NoError(tb, err)

	defer func() {
		if !success {
			if err := c.Close(); err != nil {
				tb.Log("Error closing conn:", err)
			}
		}
	}()

	if _, err := c.ExecContext(
		ctx,
		"SELECT pg_advisory_lock($1)", dbLockID,
	); err != nil {
		tb.Fatal(err)
	}

	success = true

	return &TestDB{db, dsn, ctx, tb, c}
}

// GetTestDB creates a db and returns it. It must be closed within the test.
// If it fails, t.Fatal() is called.
func GetTestDB(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB {
	tb.Helper()
	db := GetTestDBNoClear(ctx, tb, sourceDriver)
	db.ClearDB()

	return db
}

func GetTestDBNoClear(ctx context.Context, tb testing.TB, sourceDriver source.Driver) *TestDB {
	tb.Helper()
	var success bool

	dsn := getDSN(tb)

	db, err := Open(dsn, 0)
	if err != nil {
		tb.Fatal(err)
	}

	defer func() {
		if !success {
			if ctx.Err() == nil {
				if err := db.Close(); err != nil {
					tb.Log("Error closing db:", err)
				}
			}
		}
	}()

	c, err := db.Connx(ctx)
	if err != nil {
		tb.Fatal(err)
	}

	defer func() {
		if !success {
			if err := c.Close(); err != nil {
				tb.Log("Error closing conn:", err)
			}
		}
	}()

	if _, err := c.ExecContext(
		ctx,
		"SELECT pg_advisory_lock($1)", dbLockID,
	); err != nil {
		tb.Fatal(err)
	}

	if len(dbtablenames) == 0 {
		clearDB(ctx, tb, c)

		m, err := NewMigrate(dsn, sourceDriver)
		if err != nil {
			tb.Fatal(err)
		}
		if err := m.Up(); err != nil {
			tb.Fatal(err)
		}
		if srcErr, dbErr := m.Close(); srcErr != nil || dbErr != nil {
			tb.Fatal(srcErr, dbErr)
		}
	}

	testDB := TestDB{db, dsn, ctx, tb, c}
	if len(dbtablenames) == 0 {
		if err := db.Select(&dbtablenames,
			"SELECT tablename "+
				"FROM pg_catalog.pg_tables "+
				"WHERE schemaname = 'public' "+
				"AND tablename <> 'schema_migrations';",
		); err != nil {
			tb.Fatal(err)
		}
	}

	success = true

	return &testDB
}

func (db *TestDB) DisableConstraints(exec func(string, ...any) (sql.Result, error)) {
	for _, name := range dbtablenames {
		if _, err := exec(
			"ALTER TABLE " + name + " DISABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
}

func (db *TestDB) EnableConstraints(exec func(string, ...any) (sql.Result, error)) {
	for _, name := range dbtablenames {
		if _, err := exec(
			"ALTER TABLE " + name + " ENABLE TRIGGER ALL",
		); err != nil {
			db.tb.Fatal(err)
		}
	}
}

// Close the lock connection, then the database.
func (db *TestDB) Close() {
	if db.lockConn != nil {
		if err := db.lockConn.Close(); err != nil {
			db.tb.Error("lockConn.Close() failed", err)
		}
		db.lockConn = nil
	}

	if db.DB != nil {
		if err := db.DB.Close(); err != nil {
			db.tb.Error("db.Close() failed", err)
		}
		db.DB = nil
	}
}