Skip to content
Snippets Groups Projects
sql.go 8.09 KiB
Newer Older
package database

import (
	"context"
	"database/sql"
	"strings"

	"github.com/Masterminds/squirrel"
	"github.com/jmoiron/sqlx"
	"github.com/lann/builder"
	"github.com/rs/zerolog"
)

// Mapped is the common interface of all structs that are mapped in the database.
type Mapped interface {
	Table() string
	PKeyColumn() string
	Columns(withPKey bool) []string
	Values(columns ...string) []interface{}
}

// SQ is a squirrel StatementBuilder.
var SQ = squirrel.StatementBuilder

// SQLTrace logs a sql query and args as TRACE level.
func SQLTrace(log *zerolog.Logger, sql string, args []interface{}) {
	if log != nil && log.GetLevel() < 0 {
		arr := zerolog.Arr()
		for _, arg := range args {
			b, err := json.Marshal(arg)
			if err != nil {
Axel Prel's avatar
Axel Prel committed
				log.Err(err).Msg("error Marshaling arguments")
				arr.Interface(arg)
			} else {
				if len(b) > 512 {
Axel Prel's avatar
Axel Prel committed
					b, err = json.Marshal(map[string]interface{}{
						"arg first 512 bytes": b[:512],
						"len":                 len(b),
					})
Axel Prel's avatar
Axel Prel committed
					if err != nil {
						log.Err(err).Msg("error Marshaling first 512 bytes of arguments")
					}
		}
		logger := log.With().Array("args", arr).Logger()
		logger.Trace().Msg("SQL: " + sql)
	}
}

// SQLExecutor is the common interface of sqlx.DB and sqlx.Tx that we use in
type SQLExecutor interface {
	ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
	GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
	SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
	QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error)
	QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row
}

func setPlaceHolderFormat(query squirrel.Sqlizer) squirrel.Sqlizer {
Axel Prel's avatar
Axel Prel committed
	sqlizer, ok := builder.Set(query, "PlaceholderFormat", squirrel.Dollar).(squirrel.Sqlizer)
	if !ok {
		panic("could not assert sqlizer from builder")
	}

	return sqlizer
func Get(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
	return GetContext(context.Background(), e, obj, query, log)
}

// GetContext loads an object.
func GetContext(
	ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
	sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
	if err != nil {
		return err
	}
	SQLTrace(log, sqlQuery, args)
Axel Prel's avatar
Axel Prel committed

	return e.GetContext(ctx, obj, sqlQuery, args...)
}

// Select load an object list.
func Select(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
	return SelectContext(context.Background(), e, obj, query, log)
}

// SelectContext loads an object list.
func SelectContext(
	ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
	sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
	if err != nil {
		return err
	}
	SQLTrace(log, sqlQuery, args)
Axel Prel's avatar
Axel Prel committed

	return e.SelectContext(ctx, obj, sqlQuery, args...)
}

// Exec runs a squirrel query on the given db or tx.
func Exec(e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
	return ExecContext(context.Background(), e, query, log)
}

// ExecContext runs a squirrel query on the given db or tx.
func ExecContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
	query = setPlaceHolderFormat(query)
	sqlQuery, args, err := query.
		ToSql()
	if err != nil {
		return nil, err
	}
	SQLTrace(log, sqlQuery, args)
Axel Prel's avatar
Axel Prel committed

	return e.ExecContext(ctx, sqlQuery, args...)
}

// QueryContext runs a squirrel query on the given db or tx.
func QueryContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (*sqlx.Rows, error) {
	query = setPlaceHolderFormat(query)
	sqlQuery, args, err := query.
		ToSql()
	if err != nil {
		return nil, err
	}
	SQLTrace(log, sqlQuery, args)
Axel Prel's avatar
Axel Prel committed

	return e.QueryxContext(ctx, sqlQuery, args...)
}

// QueryRowContext runs a squirrel query on the given db or tx and return a single Row.
func QueryRowContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) *Row {
	query = setPlaceHolderFormat(query)
	sqlQuery, args, err := query.
		ToSql()
	if err != nil {
		return &Row{err: err}
	}
	SQLTrace(log, sqlQuery, args)

	return &Row{Row: e.QueryRowxContext(ctx, sqlQuery, args...)}
}

