Skip to content
Snippets Groups Projects
generate_db_helpers.go 12.4 KiB
Newer Older
package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"log"
	"os"
	"sort"
	"strings"
	"text/template"

	"github.com/fatih/structtag"
)

// Package is the top-level package
type Package struct {
	Name string

	AllStructs map[string]*ast.StructType
	DBStructs  []DBStruct
}

func (p Package) getAllFields(structType *ast.StructType) []*ast.Field {
	var fields []*ast.Field

	for _, field := range structType.Fields.List {
		if len(field.Names) == 0 {
			if ident, ok := field.Type.(*ast.Ident); ok {
				if sub, ok := p.AllStructs[ident.String()]; ok {
					fields = append(fields, p.getAllFields(sub)...)
				}
			}
		} else {
			fields = append(fields, field)
		}
	}
	return fields
}

// DBStruct is a struct mapped to a table
type DBStruct struct {
	Name      string
	Tablename string
	PKey      DBField
	Fields    []DBField
}

// HasTable returns true if this structure is only a top level db struct that
// has is associated to a table
func (s DBStruct) HasTable() bool {
	return s.Tablename != ""
}

// IsEmbedded returns true if this structure is only embedded in other db structs
func (s DBStruct) IsEmbedded() bool {
	return s.Tablename == ""
}

// DBField is a field of a DBStruct
type DBField struct {
	Name   string
	Column string
	IsPKey bool
}

func main() {
	var (
		outputname = "db_helpers.go"
	)

	if len(os.Args) > 2 && os.Args[1] == "-o" {
		outputname = os.Args[2]
	}

	fset := token.NewFileSet()
	packages, err := parser.ParseDir(fset, ".", func(info os.FileInfo) bool {
		return info.Name() != outputname
	}, parser.ParseComments)
	if err != nil {
		log.Fatal(err)
	}
	if len(packages) == 2 {
		for k := range packages {
			if strings.HasSuffix(k, "_test") {
				delete(packages, k)
			}
		}
	}
	if len(packages) != 1 {
		log.Fatal("Expected to find 1 go package, got: ", len(packages), packages)
	}

	var pkg *ast.Package
	for _, p := range packages {
		pkg = p
	}
	if pkg == nil {
		log.Fatal("pkg cannot be nil here...")
	}

	topLevel := Package{Name: pkg.Name, AllStructs: make(map[string]*ast.StructType)}

	for _, file := range pkg.Files {
		if err := walkStructTypes(
			file,
			func(name, doc string, structType *ast.StructType) error {
				topLevel.AllStructs[name] = structType
				return nil
			}); err != nil {
			panic(err)
		}
	}
	for _, file := range pkg.Files {
		if err := walkDBStructType(
			&topLevel,
			file,
			func(dbstruct DBStruct) error {
				topLevel.DBStructs = append(topLevel.DBStructs, dbstruct)
				return nil
			}); err != nil {
			panic(err)
		}
	}

	sort.Slice(topLevel.DBStructs, func(i, j int) bool {
		return strings.Compare(topLevel.DBStructs[i].Name, topLevel.DBStructs[j].Name) < 0
	})

	out, err := os.OpenFile(outputname, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0666)
	if err != nil {
		panic(err)
	}

	if err := headTmpl.Execute(out, topLevel); err != nil {
		panic(err)
	}
}

func walkStructTypes(
	f *ast.File, visit func(name, doc string, structType *ast.StructType) error,
) error {
	for _, decl := range f.Decls {
		if genDecl, ok := decl.(*ast.GenDecl); ok {
			if len(genDecl.Specs) == 1 {
				spec := genDecl.Specs[0]
				if typeSpec, ok := spec.(*ast.TypeSpec); ok {
					if structType, ok := typeSpec.Type.(*ast.StructType); ok {
						if err := visit(
							typeSpec.Name.String(), genDecl.Doc.Text(), structType,
						); err != nil {
							return err
						}
					}
				}
			}
		}
	}
	return nil
}

func walkDBStructType(
	topLevel *Package, f *ast.File, visit func(DBStruct) error,
) error {
	return walkStructTypes(
		f,
		func(name, doc string, structType *ast.StructType) error {
			dbtag := findStructDocLineWithPrefix(doc, "dbtable:")
			dbstruct, err := newDBStruct(topLevel, name, dbtag, structType)
			if err != nil {
				return err
			}
			if dbstruct != nil {
				return visit(*dbstruct)
			}
			return nil
		})
}

func findStructDocLineWithPrefix(doc, prefix string) string {
	for _, line := range strings.Split(doc, "\n") {
		if strings.HasPrefix(line, prefix) {
			return line
		}
	}
	return ""
}

