Skip to content
Snippets Groups Projects
sql_helper.go 6.48 KiB
package database

import (
	"context"
	"database/sql"

	sq "github.com/Masterminds/squirrel"
	"github.com/jmoiron/sqlx"
	"github.com/rs/zerolog"
)

// NewSQLHelper creates a SQLHelper. It must not outlive the given context and sql executor.
func NewSQLHelper(ctx context.Context, sqle SQLExecutor, log zerolog.Logger) SQLHelper {
	return SQLHelper{ctx, sqle, log}
}

// SQLHelper is a short lived helper to easily get/select Mapped structures.
type SQLHelper struct {
	ctx  context.Context
	sqle SQLExecutor
	log  zerolog.Logger
}

// Get loads a mapped structure.
func (h *SQLHelper) Get(obj interface{}, query sq.Sqlizer) error {
	return GetContext(h.ctx, h.sqle, obj, query, &h.log)
}

func (h *SQLHelper) GetRaw(obj interface{}, sql string, args ...interface{}) error {
	SQLTrace(&h.log, sql, args)

	return h.sqle.GetContext(h.ctx, obj, sql, args...)
}

// GetByPKey loads a mapped structure by primary key.
func (h *SQLHelper) GetByPKey(obj Mapped, value interface{}) error {
	return h.GetBy(obj, obj.PKeyColumn(), value)
}

// GetByPKeyForUpdate loads a mapped structure by primary key and get a lock
// (see https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE)
func (h *SQLHelper) GetByPKeyForUpdate(obj Mapped, value interface{}) error {
	return h.GetByForUpdate(obj, obj.PKeyColumn(), value)
}

// GetBy loads a mapped structure by a given column.
func (h *SQLHelper) GetBy(obj Mapped, column string, value interface{}) error {
	return h.GetWhere(obj, sq.Eq{column: value})
}

// GetByForUpdate loads a mapped structure by a given column and get a lock
// (see https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE)
func (h *SQLHelper) GetByForUpdate(obj Mapped, column string, value interface{}) error {
	return h.GetWhereForUpdate(obj, sq.Eq{column: value})
}

// GetWhere loads a mapped structure
// ex: GetWhere(account, squirrel.Eq{models.AccountDBNameColumn: "somename"})
func (h *SQLHelper) GetWhere(obj Mapped, pred interface{}, args ...interface{}) error {
	query := SQ.
		Select(obj.Columns(true)...).
		From(obj.Table()).
		Where(pred, args...)

	return h.Get(obj, query)
}

// DeleteWhere deletes from a table corresponding to a mapped structure using the predicate and args
// ex: DeleteWhere(account, squirrel.Eq{models.AccountDBNameColumn: "somename"})
func (h *SQLHelper) DeleteWhere(obj Mapped, pred interface{}, args ...interface{}) error {
	query := SQ.
		Delete(obj.Table()).
		Where(pred, args...)

	_, err := h.Exec(query)

	return err
}

// GetWhereForUpdate loads a mapped structure and get a lock
// (see https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE)
func (h *SQLHelper) GetWhereForUpdate(obj Mapped, pred interface{}, args ...interface{}) error {
	query := SQ.
		Select(obj.Columns(true)...).
		From(obj.Table()).
		Where(pred, args...).
		Suffix("FOR UPDATE")

	return h.Get(obj, query)
}

// Select loads a structure list.
func (h *SQLHelper) Select(obj interface{}, query sq.Sqlizer) error {
	return SelectContext(h.ctx, h.sqle, obj, query, &h.log)
}

func (h *SQLHelper) SelectRaw(obj interface{}, sql string, args ...interface{}) error {
	SQLTrace(&h.log, sql, args)

	return h.sqle.SelectContext(h.ctx, obj, sql, args...)
}

// Exec executes a query.
func (h *SQLHelper) Exec(query sq.Sqlizer) (sql.Result, error) {
	return ExecContext(h.ctx, h.sqle, query, &h.log)
}