// ValuesMap returns the values for a list of columns as a map. If a column does
// not exits, the corresponding value is set to nil.
func ValuesMap(m Mapped, columns ...string) map[string]interface{} {
	valuesList := m.Values(columns...)
	values := make(map[string]interface{})
	for i, column := range columns {
		values[column] = valuesList[i]
	}
Axel Prel's avatar
Axel Prel committed

Florent Aide's avatar
Florent Aide committed
// SQLUpdate generates a squirrel "update" statement
// for the given mapped instance (auto-selecting by its pkey).
func SQLUpdate(m Mapped, columns ...string) squirrel.UpdateBuilder {
	if len(columns) == 0 {
		columns = m.Columns(false)
	}
Florent Aide's avatar
Florent Aide committed
	q := squirrel.
		Update(m.Table()).
		SetMap(ValuesMap(m, columns...)).Where(
Florent Aide's avatar
Florent Aide committed
		squirrel.Eq{m.PKeyColumn(): m.Values(m.PKeyColumn())})
Axel Prel's avatar
Axel Prel committed

// SQLUpsert generates a squirrel "upsert" statement.
func SQLUpsert(m Mapped) squirrel.InsertBuilder {
	return SQLUpsertColumns(m, m.Columns(true), m.Columns(false))
}

func SQLUpsertColumns(
	m Mapped, insertColumns []string, updateColumns []string,
) squirrel.InsertBuilder {
	updateSQL, updateArgs, err := squirrel.
		Update(m.Table()).
		SetMap(ValuesMap(m, updateColumns...)).
		ToSql()
	if err != nil {
		panic(err)
	}

	updateParts := strings.Split(updateSQL, " SET ")
	if len(updateParts) != 2 {
Axel Prel's avatar
Axel Prel committed
		panic("Could not split the UPDATE query: " + updateSQL)
	}
	suffix := "ON CONFLICT (" + m.PKeyColumn() + ") DO UPDATE SET " + updateParts[1]

	q := squirrel.
		Insert(m.Table()).
		PlaceholderFormat(squirrel.Dollar).
		Columns(insertColumns...).
		Values(m.Values(insertColumns...)...).
		Suffix(suffix, updateArgs...)
Axel Prel's avatar
Axel Prel committed

func SQLUpsertBase(table string, keyCols []string, cols []string) squirrel.InsertBuilder {
	uq := squirrel.Update(table)
	for _, col := range cols {
		uq = uq.Set(col, squirrel.Expr("EXCLUDED."+col))
	}
	updateSQL, updateArgs, err := uq.ToSql()
Christophe de Vienne's avatar
Christophe de Vienne committed
	if err != nil {
		panic(err)
	}

	updateParts := strings.Split(updateSQL, " SET ")
	if len(updateParts) != 2 {
Axel Prel's avatar
Axel Prel committed
		panic("Could not split the UPDATE query: " + updateSQL)
Christophe de Vienne's avatar
Christophe de Vienne committed
	}
	suffix := "ON CONFLICT (" + strings.Join(keyCols, ",") + ") DO UPDATE SET " + updateParts[1]

	q := squirrel.
		Insert(table).
Christophe de Vienne's avatar
Christophe de Vienne committed
		PlaceholderFormat(squirrel.Dollar).
		Columns(cols...).
Christophe de Vienne's avatar
Christophe de Vienne committed
		Suffix(suffix, updateArgs...)
Axel Prel's avatar
Axel Prel committed

// SQLUpsertNoPKey generates a squirrel "upsert" statement.
func SQLUpsertNoPKey(keyCols []string, m Mapped, mm ...Mapped) squirrel.InsertBuilder {
	allColumns := m.Columns(false)
	q := SQLUpsertBase(m.Table(), keyCols, allColumns).
		Values(m.Values(allColumns...)...)
	for _, m := range mm {
		q = q.Values(m.Values(allColumns...)...)
	}

	return q
}

// SQLInsert build a Insert statement to insert one or several mapped instances
// in the database. All the instances must be of the same actual type.
func SQLInsert(instances ...Mapped) squirrel.InsertBuilder {
	allColumns := instances[0].Columns(true)
	q := SQ.Insert(instances[0].Table()).
		Columns(allColumns...)
	for _, m := range instances {
		q = q.Values(m.Values(allColumns...)...)
	}
Axel Prel's avatar
Axel Prel committed

	return q
}

// SQLInsertNoPKey build a Insert statement to insert one or several mapped instances
// in the database, but leave the pkey undefined so it gets auto-generated
// by the database. All the instances must be of the same actual type.
func SQLInsertNoPKey(instances ...Mapped) squirrel.InsertBuilder {
	allColumns := instances[0].Columns(false)
	q := SQ.Insert(instances[0].Table()).
		Columns(allColumns...)
	for _, m := range instances {
		q = q.Values(m.Values(allColumns...)...)
	}
Axel Prel's avatar
Axel Prel committed


// PrefixColumns ...
func PrefixColumns(table string, columns ...string) []string {
	prefixed := make([]string, len(columns))
	for i, name := range columns {
		prefixed[i] = PrefixColumn(table, name)
	}
Axel Prel's avatar
Axel Prel committed

	return columns
}

// PrefixColumn ...
func PrefixColumn(table string, column string) string {
	return table + "." + column
}