package database import ( "context" "database/sql" "encoding/json" "strings" "github.com/Masterminds/squirrel" "github.com/jmoiron/sqlx" "github.com/lann/builder" "github.com/rs/zerolog" ) // Mapped is the common interface of all structs that are mapped in the database. type Mapped interface { Table() string PKeyColumn() string Columns(withPKey bool) []string Values(columns ...string) []interface{} } // SQ is a squirrel StatementBuilder. var SQ = squirrel.StatementBuilder // SQLTrace logs a sql query and args as TRACE level. func SQLTrace(log *zerolog.Logger, sql string, args []interface{}) { if log != nil && log.GetLevel() < 0 { arr := zerolog.Arr() for _, arg := range args { b, err := json.Marshal(arg) if err != nil { log.Err(err).Msg("error Marshaling arguments") arr.Interface(arg) } else { if len(b) > 512 { b, err = json.Marshal(map[string]interface{}{ "arg first 512 bytes": b[:512], "len": len(b), }) if err != nil { log.Err(err).Msg("error Marshaling first 512 bytes of arguments") } } arr.RawJSON(b) } } logger := log.With().Array("args", arr).Logger() logger.Trace().Msg("SQL: " + sql) } } // SQLExecutor is the common interface of sqlx.DB and sqlx.Tx that we use in // the functions below. type SQLExecutor interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error) QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row } func setPlaceHolderFormat(query squirrel.Sqlizer) squirrel.Sqlizer { sqlizer, ok := builder.Set(query, "PlaceholderFormat", squirrel.Dollar).(squirrel.Sqlizer) if !ok { panic("could not assert sqlizer from builder") } return sqlizer } // Get loads an object. func Get(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error { return GetContext(context.Background(), e, obj, query, log) } // GetContext loads an object. func GetContext( ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger, ) error { sqlQuery, args, err := setPlaceHolderFormat(query).ToSql() if err != nil { return err } SQLTrace(log, sqlQuery, args) return e.GetContext(ctx, obj, sqlQuery, args...) } // Select load an object list. func Select(e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger) error { return SelectContext(context.Background(), e, obj, query, log) } // SelectContext loads an object list. func SelectContext( ctx context.Context, e SQLExecutor, obj interface{}, query squirrel.Sqlizer, log *zerolog.Logger, ) error { sqlQuery, args, err := setPlaceHolderFormat(query).ToSql() if err != nil { return err } SQLTrace(log, sqlQuery, args) return e.SelectContext(ctx, obj, sqlQuery, args...) } // Exec runs a squirrel query on the given db or tx. func Exec(e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) { return ExecContext(context.Background(), e, query, log) } // ExecContext runs a squirrel query on the given db or tx. func ExecContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (sql.Result, error) { query = setPlaceHolderFormat(query) sqlQuery, args, err := query. ToSql() if err != nil { return nil, err } SQLTrace(log, sqlQuery, args) return e.ExecContext(ctx, sqlQuery, args...) } // QueryContext runs a squirrel query on the given db or tx. func QueryContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) (*sqlx.Rows, error) { query = setPlaceHolderFormat(query) sqlQuery, args, err := query. ToSql() if err != nil { return nil, err } SQLTrace(log, sqlQuery, args) return e.QueryxContext(ctx, sqlQuery, args...) } // QueryRowContext runs a squirrel query on the given db or tx and return a single Row. func QueryRowContext(ctx context.Context, e SQLExecutor, query squirrel.Sqlizer, log *zerolog.Logger) *Row { query = setPlaceHolderFormat(query) sqlQuery, args, err := query. ToSql() if err != nil { return &Row{err: err} } SQLTrace(log, sqlQuery, args) return &Row{Row: e.QueryRowxContext(ctx, sqlQuery, args...)} } // ValuesMap returns the values for a list of columns as a map. If a column does // not exits, the corresponding value is set to nil. func ValuesMap(m Mapped, columns ...string) map[string]interface{} { valuesList := m.Values(columns...) values := make(map[string]interface{}) for i, column := range columns { values[column] = valuesList[i] } return values } // SQLUpdate generates a squirrel "update" statement // for the given mapped instance (auto-selecting by its pkey). func SQLUpdate(m Mapped, columns ...string) squirrel.UpdateBuilder { if len(columns) == 0 { columns = m.Columns(false) } q := squirrel. Update(m.Table()). SetMap(ValuesMap(m, columns...)).Where( squirrel.Eq{m.PKeyColumn(): m.Values(m.PKeyColumn())}) return q } // SQLUpsert generates a squirrel "upsert" statement. func SQLUpsert(m Mapped) squirrel.InsertBuilder { return SQLUpsertColumns(m, m.Columns(true), m.Columns(false)) } func SQLUpsertColumns( m Mapped, insertColumns []string, updateColumns []string, ) squirrel.InsertBuilder { updateSQL, updateArgs, err := squirrel. Update(m.Table()). SetMap(ValuesMap(m, updateColumns...)). ToSql() if err != nil { panic(err) } updateParts := strings.Split(updateSQL, " SET ") if len(updateParts) != 2 { panic("Could not split the UPDATE query: " + updateSQL) } suffix := "ON CONFLICT (" + m.PKeyColumn() + ") DO UPDATE SET " + updateParts[1] q := squirrel. Insert(m.Table()). PlaceholderFormat(squirrel.Dollar). Columns(insertColumns...). Values(m.Values(insertColumns...)...). Suffix(suffix, updateArgs...) return q } func SQLUpsertBase(table string, keyCols []string, cols []string) squirrel.InsertBuilder { uq := squirrel.Update(table) for _, col := range cols { uq = uq.Set(col, squirrel.Expr("EXCLUDED."+col)) } updateSQL, updateArgs, err := uq.ToSql() if err != nil { panic(err) } updateParts := strings.Split(updateSQL, " SET ") if len(updateParts) != 2 { panic("Could not split the UPDATE query: " + updateSQL) } suffix := "ON CONFLICT (" + strings.Join(keyCols, ",") + ") DO UPDATE SET " + updateParts[1] q := squirrel. Insert(table). PlaceholderFormat(squirrel.Dollar). Columns(cols...). Suffix(suffix, updateArgs...) return q } // SQLUpsertNoPKey generates a squirrel "upsert" statement. func SQLUpsertNoPKey(keyCols []string, m Mapped, mm ...Mapped) squirrel.InsertBuilder { allColumns := m.Columns(false) q := SQLUpsertBase(m.Table(), keyCols, allColumns). Values(m.Values(allColumns...)...) for _, m := range mm { q = q.Values(m.Values(allColumns...)...) } return q } // SQLInsert build a Insert statement to insert one or several mapped instances // in the database. All the instances must be of the same actual type. func SQLInsert(instances ...Mapped) squirrel.InsertBuilder { allColumns := instances[0].Columns(true) q := SQ.Insert(instances[0].Table()). Columns(allColumns...) for _, m := range instances { q = q.Values(m.Values(allColumns...)...) } return q } // SQLInsertNoPKey build a Insert statement to insert one or several mapped instances // in the database, but leave the pkey undefined so it gets auto-generated // by the database. All the instances must be of the same actual type. func SQLInsertNoPKey(instances ...Mapped) squirrel.InsertBuilder { allColumns := instances[0].Columns(false) q := SQ.Insert(instances[0].Table()). Columns(allColumns...) for _, m := range instances { q = q.Values(m.Values(allColumns...)...) } return q } // PrefixColumns ... func PrefixColumns(table string, columns ...string) []string { prefixed := make([]string, len(columns)) for i, name := range columns { prefixed[i] = PrefixColumn(table, name) } return columns } // PrefixColumn ... func PrefixColumn(table string, column string) string { return table + "." + column }