func newDBStruct(topLevel *Package, name string, tag string, structType *ast.StructType) (*DBStruct, error) {
	dbstruct := DBStruct{
		Name: name,
	}
	tags, err := structtag.Parse(tag)
	if err != nil {
		return nil, fmt.Errorf("error parsing tag `%s`: %s", tag, err)
	}

	tableTag, err := tags.Get("dbtable")
	if err != nil && err.Error() != "tag does not exist" {
		return nil, err
	}

	pkeyTag, err := tags.Get("dbpkey")
	if err != nil && err.Error() != "tag does not exist" {
		return nil, err
	}

	pkey := ""
	if tableTag != nil {
		dbstruct.Tablename = tableTag.Name
		if pkeyTag != nil {
			pkey = pkeyTag.Name
		}
	}

	if dbstruct.Fields, err = getDBFields(topLevel, structType); err != nil {
		return nil, err
	}

	if len(dbstruct.Fields) == 0 {
		return nil, nil
	}

	for i, field := range dbstruct.Fields {
		if field.Column == pkey {
			dbstruct.Fields[i].IsPKey = true
			dbstruct.PKey = dbstruct.Fields[i]
		}
	}

	return &dbstruct, nil
}

func getDBFields(topLevel *Package, structType *ast.StructType) ([]DBField, error) {
	var dbfields []DBField

	for _, field := range topLevel.getAllFields(structType) {
		if field.Tag == nil || len(field.Tag.Value) < 2 {
			continue
		}
		tags, err := structtag.Parse(
			field.Tag.Value[1 : len(field.Tag.Value)-1])
		if err != nil {
			return nil, err
		}

		dbtag, err := tags.Get("db")
		if err != nil {
			continue
		}

		dbfields = append(dbfields, DBField{
			Name:   field.Names[0].String(),
			Column: dbtag.Name,
		})
	}

	return dbfields, nil
}

