# HG changeset patch # User Christophe de Vienne <christophe.devienne@orus.io> # Date 1641310474 -3600 # Tue Jan 04 16:34:34 2022 +0100 # Node ID 938b5b90d3868910456337685c20280da081b724 # Parent d9d9948c52fac189d77597c6087c5058994eea56 Add a 'ArrayContains' operator ('@>' in PG) diff --git a/database/array_contains.go b/database/array_contains.go new file mode 100644 --- /dev/null +++ b/database/array_contains.go @@ -0,0 +1,70 @@ +package database + +import ( + "database/sql/driver" + "fmt" + "sort" + "strings" + + "github.com/jackc/pgtype" +) + +type ArrayContains map[string]interface{} + +func (ac ArrayContains) ToSQL() (sql string, args []interface{}, err error) { + if len(ac) == 0 { + // Empty Sql{} evaluates to true. + sql = "(1=1)" + return + } + + sortedKeys := getSortedKeys(ac) + var exprs []string + + for _, key := range sortedKeys { + var expr string + val := ac[key] + + switch v := val.(type) { + case driver.Valuer: + if val, err = v.Value(); err != nil { + return + } + case []string: + var data pgtype.TextArray + if err = data.Set(v); err != nil { + return + } + if val, err = data.Value(); err != nil { + return + } + case string: + var data pgtype.TextArray + if err = data.Set([]string{v}); err != nil { + return + } + if val, err = data.Value(); err != nil { + return + } + } + if val == nil { + panic("cannot handle NULL values") + } else { + expr = fmt.Sprintf("%s @> ?", key) + args = append(args, val) + } + exprs = append(exprs, expr) + } + + sql = strings.Join(exprs, " AND ") + return +} + +func getSortedKeys(exp map[string]interface{}) []string { + sortedKeys := make([]string, 0, len(exp)) + for k := range exp { + sortedKeys = append(sortedKeys, k) + } + sort.Strings(sortedKeys) + return sortedKeys +} diff --git a/scripts/generate_db_helpers.go b/scripts/generate_db_helpers.go --- a/scripts/generate_db_helpers.go +++ b/scripts/generate_db_helpers.go @@ -344,6 +344,10 @@ return squirrel.NotLike{c.Sql(): value} } +func (c Column) ArrayContains(value interface{}) database.ArrayContains{ + return database.ArrayContains{c.Sql(): value} +} + func (c Column) As(name string) Column { c.alias = name return c