Skip to content
Snippets Groups Projects
array_contains.go 1.27 KiB
Newer Older
package database

import (
	"database/sql/driver"
	"fmt"
	"sort"
	"strings"

	"github.com/jackc/pgtype"
)

type ArrayContains map[string]interface{}

func (ac ArrayContains) ToSQL() (sql string, args []interface{}, err error) {
	if len(ac) == 0 {
		// Empty Sql{} evaluates to true.
		sql = "(1=1)"
		return
	}

	sortedKeys := getSortedKeys(ac)
	var exprs []string

	for _, key := range sortedKeys {
		var expr string
		val := ac[key]

		switch v := val.(type) {
		case driver.Valuer:
			if val, err = v.Value(); err != nil {
				return
			}
		case []string:
			var data pgtype.TextArray
			if err = data.Set(v); err != nil {
				return
			}
			if val, err = data.Value(); err != nil {
				return
			}
		case string:
			var data pgtype.TextArray
			if err = data.Set([]string{v}); err != nil {
				return
			}
			if val, err = data.Value(); err != nil {
				return
			}
		}
		if val == nil {
			panic("cannot handle NULL values")
		} else {
			expr = fmt.Sprintf("%s @> ?", key)
			args = append(args, val)
		}
		exprs = append(exprs, expr)
	}

	sql = strings.Join(exprs, " AND ")
	return
}

func getSortedKeys(exp map[string]interface{}) []string {
	sortedKeys := make([]string, 0, len(exp))
	for k := range exp {
		sortedKeys = append(sortedKeys, k)
	}
	sort.Strings(sortedKeys)
	return sortedKeys
}