func (h *SQLHelper) ExecRaw(sql string, args ...interface{}) (sql.Result, error) {
	SQLTrace(&h.log, sql, args)

	return h.sqle.ExecContext(h.ctx, sql, args...)
}

// Query executes a query.
func (h *SQLHelper) Query(query sq.Sqlizer) (*sqlx.Rows, error) {
	return QueryContext(h.ctx, h.sqle, query, &h.log)
}

// QueryRaw executes a raw SQL query.
func (h *SQLHelper) QueryRaw(sql string, args ...interface{}) (*sqlx.Rows, error) {
	SQLTrace(&h.log, sql, args)

	return h.sqle.QueryxContext(h.ctx, sql, args...)
}

// QueryRow executes a query that returns a single row.
func (h *SQLHelper) QueryRow(query sq.Sqlizer) *Row {
	return QueryRowContext(h.ctx, h.sqle, query, &h.log)
}

// Insert inserts a Mapped into the db.
func (h *SQLHelper) Insert(instances ...Mapped) (sql.Result, error) {
	query := SQLInsert(instances...)

	return h.Exec(query)
}

// InsertNoPKey inserts a Mapped into the db.
func (h *SQLHelper) InsertNoPKey(instances ...Mapped) (sql.Result, error) {
	query := SQLInsertNoPKey(instances...)

	return h.Exec(query)
}

// Upsert upserts a Mapped into the db.
func (h *SQLHelper) Upsert(instances ...Mapped) error {
	for _, instance := range instances {
		query := SQLUpsert(instance)
		if _, err := h.Exec(query); err != nil {
			return err
		}
	}

	return nil
}

// UpsertColumns upserts a Mapped into the db using only the given columns as update source.
func (h *SQLHelper) UpsertColumns(insertColumns, updateColumns []string, instances ...Mapped) error {
	for _, instance := range instances {
		query := SQLUpsertColumns(instance, insertColumns, updateColumns)
		if _, err := h.Exec(query); err != nil {
			return err
		}
	}

	return nil
}

// Update updates the given Mapped into the db using their
// pkey as predicate for the where clause.
func (h *SQLHelper) Update(instances ...Mapped) error {
	for _, instance := range instances {
		query := SQLUpdate(instance)
		if _, err := h.Exec(query); err != nil {
			return err
		}
	}

	return nil
}

// UpdateColumns update a mapped but only for the given columns.
func (h *SQLHelper) UpdateColumns(instance Mapped, columns ...string) error {
	query := SQLUpdate(instance, columns...)
	if _, err := h.Exec(query); err != nil {
		return err
	}

	return nil
}

// UpsertNoPKey upserts a Mapped into the db.
func (h *SQLHelper) UpsertNoPKey(keyCols []string, instances ...Mapped) error {
	for _, instance := range instances {
		query := SQLUpsertNoPKey(keyCols, instance)
		if _, err := h.Exec(query); err != nil {
			return err
		}
	}

	return nil
}

// SyncRelationStrings ...
func (h *SQLHelper) SyncRelationStrings(
	table string, colFrom string, colTo string,
	colFromValue string, colToValues []string,
) error {
	if len(colToValues) != 0 {
		// Upsert ts les couples
		q := sq.Insert(table).
			Columns(colFrom, colTo)
		for _, value := range colToValues {
			q = q.Values(colFromValue, value)
		}
		q = q.Suffix("ON CONFLICT DO NOTHING")
		if _, err := h.Exec(q); err != nil {
			return err
		}
	}

	filter := sq.And{sq.Eq{colFrom: colFromValue}}
	for _, value := range colToValues {
		filter = append(filter, sq.NotEq{colTo: value})
	}
	// Delete les valeurs non voulues
	if _, err := h.Exec(sq.Delete(table).Where(filter)); err != nil {
		return err
	}

	return nil
}