var (
	headTmpl = template.Must(template.New("head").Parse(`// This file is generated by 'gendbfiles' - DO NOT EDIT

package {{.Name}}

// This file contains constants for all db names involved in a mapped struct
// It also add some accessors on the struct types so they implement a 'Mapped' interface

// Mapped is the common interface of all structs that are mapped in the database
type Mapped interface {
	Table() string
	PKeyColumn() string
	Columns(withPKey bool) []string
	Values(columns ...string) []interface{}
	ValuesMap(columns ...string) map[string]interface{}
func NewColumn(table TableSchema, name string) Column {
	return Column{table, name, ""}
}

	alias string
	sql :=  c.table.GetName()+"."+c.name
	if c.alias != "" {
		sql += " AS " + c.alias
	}
	return sql
}

func (c Column) ToSql() (string, []interface{}, error ) {
	return c.Sql(), nil, nil
}

func (c Column) Join(col Column) Join {
	return Join{
		from: c,
		to: col,
	}
}

func (c Column) Eq(value interface{}) squirrel.Eq{
	return squirrel.Eq{c.Sql(): value}
}

func (c Column) Ne(value interface{}) squirrel.NotEq{
	return squirrel.NotEq{c.Sql(): value}
}

func (c Column) Gt(value interface{}) squirrel.Gt{
	return squirrel.Gt{c.Sql(): value}
}

func (c Column) Gte(value interface{}) squirrel.GtOrEq{
	return squirrel.GtOrEq{c.Sql(): value}
}

func (c Column) Lt(value interface{}) squirrel.Lt{
	return squirrel.Lt{c.Sql(): value}
}

func (c Column) Lte(value interface{}) squirrel.LtOrEq{
	return squirrel.LtOrEq{c.Sql(): value}
}

func (c Column) Like(value interface{}) squirrel.Like{
	return squirrel.Like{c.Sql(): value}
}

func (c Column) NotLike(value interface{}) squirrel.NotLike{
	return squirrel.NotLike{c.Sql(): value}
}

func (c Column) As(name string) Column {
	c.alias = name
	return c
}

func (c Column) Name() string {
	if c.alias != "" {
		return c.alias
	}
	return c.table.GetName() + "." + c.name
}

func (c Column) ColName() string {
	return c.name
}

type Join struct {
	from Column
	to Column
}

func (j Join) Sql() string {
	return j.to.table.Sql() + " ON " + j.to.Sql() + " = " + j.from.Sql()
}


const (
	// table and column names
{{- range $i, $dbstruct := .DBStructs}}
	{{- if $dbstruct.HasTable}}

	// {{.Name}}Table is the name of the table where {{.Name}} are stored
	{{.Name}}Table = "{{.Tablename}}"
	{{- end}}

	{{- if .PKey.Name}}

	// {{.Name}}PKeyColumn is the name of the primary key
	{{.Name}}PKeyColumn = {{.Name}}{{.PKey.Name}}Column
	{{- end}}

	{{- range .Fields}}

	// {{$dbstruct.Name}}{{.Name}}Column is the name of the column containing field "{{.Name}}" data
	{{$dbstruct.Name}}{{.Name}}Column = "{{.Column}}"
	{{- end}}
{{- end}}
)

var (
	// DBAllTables is the list of all the database table names
	DBAllTables = []string{
{{- range .DBStructs}}
	{{- if .HasTable}}
		{{.Name}}Table,
	{{- end}}
{{- end}}
	}
{{- range $dbstruct := .DBStructs}}

	{{- if .HasTable}}
	// {{.Name}}DataColumns is the list of the columns for the {{.Name}} structure, expect its primary key
	{{.Name}}DataColumns = []string{
	{{- range .Fields}}
		{{- if not .IsPKey}}
		{{$dbstruct.Name}}{{.Name}}Column,
		{{- end}}
	{{- end}}
	}

	// {{.Name}}Columns is the list of the columns for the {{.Name}} structure
	{{- if .PKey.Name}}
	{{.Name}}Columns = append(
		[]string{ {{$dbstruct.Name}}{{.PKey.Name}}Column },
		{{.Name}}DataColumns...,
	)
	{{- else }}
	{{.Name}}Columns = {{.Name}}DataColumns
	{{- end}}
	{{- else}}
	// {{.Name}}Columns is the list of the columns for the {{.Name}} structure
	{{.Name}}Columns = []string{
	{{- range .Fields}}
		{{$dbstruct.Name}}{{.Name}}Column,
	{{- end}}
	}
	{{- end}}
{{- end}}
)

{{- range $dbstruct := .DBStructs}}

{{- if .HasTable}}

// Table returns the database table name
func (s {{.Name}}) Table() string {
	return {{.Name}}Table
}
{{- end}}

{{- if .PKey.Name}}

// PKeyColumn returns the database table primary key column name
func (s {{.Name}}) PKeyColumn() string {
	return {{.Name}}PKeyColumn
}
{{- end}}

{{- if .HasTable}}

// Columns returns the database table column names
func (s {{.Name}}) Columns(withPKey bool) []string {
	if withPKey {
		return {{.Name}}Columns
	}
	return {{.Name}}DataColumns
}
{{- else}}

// Columns returns the database table column names
func (s {{.Name}}) Columns() []string {
	return {{.Name}}Columns
}
{{- end}}

// Values returns the values for a list of columns. If a column does not exits,
// the corresponding value is left empty
func (s {{.Name}}) Values(columns ...string) []interface{} {
	values := make([]interface{}, len(columns))
	for i, column := range columns {
		switch column {
		{{- range .Fields}}
		case "{{.Column}}":
			values[i] = s.{{.Name}}
		{{- end}}
		}
	}
	return values
}
// ValuesMap returns the values map for a list of columns. If a column does not
// exits, the corresponding value is left empty
func (s {{.Name}}) ValuesMap(columns ...string) map[string]interface{} {
	values := make(map[string]interface{})
	for _, column := range columns {
		switch column {
		{{- range .Fields}}
		case "{{.Column}}":
			values["{{.Column}}"] = s.{{.Name}}
		{{- end}}
		}
	}
	return values
}
{{- if .HasTable}}

func New{{.Name}}TableSchema() *{{.Name}}TableSchema {
	t := {{.Name}}TableSchema{}
{{- range .Fields}}
	t.{{.Name}} = NewColumn(&t, "{{.Column}}")
	return &t
}

type {{.Name}}TableSchema struct {
	alias string

{{- range .Fields}}
	{{.Name}} Column
{{- end}}
}

// Columns returns the database table column names
func (t {{.Name}}TableSchema) Columns(withPKey bool) []string {
	if withPKey {
		return {{.Name}}Columns
	}
	return {{.Name}}DataColumns
}

// FQColumns returns the database table column names prefixed with the table alias
func (t {{.Name}}TableSchema) FQColumns(withPKey bool) []string {
	var colList []string
    colList = {{.Name}}DataColumns

	if withPKey {
		colList = {{.Name}}Columns	
	}

    var cols []string
	for _, col := range colList {
			cols = append(cols, t.GetName()+"."+col)
	}
	return cols
}

func (t {{.Name}}TableSchema) As(name string) *{{.Name}}TableSchema {
	t.alias = name
{{- range .Fields}}
	t.{{.Name}} = NewColumn(&t, "{{.Column}}")
{{- end}}
	return &t
}

func (t {{.Name}}TableSchema) GetName() string {
	if t.alias == "" {
		return {{.Name}}Table
	}
	return t.alias
}

func (t {{.Name}}TableSchema) Sql() string {
	if t.alias == "" {
		return {{.Name}}Table
	}
	return {{.Name}}Table + " AS " + t.alias
}
func (t {{.Name}}TableSchema) ToSql() (string, []interface{}, error) {
	return t.Sql(), nil, nil
}

func (t {{.Name}}TableSchema) Select() squirrel.SelectBuilder {
	return squirrel.Select(t.Columns(true)...).From(t.Sql())
}

{{- end}}


{{- end}}

func NewDBSchema() *DBSchema {
	return &DBSchema{
{{- range $dbstruct := .DBStructs}}
{{- if .HasTable}}
		{{.Name}}: New{{.Name}}TableSchema(),
{{- end}}
{{- end}}
	}
}

type DBSchema struct {
{{- range $dbstruct := .DBStructs}}
{{- if .HasTable}}
	{{.Name}} *{{.Name}}TableSchema
{{- end}}
{{- end}}
}

var Schema = NewDBSchema()