Newer
Older
package database
import (
"context"
"database/sql"
"strings"
"github.com/Masterminds/squirrel"
"github.com/lann/builder"
"github.com/rs/zerolog"
)
// Mapped is the common interface of all structs that are mapped in the database
type Mapped interface {
Table() string
PKeyColumn() string
Columns(withPKey bool) []string
Values(columns ...string) []interface{}
}
// SQ is a squirrel StatementBuilder
var SQ = squirrel.StatementBuilder
// SQLTrace logs a sql query and args as TRACE level
func SQLTrace(log *zerolog.Logger, sql string, args []interface{}) {
if log != nil && log.GetLevel() < 0 {
arr := zerolog.Arr()
for _, arg := range args {
arr.Interface(arg)
}
logger := log.With().Array("args", arr).Logger()
logger.Trace().Msg("SQL: " + sql)
}
}
// SQLExecutor is the common interface of sqlx.DB and sqlx.Tx that we use in
// the functions below
type SQLExecutor interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error)
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
}
func setPlaceHolderFormat(query squirrel.Sqlizer) squirrel.Sqlizer {
return builder.Set(query, "PlaceholderFormat", squirrel.Dollar).(squirrel.Sqlizer)
}
// Get loads an object
func Get(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
return GetContext(context.Background(), e, obj, query, log)
}
// GetContext loads an object
func GetContext(
ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
if err != nil {
return err
}
SQLTrace(log, sqlQuery, args)
return e.GetContext(ctx, obj, sqlQuery, args...)
}
// Select load an object list
func Select(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
return SelectContext(context.Background(), e, obj, query, log)
}
// SelectContext loads an object list
func SelectContext(
ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
if err != nil {
return err
}
SQLTrace(log, sqlQuery, args)
return e.SelectContext(ctx, obj, sqlQuery, args...)
}
// Exec runs a squirrel query on the given db or tx
func Exec(e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
return ExecContext(context.Background(), e, query, log)
}
// ExecContext runs a squirrel query on the given db or tx
func ExecContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
query = setPlaceHolderFormat(query)
sqlQuery, args, err := query.
ToSql()
if err != nil {
return nil, err
}
SQLTrace(log, sqlQuery, args)
return e.ExecContext(ctx, sqlQuery, args...)
}
// QueryContext runs a squirrel query on the given db or tx
func QueryContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (*sqlx.Rows, error) {
query = setPlaceHolderFormat(query)
sqlQuery, args, err := query.
ToSql()
if err != nil {
return nil, err
}
SQLTrace(log, sqlQuery, args)
return e.QueryxContext(ctx, sqlQuery, args...)
}
// ValuesMap returns the values for a list of columns as a map. If a column does
// not exits, the corresponding value is set to nil
func ValuesMap(m Mapped, columns ...string) map[string]interface{} {
valuesList := m.Values(columns...)
values := make(map[string]interface{})
for i, column := range columns {
values[column] = valuesList[i]
}
return values
}
// SQLUpdate generates a squirrel "update" statement
// for the given mapped instance (auto-selecting by its pkey)
func SQLUpdate(m Mapped) squirrel.UpdateBuilder {
q := squirrel.
Update(m.Table()).
SetMap(ValuesMap(m, m.Columns(false)...)).Where(
squirrel.Eq{m.PKeyColumn(): m.Values(m.PKeyColumn())})
return q
}
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
// SQLUpsert generates a squirrel "upsert" statement
func SQLUpsert(m Mapped) squirrel.InsertBuilder {
updateSQL, updateArgs, err := squirrel.
Update(m.Table()).
SetMap(ValuesMap(m, m.Columns(false)...)).
ToSql()
if err != nil {
panic(err)
}
updateParts := strings.Split(updateSQL, " SET ")
if len(updateParts) != 2 {
panic("Could not split the UPDATE query: " + updateSQL) //nolint:gosec
}
suffix := "ON CONFLICT (" + m.PKeyColumn() + ") DO UPDATE SET " + updateParts[1]
allColumns := m.Columns(true)
q := squirrel.
Insert(m.Table()).
PlaceholderFormat(squirrel.Dollar).
Columns(allColumns...).
Values(m.Values(allColumns...)...).
Suffix(suffix, updateArgs...)
return q
}
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
// SQLUpsertNoPKey generates a squirrel "upsert" statement
func SQLUpsertNoPKey(keyCols []string, m Mapped) squirrel.InsertBuilder {
updateSQL, updateArgs, err := squirrel.
Update(m.Table()).
SetMap(ValuesMap(m, m.Columns(false)...)).
ToSql()
if err != nil {
panic(err)
}
updateParts := strings.Split(updateSQL, " SET ")
if len(updateParts) != 2 {
panic("Could not split the UPDATE query: " + updateSQL) //nolint:gosec
}
suffix := "ON CONFLICT (" + strings.Join(keyCols, ",") + ") DO UPDATE SET " + updateParts[1]
allColumns := m.Columns(false)
q := squirrel.
Insert(m.Table()).
PlaceholderFormat(squirrel.Dollar).
Columns(allColumns...).
Values(m.Values(allColumns...)...).
Suffix(suffix, updateArgs...)
return q
}
// SQLInsert build a Insert statement to insert one or several mapped instances
// in the database. All the instances must be of the same actual type
func SQLInsert(instances ...Mapped) squirrel.InsertBuilder {
allColumns := instances[0].Columns(true)
q := SQ.Insert(instances[0].Table()).
Columns(allColumns...)
for _, m := range instances {
q = q.Values(m.Values(allColumns...)...)
}
return q
}
// SQLInsertNoPKey build a Insert statement to insert one or several mapped instances
// in the database, but leave the pkey undefined so it gets auto-generated
// by the database. All the instances must be of the same actual type
func SQLInsertNoPKey(instances ...Mapped) squirrel.InsertBuilder {
allColumns := instances[0].Columns(false)
q := SQ.Insert(instances[0].Table()).
Columns(allColumns...)
for _, m := range instances {
q = q.Values(m.Values(allColumns...)...)
}
return q
}
// PrefixColumns ...
func PrefixColumns(table string, columns ...string) []string {
var prefixed = make([]string, len(columns))
for i, name := range columns {
prefixed[i] = PrefixColumn(table, name)
}
return columns
}
// PrefixColumn ...
func PrefixColumn(table string, column string) string {
return table + "." + column
}