diff --git a/scripts/generate_db_helpers.go b/scripts/generate_db_helpers.go new file mode 100644 index 0000000000000000000000000000000000000000..d6032c71d396d912c0a8137745a737a45356f5ce_c2NyaXB0cy9nZW5lcmF0ZV9kYl9oZWxwZXJzLmdv --- /dev/null +++ b/scripts/generate_db_helpers.go @@ -0,0 +1,383 @@ +package main + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "log" + "os" + "sort" + "strings" + "text/template" + + "github.com/fatih/structtag" +) + +// Package is the top-level package +type Package struct { + Name string + + AllStructs map[string]*ast.StructType + DBStructs []DBStruct +} + +func (p Package) getAllFields(structType *ast.StructType) []*ast.Field { + var fields []*ast.Field + + for _, field := range structType.Fields.List { + if len(field.Names) == 0 { + if ident, ok := field.Type.(*ast.Ident); ok { + if sub, ok := p.AllStructs[ident.String()]; ok { + fields = append(fields, p.getAllFields(sub)...) + } + } + } else { + fields = append(fields, field) + } + } + return fields +} + +// DBStruct is a struct mapped to a table +type DBStruct struct { + Name string + Tablename string + PKey DBField + Fields []DBField +} + +// HasTable returns true if this structure is only a top level db struct that +// has is associated to a table +func (s DBStruct) HasTable() bool { + return s.Tablename != "" +} + +// IsEmbedded returns true if this structure is only embedded in other db structs +func (s DBStruct) IsEmbedded() bool { + return s.Tablename == "" +} + +// DBField is a field of a DBStruct +type DBField struct { + Name string + Column string + IsPKey bool +} + +func main() { + var ( + outputname = "db_helpers.go" + ) + + if len(os.Args) > 2 && os.Args[1] == "-o" { + outputname = os.Args[2] + } + + fset := token.NewFileSet() + packages, err := parser.ParseDir(fset, ".", func(info os.FileInfo) bool { + return info.Name() != outputname + }, parser.ParseComments) + if err != nil { + log.Fatal(err) + } + if len(packages) != 1 { + log.Fatal("Expected to find 1 go package, got: ", len(packages), packages) + } + + var pkg *ast.Package + for _, p := range packages { + pkg = p + } + + topLevel := Package{Name: pkg.Name, AllStructs: make(map[string]*ast.StructType)} + + for _, file := range pkg.Files { + if err := walkStructTypes( + file, + func(name, doc string, structType *ast.StructType) error { + topLevel.AllStructs[name] = structType + return nil + }); err != nil { + panic(err) + } + } + for _, file := range pkg.Files { + if err := walkDBStructType( + &topLevel, + file, + func(dbstruct DBStruct) error { + topLevel.DBStructs = append(topLevel.DBStructs, dbstruct) + return nil + }); err != nil { + panic(err) + } + } + + sort.Slice(topLevel.DBStructs, func(i, j int) bool { + return strings.Compare(topLevel.DBStructs[i].Name, topLevel.DBStructs[j].Name) < 0 + }) + + out, err := os.OpenFile(outputname, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0666) + if err != nil { + panic(err) + } + + if err := headTmpl.Execute(out, topLevel); err != nil { + panic(err) + } +} + +func walkStructTypes( + f *ast.File, visit func(name, doc string, structType *ast.StructType) error, +) error { + for _, decl := range f.Decls { + if genDecl, ok := decl.(*ast.GenDecl); ok { + if len(genDecl.Specs) == 1 { + spec := genDecl.Specs[0] + if typeSpec, ok := spec.(*ast.TypeSpec); ok { + if structType, ok := typeSpec.Type.(*ast.StructType); ok { + if err := visit( + typeSpec.Name.String(), genDecl.Doc.Text(), structType, + ); err != nil { + return err + } + } + } + } + } + } + return nil +} + +func walkDBStructType( + topLevel *Package, f *ast.File, visit func(DBStruct) error, +) error { + return walkStructTypes( + f, + func(name, doc string, structType *ast.StructType) error { + dbtag := findStructDocLineWithPrefix(doc, "dbtable:") + dbstruct, err := newDBStruct(topLevel, name, dbtag, structType) + if err != nil { + return err + } + if dbstruct != nil { + return visit(*dbstruct) + } + return nil + }) +} + +func findStructDocLineWithPrefix(doc, prefix string) string { + for _, line := range strings.Split(doc, "\n") { + if strings.HasPrefix(line, prefix) { + return line + } + } + return "" +} + +func newDBStruct(topLevel *Package, name string, tag string, structType *ast.StructType) (*DBStruct, error) { + dbstruct := DBStruct{ + Name: name, + } + tags, err := structtag.Parse(tag) + if err != nil { + return nil, fmt.Errorf("error parsing tag `%s`: %s", tag, err) + } + + tableTag, err := tags.Get("dbtable") + if err != nil && err.Error() != "tag does not exist" { + return nil, err + } + + pkeyTag, err := tags.Get("dbpkey") + if err != nil && err.Error() != "tag does not exist" { + return nil, err + } + + pkey := "" + if tableTag != nil { + dbstruct.Tablename = tableTag.Name + if pkeyTag != nil { + pkey = pkeyTag.Name + } + } + + if dbstruct.Fields, err = getDBFields(topLevel, structType); err != nil { + return nil, err + } + + if len(dbstruct.Fields) == 0 { + return nil, nil + } + + for i, field := range dbstruct.Fields { + if field.Column == pkey { + dbstruct.Fields[i].IsPKey = true + dbstruct.PKey = dbstruct.Fields[i] + } + } + + return &dbstruct, nil +} + +func getDBFields(topLevel *Package, structType *ast.StructType) ([]DBField, error) { + var dbfields []DBField + + for _, field := range topLevel.getAllFields(structType) { + if field.Tag == nil || len(field.Tag.Value) < 2 { + continue + } + tags, err := structtag.Parse( + field.Tag.Value[1 : len(field.Tag.Value)-1]) + if err != nil { + return nil, err + } + + dbtag, err := tags.Get("db") + if err != nil { + continue + } + + dbfields = append(dbfields, DBField{ + Name: field.Names[0].String(), + Column: dbtag.Name, + }) + } + + return dbfields, nil +} + +var ( + headTmpl = template.Must(template.New("head").Parse(`// This file is generated by 'gendbfiles' - DO NOT EDIT + +package {{.Name}} + +// This file contains constants for all db names involved in a mapped struct +// It also add some accessors on the struct types so they implement a 'Mapped' interface + +// Mapped is the common interface of all structs that are mapped in the database +type Mapped interface { + Table() string + PKeyColumn() string + Columns(withPKey bool) []string + Values(columns ...string) []interface{} +} + +const ( + // table and column names +{{- range $i, $dbstruct := .DBStructs}} + {{- if $dbstruct.HasTable}} + + // {{.Name}}Table is the name of the table where {{.Name}} are stored + {{.Name}}Table = "{{.Tablename}}" + {{- end}} + + {{- if .PKey.Name}} + + // {{.Name}}PKeyColumn is the name of the primary key + {{.Name}}PKeyColumn = {{.Name}}{{.PKey.Name}}Column + {{- end}} + + {{- range .Fields}} + + // {{$dbstruct.Name}}{{.Name}}Column is the name of the column containing field "{{.Name}}" data + {{$dbstruct.Name}}{{.Name}}Column = "{{.Column}}" + {{- end}} +{{- end}} +) + +var ( + // DBAllTables is the list of all the database table names + DBAllTables = []string{ +{{- range .DBStructs}} + {{- if .HasTable}} + {{.Name}}Table, + {{- end}} +{{- end}} + } +{{- range $dbstruct := .DBStructs}} + + {{- if .HasTable}} + // {{.Name}}DataColumns is the list of the columns for the {{.Name}} structure, expect its primary key + {{.Name}}DataColumns = []string{ + {{- range .Fields}} + {{- if not .IsPKey}} + {{$dbstruct.Name}}{{.Name}}Column, + {{- end}} + {{- end}} + } + + // {{.Name}}Columns is the list of the columns for the {{.Name}} structure + {{- if .PKey.Name}} + {{.Name}}Columns = append( + []string{ {{$dbstruct.Name}}{{.PKey.Name}}Column }, + {{.Name}}DataColumns..., + ) + {{- else }} + {{.Name}}Columns = {{.Name}}DataColumns + {{- end}} + {{- else}} + // {{.Name}}Columns is the list of the columns for the {{.Name}} structure + {{.Name}}Columns = []string{ + {{- range .Fields}} + {{$dbstruct.Name}}{{.Name}}Column, + {{- end}} + } + {{- end}} +{{- end}} +) + +{{- range $dbstruct := .DBStructs}} + +{{- if .HasTable}} + +// Table returns the database table name +func (s {{.Name}}) Table() string { + return {{.Name}}Table +} +{{- end}} + +{{- if .PKey.Name}} + +// PKeyColumn returns the database table primary key column name +func (s {{.Name}}) PKeyColumn() string { + return {{.Name}}PKeyColumn +} +{{- end}} + +{{- if .HasTable}} + +// Columns returns the database table column names +func (s {{.Name}}) Columns(withPKey bool) []string { + if withPKey { + return {{.Name}}Columns + } + return {{.Name}}DataColumns +} +{{- else}} + +// Columns returns the database table column names +func (s {{.Name}}) Columns() []string { + return {{.Name}}Columns +} +{{- end}} + +// Values returns the values for a list of columns. If a column does not exits, +// the corresponding value is left empty +func (s {{.Name}}) Values(columns ...string) []interface{} { + values := make([]interface{}, len(columns)) + for i, column := range columns { + switch column { + {{- range .Fields}} + case "{{.Column}}": + values[i] = s.{{.Name}} + {{- end}} + } + } + return values +} +{{- end}} +`)) +)