package database import ( "context" "database/sql" "strings" "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" "github.com/lann/builder" "github.com/rs/zerolog" ) // Mapped is the common interface of all structs that are mapped in the database type Mapped interface { Table() string PKeyColumn() string Columns(withPKey bool) []string Values(columns ...string) []interface{} } // SQ is a squirrel StatementBuilder var SQ = squirrel.StatementBuilder // SQLTrace logs a sql query and args as TRACE level func SQLTrace(log *zerolog.Logger, sql string, args []interface{}) { if log != nil && log.GetLevel() < 0 { arr := zerolog.Arr() for _, arg := range args { arr.Interface(arg) } logger := log.With().Array("args", arr).Logger() logger.Trace().Msg("SQL: " + sql) } } // SQLExecutor is the common interface of sqlx.DB and sqlx.Tx that we use in // the functions below type SQLExecutor interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error) } func setPlaceHolderFormat(query squirrel.Sqlizer) squirrel.Sqlizer { return builder.Set(query, "PlaceholderFormat", squirrel.Dollar).(squirrel.Sqlizer) } // Get loads an object func Get(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error { return GetContext(context.Background(), e, obj, query, log) } // GetContext loads an object func GetContext( ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger, ) error { sqlQuery, args, err := setPlaceHolderFormat(query).ToSql() if err != nil { return err } SQLTrace(log, sqlQuery, args) return e.GetContext(ctx, obj, sqlQuery, args...) } // Select load an object list func Select(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error { return SelectContext(context.Background(), e, obj, query, log) } // SelectContext loads an object list func SelectContext( ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger, ) error { sqlQuery, args, err := setPlaceHolderFormat(query).ToSql() if err != nil { return err } SQLTrace(log, sqlQuery, args) return e.SelectContext(ctx, obj, sqlQuery, args...) } // Exec runs a squirrel query on the given db or tx func Exec(e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) { return ExecContext(context.Background(), e, query, log) } // ExecContext runs a squirrel query on the given db or tx func ExecContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) { query = setPlaceHolderFormat(query) sqlQuery, args, err := query. ToSql() if err != nil { return nil, err } SQLTrace(log, sqlQuery, args) return e.ExecContext(ctx, sqlQuery, args...) } // QueryContext runs a squirrel query on the given db or tx func QueryContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (*sqlx.Rows, error) { query = setPlaceHolderFormat(query) sqlQuery, args, err := query. ToSql() if err != nil { return nil, err } SQLTrace(log, sqlQuery, args) return e.QueryxContext(ctx, sqlQuery, args...) } // ValuesMap returns the values for a list of columns as a map. If a column does // not exits, the corresponding value is set to nil func ValuesMap(m Mapped, columns ...string) map[string]interface{} { valuesList := m.Values(columns...) values := make(map[string]interface{}) for i, column := range columns { values[column] = valuesList[i] } return values } // SQLUpdate generates a squirrel "update" statement // for the given mapped instance (auto-selecting by its pkey) func SQLUpdate(m Mapped, columns ...string) squirrel.UpdateBuilder { if len(columns) == 0 { columns = m.Columns(false) } q := squirrel. Update(m.Table()). SetMap(ValuesMap(m, columns...)).Where( squirrel.Eq{m.PKeyColumn(): m.Values(m.PKeyColumn())}) return q } // SQLUpsert generates a squirrel "upsert" statement func SQLUpsert(m Mapped) squirrel.InsertBuilder { return SQLUpsertColumns(m, m.Columns(true), m.Columns(false)) } func SQLUpsertColumns( m Mapped, insertColumns []string, updateColumns []string, ) squirrel.InsertBuilder { updateSQL, updateArgs, err := squirrel. Update(m.Table()). SetMap(ValuesMap(m, updateColumns...)). ToSql() if err != nil { panic(err) } updateParts := strings.Split(updateSQL, " SET ") if len(updateParts) != 2 { panic("Could not split the UPDATE query: " + updateSQL) //nolint:gosec } suffix := "ON CONFLICT (" + m.PKeyColumn() + ") DO UPDATE SET " + updateParts[1] q := squirrel. Insert(m.Table()). PlaceholderFormat(squirrel.Dollar). Columns(insertColumns...). Values(m.Values(insertColumns...)...). Suffix(suffix, updateArgs...) return q } // SQLUpsertNoPKey generates a squirrel "upsert" statement func SQLUpsertNoPKey(keyCols []string, m Mapped) squirrel.InsertBuilder { updateSQL, updateArgs, err := squirrel. Update(m.Table()). SetMap(ValuesMap(m, m.Columns(false)...)). ToSql() if err != nil { panic(err) } updateParts := strings.Split(updateSQL, " SET ") if len(updateParts) != 2 { panic("Could not split the UPDATE query: " + updateSQL) //nolint:gosec } suffix := "ON CONFLICT (" + strings.Join(keyCols, ",") + ") DO UPDATE SET " + updateParts[1] allColumns := m.Columns(false) q := squirrel. Insert(m.Table()). PlaceholderFormat(squirrel.Dollar). Columns(allColumns...). Values(m.Values(allColumns...)...). Suffix(suffix, updateArgs...) return q } // SQLInsert build a Insert statement to insert one or several mapped instances // in the database. All the instances must be of the same actual type func SQLInsert(instances ...Mapped) squirrel.InsertBuilder { allColumns := instances[0].Columns(true) q := SQ.Insert(instances[0].Table()). Columns(allColumns...) for _, m := range instances { q = q.Values(m.Values(allColumns...)...) } return q } // SQLInsertNoPKey build a Insert statement to insert one or several mapped instances // in the database, but leave the pkey undefined so it gets auto-generated // by the database. All the instances must be of the same actual type func SQLInsertNoPKey(instances ...Mapped) squirrel.InsertBuilder { allColumns := instances[0].Columns(false) q := SQ.Insert(instances[0].Table()). Columns(allColumns...) for _, m := range instances { q = q.Values(m.Values(allColumns...)...) } return q } // PrefixColumns ... func PrefixColumns(table string, columns ...string) []string { var prefixed = make([]string, len(columns)) for i, name := range columns { prefixed[i] = PrefixColumn(table, name) } return columns } // PrefixColumn ... func PrefixColumn(table string, column string) string { return table + "." + column }