Newer
Older
package database
import (
"context"
"database/sql"
"strings"
"github.com/Masterminds/squirrel"
"github.com/lann/builder"
"github.com/rs/zerolog"
)
// Mapped is the common interface of all structs that are mapped in the database.
type Mapped interface {
Table() string
PKeyColumn() string
Columns(withPKey bool) []string
Values(columns ...string) []interface{}
}
// SQ is a squirrel StatementBuilder.
var SQ = squirrel.StatementBuilder
// SQLTrace logs a sql query and args as TRACE level.
func SQLTrace(log *zerolog.Logger, sql string, args []interface{}) {
if log != nil && log.GetLevel() < 0 {
arr := zerolog.Arr()
for _, arg := range args {
b, err := json.Marshal(arg)
if err != nil {
arr.Interface(arg)
} else {
if len(b) > 512 {
"arg first 512 bytes": b[:512],
"len": len(b),
})
if err != nil {
log.Err(err).Msg("error Marshaling first 512 bytes of arguments")
}
}
arr.RawJSON(b)
}
}
logger := log.With().Array("args", arr).Logger()
logger.Trace().Msg("SQL: " + sql)
}
}
// SQLExecutor is the common interface of sqlx.DB and sqlx.Tx that we use in
// the functions below.
type SQLExecutor interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error)
QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row
}
func setPlaceHolderFormat(query squirrel.Sqlizer) squirrel.Sqlizer {
sqlizer, ok := builder.Set(query, "PlaceholderFormat", squirrel.Dollar).(squirrel.Sqlizer)
if !ok {
panic("could not assert sqlizer from builder")
}
return sqlizer
// Get loads an object.
func Get(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
return GetContext(context.Background(), e, obj, query, log)
}
// GetContext loads an object.
func GetContext(
ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
if err != nil {
return err
}
SQLTrace(log, sqlQuery, args)
return e.GetContext(ctx, obj, sqlQuery, args...)
}
// Select load an object list.
func Select(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
return SelectContext(context.Background(), e, obj, query, log)
}
// SelectContext loads an object list.
func SelectContext(
ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
if err != nil {
return err
}
SQLTrace(log, sqlQuery, args)
return e.SelectContext(ctx, obj, sqlQuery, args...)
}
// Exec runs a squirrel query on the given db or tx.
func Exec(e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
return ExecContext(context.Background(), e, query, log)
}
// ExecContext runs a squirrel query on the given db or tx.
func ExecContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
query = setPlaceHolderFormat(query)
sqlQuery, args, err := query.
ToSql()
if err != nil {
return nil, err
}
SQLTrace(log, sqlQuery, args)
return e.ExecContext(ctx, sqlQuery, args...)
}
// QueryContext runs a squirrel query on the given db or tx.
func QueryContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (*sqlx.Rows, error) {
query = setPlaceHolderFormat(query)
sqlQuery, args, err := query.
ToSql()
if err != nil {
return nil, err
}
SQLTrace(log, sqlQuery, args)
return e.QueryxContext(ctx, sqlQuery, args...)
}
// QueryRowContext runs a squirrel query on the given db or tx and return a single Row.
func QueryRowContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) *Row {
query = setPlaceHolderFormat(query)
sqlQuery, args, err := query.
ToSql()
if err != nil {
return &Row{err: err}
}
SQLTrace(log, sqlQuery, args)
return &Row{Row: e.QueryRowxContext(ctx, sqlQuery, args...)}
}
// ValuesMap returns the values for a list of columns as a map. If a column does
// not exits, the corresponding value is set to nil.
func ValuesMap(m Mapped, columns ...string) map[string]interface{} {
valuesList := m.Values(columns...)
values := make(map[string]interface{})
for i, column := range columns {
values[column] = valuesList[i]
}
// SQLUpdate generates a squirrel "update" statement
// for the given mapped instance (auto-selecting by its pkey).
func SQLUpdate(m Mapped, columns ...string) squirrel.UpdateBuilder {
if len(columns) == 0 {
columns = m.Columns(false)
}
squirrel.Eq{m.PKeyColumn(): m.Values(m.PKeyColumn())})
// SQLUpsert generates a squirrel "upsert" statement.
func SQLUpsert(m Mapped) squirrel.InsertBuilder {
return SQLUpsertColumns(m, m.Columns(true), m.Columns(false))
}
func SQLUpsertColumns(
m Mapped, insertColumns []string, updateColumns []string,
) squirrel.InsertBuilder {
updateSQL, updateArgs, err := squirrel.
Update(m.Table()).
ToSql()
if err != nil {
panic(err)
}
updateParts := strings.Split(updateSQL, " SET ")
if len(updateParts) != 2 {
}
suffix := "ON CONFLICT (" + m.PKeyColumn() + ") DO UPDATE SET " + updateParts[1]
q := squirrel.
Insert(m.Table()).
PlaceholderFormat(squirrel.Dollar).
Columns(insertColumns...).
Values(m.Values(insertColumns...)...).
Suffix(suffix, updateArgs...)
func SQLUpsertBase(table string, keyCols []string, cols []string) squirrel.InsertBuilder {
uq := squirrel.Update(table)
for _, col := range cols {
uq = uq.Set(col, squirrel.Expr("EXCLUDED."+col))
}
updateSQL, updateArgs, err := uq.ToSql()
if err != nil {
panic(err)
}
updateParts := strings.Split(updateSQL, " SET ")
if len(updateParts) != 2 {
}
suffix := "ON CONFLICT (" + strings.Join(keyCols, ",") + ") DO UPDATE SET " + updateParts[1]
q := squirrel.
// SQLUpsertNoPKey generates a squirrel "upsert" statement.
func SQLUpsertNoPKey(keyCols []string, m Mapped, mm ...Mapped) squirrel.InsertBuilder {
allColumns := m.Columns(false)
q := SQLUpsertBase(m.Table(), keyCols, allColumns).
Values(m.Values(allColumns...)...)
for _, m := range mm {
q = q.Values(m.Values(allColumns...)...)
}
return q
}
// SQLInsert build a Insert statement to insert one or several mapped instances
// in the database. All the instances must be of the same actual type.
func SQLInsert(instances ...Mapped) squirrel.InsertBuilder {
allColumns := instances[0].Columns(true)
q := SQ.Insert(instances[0].Table()).
Columns(allColumns...)
for _, m := range instances {
q = q.Values(m.Values(allColumns...)...)
}
return q
}
// SQLInsertNoPKey build a Insert statement to insert one or several mapped instances
// in the database, but leave the pkey undefined so it gets auto-generated
// by the database. All the instances must be of the same actual type.
func SQLInsertNoPKey(instances ...Mapped) squirrel.InsertBuilder {
allColumns := instances[0].Columns(false)
q := SQ.Insert(instances[0].Table()).
Columns(allColumns...)
for _, m := range instances {
q = q.Values(m.Values(allColumns...)...)
}
// PrefixColumns ...
func PrefixColumns(table string, columns ...string) []string {
prefixed := make([]string, len(columns))
for i, name := range columns {
prefixed[i] = PrefixColumn(table, name)
}
return columns
}
// PrefixColumn ...
func PrefixColumn(table string, column string) string {
return table + "." + column
}