Skip to content
Snippets Groups Projects
sql.go 6.66 KiB
Newer Older
package database

import (
	"context"
	"database/sql"
	"strings"

	"github.com/Masterminds/squirrel"
	"github.com/jmoiron/sqlx"
	"github.com/lann/builder"
	"github.com/rs/zerolog"
)

// Mapped is the common interface of all structs that are mapped in the database
type Mapped interface {
	Table() string
	PKeyColumn() string
	Columns(withPKey bool) []string
	Values(columns ...string) []interface{}
}

// SQ is a squirrel StatementBuilder
var SQ = squirrel.StatementBuilder

// SQLTrace logs a sql query and args as TRACE level
func SQLTrace(log *zerolog.Logger, sql string, args []interface{}) {
	if log != nil && log.GetLevel() < 0 {
		arr := zerolog.Arr()
		for _, arg := range args {
			arr.Interface(arg)
		}
		logger := log.With().Array("args", arr).Logger()
		logger.Trace().Msg("SQL: " + sql)
	}
}

// SQLExecutor is the common interface of sqlx.DB and sqlx.Tx that we use in
// the functions below
type SQLExecutor interface {
	ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
	GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
	SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
	QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error)
}

func setPlaceHolderFormat(query squirrel.Sqlizer) squirrel.Sqlizer {
	return builder.Set(query, "PlaceholderFormat", squirrel.Dollar).(squirrel.Sqlizer)
}

// Get loads an object
func Get(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
	return GetContext(context.Background(), e, obj, query, log)
}

// GetContext loads an object
func GetContext(
	ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
	sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
	if err != nil {
		return err
	}
	SQLTrace(log, sqlQuery, args)
	return e.GetContext(ctx, obj, sqlQuery, args...)
}

// Select load an object list
func Select(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
	return SelectContext(context.Background(), e, obj, query, log)
}

// SelectContext loads an object list
func SelectContext(
	ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
	sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
	if err != nil {
		return err
	}
	SQLTrace(log, sqlQuery, args)
	return e.SelectContext(ctx, obj, sqlQuery, args...)
}

// Exec runs a squirrel query on the given db or tx
func Exec(e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
	return ExecContext(context.Background(), e, query, log)
}

// ExecContext runs a squirrel query on the given db or tx
func ExecContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
	query = setPlaceHolderFormat(query)
	sqlQuery, args, err := query.
		ToSql()
	if err != nil {
		return nil, err
	}
	SQLTrace(log, sqlQuery, args)
	return e.ExecContext(ctx, sqlQuery, args...)
}

// QueryContext runs a squirrel query on the given db or tx
func QueryContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (*sqlx.Rows, error) {
	query = setPlaceHolderFormat(query)
	sqlQuery, args, err := query.
		ToSql()
	if err != nil {
		return nil, err
	}
	SQLTrace(log, sqlQuery, args)
	return e.QueryxContext(ctx, sqlQuery, args...)
}

// ValuesMap returns the values for a list of columns as a map. If a column does
// not exits, the corresponding value is set to nil
func ValuesMap(m Mapped, columns ...string) map[string]interface{} {
	valuesList := m.Values(columns...)
	values := make(map[string]interface{})
	for i, column := range columns {
		values[column] = valuesList[i]
	}
	return values
}

Florent Aide's avatar
Florent Aide committed
// SQLUpdate generates a squirrel "update" statement
// for the given mapped instance (auto-selecting by its pkey)
func SQLUpdate(m Mapped) squirrel.UpdateBuilder {
	q := squirrel.
		Update(m.Table()).
		SetMap(ValuesMap(m, m.Columns(false)...)).Where(
		squirrel.Eq{m.PKeyColumn(): m.Values(m.PKeyColumn())})
	return q
}

// SQLUpsert generates a squirrel "upsert" statement
func SQLUpsert(m Mapped) squirrel.InsertBuilder {
	updateSQL, updateArgs, err := squirrel.
		Update(m.Table()).
		SetMap(ValuesMap(m, m.Columns(false)...)).
		ToSql()
	if err != nil {
		panic(err)
	}

	updateParts := strings.Split(updateSQL, " SET ")
	if len(updateParts) != 2 {
		panic("Could not split the UPDATE query: " + updateSQL) //nolint:gosec
	}
	suffix := "ON CONFLICT (" + m.PKeyColumn() + ") DO UPDATE SET " + updateParts[1]

	allColumns := m.Columns(true)
	q := squirrel.
		Insert(m.Table()).
		PlaceholderFormat(squirrel.Dollar).
		Columns(allColumns...).
		Values(m.Values(allColumns...)...).
		Suffix(suffix, updateArgs...)
	return q
}

Christophe de Vienne's avatar
Christophe de Vienne committed
// SQLUpsertNoPKey generates a squirrel "upsert" statement
func SQLUpsertNoPKey(keyCols []string, m Mapped) squirrel.InsertBuilder {
	updateSQL, updateArgs, err := squirrel.
		Update(m.Table()).
		SetMap(ValuesMap(m, m.Columns(false)...)).
		ToSql()
	if err != nil {
		panic(err)
	}

	updateParts := strings.Split(updateSQL, " SET ")
	if len(updateParts) != 2 {
		panic("Could not split the UPDATE query: " + updateSQL) //nolint:gosec
	}
	suffix := "ON CONFLICT (" + strings.Join(keyCols, ",") + ") DO UPDATE SET " + updateParts[1]

	allColumns := m.Columns(false)
	q := squirrel.
		Insert(m.Table()).
		PlaceholderFormat(squirrel.Dollar).
		Columns(allColumns...).
		Values(m.Values(allColumns...)...).
		Suffix(suffix, updateArgs...)
	return q
}

// SQLInsert build a Insert statement to insert one or several mapped instances
// in the database. All the instances must be of the same actual type
func SQLInsert(instances ...Mapped) squirrel.InsertBuilder {
	allColumns := instances[0].Columns(true)
	q := SQ.Insert(instances[0].Table()).
		Columns(allColumns...)
	for _, m := range instances {
		q = q.Values(m.Values(allColumns...)...)
	}
	return q
}

// SQLInsertNoPKey build a Insert statement to insert one or several mapped instances
// in the database, but leave the pkey undefined so it gets auto-generated
// by the database. All the instances must be of the same actual type
func SQLInsertNoPKey(instances ...Mapped) squirrel.InsertBuilder {
	allColumns := instances[0].Columns(false)
	q := SQ.Insert(instances[0].Table()).
		Columns(allColumns...)
	for _, m := range instances {
		q = q.Values(m.Values(allColumns...)...)
	}
	return q
}

// PrefixColumns ...
func PrefixColumns(table string, columns ...string) []string {
	var prefixed = make([]string, len(columns))
	for i, name := range columns {
		prefixed[i] = PrefixColumn(table, name)
	}
	return columns
}

// PrefixColumn ...
func PrefixColumn(table string, column string) string {
	return table + "." + column
}