Newer
Older
package database
import (
"context"
"database/sql"
"strings"
"github.com/Masterminds/squirrel"
"github.com/lann/builder"
"github.com/rs/zerolog"
)
// Mapped is the common interface of all structs that are mapped in the database
type Mapped interface {
Table() string
PKeyColumn() string
Columns(withPKey bool) []string
Values(columns ...string) []interface{}
}
// SQ is a squirrel StatementBuilder
var SQ = squirrel.StatementBuilder
// SQLTrace logs a sql query and args as TRACE level
func SQLTrace(log *zerolog.Logger, sql string, args []interface{}) {
if log != nil && log.GetLevel() < 0 {
arr := zerolog.Arr()
for _, arg := range args {
arr.Interface(arg)
}
logger := log.With().Array("args", arr).Logger()
logger.Trace().Msg("SQL: " + sql)
}
}
// SQLExecutor is the common interface of sqlx.DB and sqlx.Tx that we use in
// the functions below
type SQLExecutor interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error)
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
}
func setPlaceHolderFormat(query squirrel.Sqlizer) squirrel.Sqlizer {
return builder.Set(query, "PlaceholderFormat", squirrel.Dollar).(squirrel.Sqlizer)
}
// Get loads an object
func Get(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
return GetContext(context.Background(), e, obj, query, log)
}
// GetContext loads an object
func GetContext(
ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
if err != nil {
return err
}
SQLTrace(log, sqlQuery, args)
return e.GetContext(ctx, obj, sqlQuery, args...)
}
// Select load an object list
func Select(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error {
return SelectContext(context.Background(), e, obj, query, log)
}
// SelectContext loads an object list
func SelectContext(
ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger,
) error {
sqlQuery, args, err := setPlaceHolderFormat(query).ToSql()
if err != nil {
return err
}
SQLTrace(log, sqlQuery, args)
return e.SelectContext(ctx, obj, sqlQuery, args...)
}
// Exec runs a squirrel query on the given db or tx
func Exec(e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
return ExecContext(context.Background(), e, query, log)
}
// ExecContext runs a squirrel query on the given db or tx
func ExecContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) {
query = setPlaceHolderFormat(query)
sqlQuery, args, err := query.
ToSql()
if err != nil {
return nil, err
}
SQLTrace(log, sqlQuery, args)
return e.ExecContext(ctx, sqlQuery, args...)
}
// QueryContext runs a squirrel query on the given db or tx
func QueryContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (*sqlx.Rows, error) {
query = setPlaceHolderFormat(query)
sqlQuery, args, err := query.
ToSql()
if err != nil {
return nil, err
}
SQLTrace(log, sqlQuery, args)
return e.QueryxContext(ctx, sqlQuery, args...)
}
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// ValuesMap returns the values for a list of columns as a map. If a column does
// not exits, the corresponding value is set to nil
func ValuesMap(m Mapped, columns ...string) map[string]interface{} {
valuesList := m.Values(columns...)
values := make(map[string]interface{})
for i, column := range columns {
values[column] = valuesList[i]
}
return values
}
// SQLUpsert generates a squirrel "upsert" statement
func SQLUpsert(m Mapped) squirrel.InsertBuilder {
updateSQL, updateArgs, err := squirrel.
Update(m.Table()).
SetMap(ValuesMap(m, m.Columns(false)...)).
ToSql()
if err != nil {
panic(err)
}
updateParts := strings.Split(updateSQL, " SET ")
if len(updateParts) != 2 {
panic("Could not split the UPDATE query: " + updateSQL) //nolint:gosec
}
suffix := "ON CONFLICT (" + m.PKeyColumn() + ") DO UPDATE SET " + updateParts[1]
allColumns := m.Columns(true)
q := squirrel.
Insert(m.Table()).
PlaceholderFormat(squirrel.Dollar).
Columns(allColumns...).
Values(m.Values(allColumns...)...).
Suffix(suffix, updateArgs...)
return q
}
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
// SQLUpsertNoPKey generates a squirrel "upsert" statement
func SQLUpsertNoPKey(keyCols []string, m Mapped) squirrel.InsertBuilder {
updateSQL, updateArgs, err := squirrel.
Update(m.Table()).
SetMap(ValuesMap(m, m.Columns(false)...)).
ToSql()
if err != nil {
panic(err)
}
updateParts := strings.Split(updateSQL, " SET ")
if len(updateParts) != 2 {
panic("Could not split the UPDATE query: " + updateSQL) //nolint:gosec
}
suffix := "ON CONFLICT (" + strings.Join(keyCols, ",") + ") DO UPDATE SET " + updateParts[1]
allColumns := m.Columns(false)
q := squirrel.
Insert(m.Table()).
PlaceholderFormat(squirrel.Dollar).
Columns(allColumns...).
Values(m.Values(allColumns...)...).
Suffix(suffix, updateArgs...)
return q
}
// SQLInsert build a Insert statement to insert one or several mapped instances
// in the database. All the instances must be of the same actual type
func SQLInsert(instances ...Mapped) squirrel.InsertBuilder {
allColumns := instances[0].Columns(true)
q := SQ.Insert(instances[0].Table()).
Columns(allColumns...)
for _, m := range instances {
q = q.Values(m.Values(allColumns...)...)
}
return q
}
// SQLInsertNoPKey build a Insert statement to insert one or several mapped instances
// in the database, but leave the pkey undefined so it gets auto-generated
// by the database. All the instances must be of the same actual type
func SQLInsertNoPKey(instances ...Mapped) squirrel.InsertBuilder {
allColumns := instances[0].Columns(false)
q := SQ.Insert(instances[0].Table()).
Columns(allColumns...)
for _, m := range instances {
q = q.Values(m.Values(allColumns...)...)
}
return q
}
// PrefixColumns ...
func PrefixColumns(table string, columns ...string) []string {
var prefixed = make([]string, len(columns))
for i, name := range columns {
prefixed[i] = PrefixColumn(table, name)
}
return columns
}
// PrefixColumn ...
func PrefixColumn(table string, column string) string {
return table + "." + column
}