Skip to content
Snippets Groups Projects
array_contains.go 1.33 KiB
package database

import (
	"database/sql/driver"
	"sort"
	"strings"

	"github.com/jackc/pgtype"
)

type ArrayContains map[string]interface{}

//nolint:nakedret
func (ac ArrayContains) ToSql() (sql string, args []interface{}, err error) { //nolint:nonamedreturns,stylecheck,revive
	if len(ac) == 0 {
		// Empty Sql{} evaluates to true.
		sql = "(1=1)"

		return
	}

	sortedKeys := getSortedKeys(ac)

	exprs := make([]string, 0, len(sortedKeys))

	for _, key := range sortedKeys {
		var expr string
		val := ac[key]

		switch v := val.(type) {
		case driver.Valuer:
			if val, err = v.Value(); err != nil {
				return
			}
		case []string:
			var data pgtype.TextArray
			if err = data.Set(v); err != nil {
				return
			}
			if val, err = data.Value(); err != nil {
				return
			}
		case string:
			var data pgtype.TextArray
			if err = data.Set([]string{v}); err != nil {
				return
			}
			if val, err = data.Value(); err != nil {
				return
			}
		}
		if val == nil {
			panic("cannot handle NULL values")
		}
		expr = key + " @> ?"
		args = append(args, val)
		exprs = append(exprs, expr)
	}

	sql = strings.Join(exprs, " AND ")

	return
}

func getSortedKeys(exp map[string]interface{}) []string {
	sortedKeys := make([]string, 0, len(exp))
	for k := range exp {
		sortedKeys = append(sortedKeys, k)
	}
	sort.Strings(sortedKeys)

	return sortedKeys
}