-
Christophe de Vienne authoredChristophe de Vienne authored
generate_db_helpers.go 8.25 KiB
package main
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"log"
"os"
"sort"
"strings"
"text/template"
"github.com/fatih/structtag"
)
// Package is the top-level package
type Package struct {
Name string
AllStructs map[string]*ast.StructType
DBStructs []DBStruct
}
func (p Package) getAllFields(structType *ast.StructType) []*ast.Field {
var fields []*ast.Field
for _, field := range structType.Fields.List {
if len(field.Names) == 0 {
if ident, ok := field.Type.(*ast.Ident); ok {
if sub, ok := p.AllStructs[ident.String()]; ok {
fields = append(fields, p.getAllFields(sub)...)
}
}
} else {
fields = append(fields, field)
}
}
return fields
}
// DBStruct is a struct mapped to a table
type DBStruct struct {
Name string
Tablename string
PKey DBField
Fields []DBField
}
// HasTable returns true if this structure is only a top level db struct that
// has is associated to a table
func (s DBStruct) HasTable() bool {
return s.Tablename != ""
}
// IsEmbedded returns true if this structure is only embedded in other db structs
func (s DBStruct) IsEmbedded() bool {
return s.Tablename == ""
}
// DBField is a field of a DBStruct
type DBField struct {
Name string
Column string
IsPKey bool
}
func main() {
var (
outputname = "db_helpers.go"
)
if len(os.Args) > 2 && os.Args[1] == "-o" {
outputname = os.Args[2]
}
fset := token.NewFileSet()
packages, err := parser.ParseDir(fset, ".", func(info os.FileInfo) bool {
return info.Name() != outputname
}, parser.ParseComments)
if err != nil {
log.Fatal(err)
}
if len(packages) != 1 {
log.Fatal("Expected to find 1 go package, got: ", len(packages), packages)
}
var pkg *ast.Package
for _, p := range packages {
pkg = p
}
topLevel := Package{Name: pkg.Name, AllStructs: make(map[string]*ast.StructType)}
for _, file := range pkg.Files {
if err := walkStructTypes(
file,
func(name, doc string, structType *ast.StructType) error {
topLevel.AllStructs[name] = structType
return nil
}); err != nil {
panic(err)
}
}
for _, file := range pkg.Files {
if err := walkDBStructType(
&topLevel,
file,
func(dbstruct DBStruct) error {
topLevel.DBStructs = append(topLevel.DBStructs, dbstruct)
return nil
}); err != nil {
panic(err)
}
}
sort.Slice(topLevel.DBStructs, func(i, j int) bool {
return strings.Compare(topLevel.DBStructs[i].Name, topLevel.DBStructs[j].Name) < 0
})
out, err := os.OpenFile(outputname, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0666)
if err != nil {
panic(err)
}
if err := headTmpl.Execute(out, topLevel); err != nil {
panic(err)
}
}
func walkStructTypes(
f *ast.File, visit func(name, doc string, structType *ast.StructType) error,
) error {
for _, decl := range f.Decls {
if genDecl, ok := decl.(*ast.GenDecl); ok {
if len(genDecl.Specs) == 1 {
spec := genDecl.Specs[0]
if typeSpec, ok := spec.(*ast.TypeSpec); ok {
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
if err := visit(
typeSpec.Name.String(), genDecl.Doc.Text(), structType,
); err != nil {
return err
}
}
}
}
}
}
return nil
}
func walkDBStructType(
topLevel *Package, f *ast.File, visit func(DBStruct) error,
) error {
return walkStructTypes(
f,
func(name, doc string, structType *ast.StructType) error {
dbtag := findStructDocLineWithPrefix(doc, "dbtable:")
dbstruct, err := newDBStruct(topLevel, name, dbtag, structType)
if err != nil {
return err
}
if dbstruct != nil {
return visit(*dbstruct)
}
return nil
})
}
func findStructDocLineWithPrefix(doc, prefix string) string {
for _, line := range strings.Split(doc, "\n") {
if strings.HasPrefix(line, prefix) {
return line
}
}
return ""
}
func newDBStruct(topLevel *Package, name string, tag string, structType *ast.StructType) (*DBStruct, error) {
dbstruct := DBStruct{
Name: name,
}
tags, err := structtag.Parse(tag)
if err != nil {
return nil, fmt.Errorf("error parsing tag `%s`: %s", tag, err)
}
tableTag, err := tags.Get("dbtable")
if err != nil && err.Error() != "tag does not exist" {
return nil, err
}
pkeyTag, err := tags.Get("dbpkey")
if err != nil && err.Error() != "tag does not exist" {
return nil, err
}
pkey := ""
if tableTag != nil {
dbstruct.Tablename = tableTag.Name
if pkeyTag != nil {
pkey = pkeyTag.Name
}
}
if dbstruct.Fields, err = getDBFields(topLevel, structType); err != nil {
return nil, err
}
if len(dbstruct.Fields) == 0 {
return nil, nil
}
for i, field := range dbstruct.Fields {
if field.Column == pkey {
dbstruct.Fields[i].IsPKey = true
dbstruct.PKey = dbstruct.Fields[i]
}
}
return &dbstruct, nil
}
func getDBFields(topLevel *Package, structType *ast.StructType) ([]DBField, error) {
var dbfields []DBField
for _, field := range topLevel.getAllFields(structType) {
if field.Tag == nil || len(field.Tag.Value) < 2 {
continue
}
tags, err := structtag.Parse(
field.Tag.Value[1 : len(field.Tag.Value)-1])
if err != nil {
return nil, err
}
dbtag, err := tags.Get("db")
if err != nil {
continue
}
dbfields = append(dbfields, DBField{
Name: field.Names[0].String(),
Column: dbtag.Name,
})
}
return dbfields, nil
}
var (
headTmpl = template.Must(template.New("head").Parse(`// This file is generated by 'gendbfiles' - DO NOT EDIT
package {{.Name}}
// This file contains constants for all db names involved in a mapped struct
// It also add some accessors on the struct types so they implement a 'Mapped' interface
// Mapped is the common interface of all structs that are mapped in the database
type Mapped interface {
Table() string
PKeyColumn() string
Columns(withPKey bool) []string
Values(columns ...string) []interface{}
}
const (
// table and column names
{{- range $i, $dbstruct := .DBStructs}}
{{- if $dbstruct.HasTable}}
// {{.Name}}Table is the name of the table where {{.Name}} are stored
{{.Name}}Table = "{{.Tablename}}"
{{- end}}
{{- if .PKey.Name}}
// {{.Name}}PKeyColumn is the name of the primary key
{{.Name}}PKeyColumn = {{.Name}}{{.PKey.Name}}Column
{{- end}}
{{- range .Fields}}
// {{$dbstruct.Name}}{{.Name}}Column is the name of the column containing field "{{.Name}}" data
{{$dbstruct.Name}}{{.Name}}Column = "{{.Column}}"
{{- end}}
{{- end}}
)
var (
// DBAllTables is the list of all the database table names
DBAllTables = []string{
{{- range .DBStructs}}
{{- if .HasTable}}
{{.Name}}Table,
{{- end}}
{{- end}}
}
{{- range $dbstruct := .DBStructs}}
{{- if .HasTable}}
// {{.Name}}DataColumns is the list of the columns for the {{.Name}} structure, expect its primary key
{{.Name}}DataColumns = []string{
{{- range .Fields}}
{{- if not .IsPKey}}
{{$dbstruct.Name}}{{.Name}}Column,
{{- end}}
{{- end}}
}
// {{.Name}}Columns is the list of the columns for the {{.Name}} structure
{{- if .PKey.Name}}
{{.Name}}Columns = append(
[]string{ {{$dbstruct.Name}}{{.PKey.Name}}Column },
{{.Name}}DataColumns...,
)
{{- else }}
{{.Name}}Columns = {{.Name}}DataColumns
{{- end}}
{{- else}}
// {{.Name}}Columns is the list of the columns for the {{.Name}} structure
{{.Name}}Columns = []string{
{{- range .Fields}}
{{$dbstruct.Name}}{{.Name}}Column,
{{- end}}
}
{{- end}}
{{- end}}
)
{{- range $dbstruct := .DBStructs}}
{{- if .HasTable}}
// Table returns the database table name
func (s {{.Name}}) Table() string {
return {{.Name}}Table
}
{{- end}}
{{- if .PKey.Name}}
// PKeyColumn returns the database table primary key column name
func (s {{.Name}}) PKeyColumn() string {
return {{.Name}}PKeyColumn
}
{{- end}}
{{- if .HasTable}}
// Columns returns the database table column names
func (s {{.Name}}) Columns(withPKey bool) []string {
if withPKey {
return {{.Name}}Columns
}
return {{.Name}}DataColumns
}
{{- else}}
// Columns returns the database table column names
func (s {{.Name}}) Columns() []string {
return {{.Name}}Columns
}
{{- end}}
// Values returns the values for a list of columns. If a column does not exits,
// the corresponding value is left empty
func (s {{.Name}}) Values(columns ...string) []interface{} {
values := make([]interface{}, len(columns))
for i, column := range columns {
switch column {
{{- range .Fields}}
case "{{.Column}}":
values[i] = s.{{.Name}}
{{- end}}
}
}
return values
}
{{- end}}
`))
)