package main import ( "fmt" "go/ast" "go/parser" "go/token" "log" "os" "sort" "strings" "text/template" "github.com/fatih/structtag" ) // Package is the top-level package type Package struct { Name string AllStructs map[string]*ast.StructType DBStructs []DBStruct } func (p Package) getAllFields(structType *ast.StructType) []*ast.Field { var fields []*ast.Field for _, field := range structType.Fields.List { if len(field.Names) == 0 { if ident, ok := field.Type.(*ast.Ident); ok { if sub, ok := p.AllStructs[ident.String()]; ok { fields = append(fields, p.getAllFields(sub)...) } } } else { fields = append(fields, field) } } return fields } // DBStruct is a struct mapped to a table type DBStruct struct { Name string Tablename string PKey DBField Fields []DBField } // HasTable returns true if this structure is only a top level db struct that // has is associated to a table func (s DBStruct) HasTable() bool { return s.Tablename != "" } // IsEmbedded returns true if this structure is only embedded in other db structs func (s DBStruct) IsEmbedded() bool { return s.Tablename == "" } // DBField is a field of a DBStruct type DBField struct { Name string Column string IsPKey bool } func main() { var ( outputname = "db_helpers.go" ) if len(os.Args) > 2 && os.Args[1] == "-o" { outputname = os.Args[2] } fset := token.NewFileSet() packages, err := parser.ParseDir(fset, ".", func(info os.FileInfo) bool { return info.Name() != outputname }, parser.ParseComments) if err != nil { log.Fatal(err) } if len(packages) == 2 { for k := range packages { if strings.HasSuffix(k, "_test") { delete(packages, k) } } } if len(packages) != 1 { log.Fatal("Expected to find 1 go package, got: ", len(packages), packages) } var pkg *ast.Package for _, p := range packages { pkg = p } if pkg == nil { log.Fatal("pkg cannot be nil here...") } topLevel := Package{Name: pkg.Name, AllStructs: make(map[string]*ast.StructType)} for _, file := range pkg.Files { if err := walkStructTypes( file, func(name, doc string, structType *ast.StructType) error { topLevel.AllStructs[name] = structType return nil }); err != nil { panic(err) } } for _, file := range pkg.Files { if err := walkDBStructType( &topLevel, file, func(dbstruct DBStruct) error { topLevel.DBStructs = append(topLevel.DBStructs, dbstruct) return nil }); err != nil { panic(err) } } sort.Slice(topLevel.DBStructs, func(i, j int) bool { return strings.Compare(topLevel.DBStructs[i].Name, topLevel.DBStructs[j].Name) < 0 }) out, err := os.OpenFile(outputname, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0666) if err != nil { panic(err) } if err := headTmpl.Execute(out, topLevel); err != nil { panic(err) } } func walkStructTypes( f *ast.File, visit func(name, doc string, structType *ast.StructType) error, ) error { for _, decl := range f.Decls { if genDecl, ok := decl.(*ast.GenDecl); ok { if len(genDecl.Specs) == 1 { spec := genDecl.Specs[0] if typeSpec, ok := spec.(*ast.TypeSpec); ok { if structType, ok := typeSpec.Type.(*ast.StructType); ok { if err := visit( typeSpec.Name.String(), genDecl.Doc.Text(), structType, ); err != nil { return err } } } } } } return nil } func walkDBStructType( topLevel *Package, f *ast.File, visit func(DBStruct) error, ) error { return walkStructTypes( f, func(name, doc string, structType *ast.StructType) error { dbtag := findStructDocLineWithPrefix(doc, "dbtable:") dbstruct, err := newDBStruct(topLevel, name, dbtag, structType) if err != nil { return err } if dbstruct != nil { return visit(*dbstruct) } return nil }) } func findStructDocLineWithPrefix(doc, prefix string) string { for _, line := range strings.Split(doc, "\n") { if strings.HasPrefix(line, prefix) { return line } } return "" } func newDBStruct(topLevel *Package, name string, tag string, structType *ast.StructType) (*DBStruct, error) { dbstruct := DBStruct{ Name: name, } tags, err := structtag.Parse(tag) if err != nil { return nil, fmt.Errorf("error parsing tag `%s`: %s", tag, err) } tableTag, err := tags.Get("dbtable") if err != nil && err.Error() != "tag does not exist" { return nil, err } pkeyTag, err := tags.Get("dbpkey") if err != nil && err.Error() != "tag does not exist" { return nil, err } pkey := "" if tableTag != nil { dbstruct.Tablename = tableTag.Name if pkeyTag != nil { pkey = pkeyTag.Name } } if dbstruct.Fields, err = getDBFields(topLevel, structType); err != nil { return nil, err } if len(dbstruct.Fields) == 0 { return nil, nil } for i, field := range dbstruct.Fields { if field.Column == pkey { dbstruct.Fields[i].IsPKey = true dbstruct.PKey = dbstruct.Fields[i] } } return &dbstruct, nil } func getDBFields(topLevel *Package, structType *ast.StructType) ([]DBField, error) { var dbfields []DBField for _, field := range topLevel.getAllFields(structType) { if field.Tag == nil || len(field.Tag.Value) < 2 { continue } tags, err := structtag.Parse( field.Tag.Value[1 : len(field.Tag.Value)-1]) if err != nil { return nil, err } dbtag, err := tags.Get("db") if err != nil { continue } dbfields = append(dbfields, DBField{ Name: field.Names[0].String(), Column: dbtag.Name, }) } return dbfields, nil } var ( headTmpl = template.Must(template.New("head").Parse(`// This file is generated by 'gendbfiles' - DO NOT EDIT package {{.Name}} // This file contains constants for all db names involved in a mapped struct // It also add some accessors on the struct types so they implement a 'Mapped' interface import "github.com/Masterminds/squirrel" // Mapped is the common interface of all structs that are mapped in the database type Mapped interface { Table() string PKeyColumn() string Columns(withPKey bool) []string Values(columns ...string) []interface{} } type TableSchema interface { GetName() string Sql() string } func NewColumn(table TableSchema, name string) Column { return Column{table, name, ""} } type Column struct { table TableSchema name string alias string } func (c Column) Sql() string { sql := c.table.GetName()+"."+c.name if c.alias != "" { sql += " AS " + c.alias } return sql } func (c Column) ToSql() (string, []interface{}, error ) { return c.Sql(), nil, nil } func (c Column) Join(col Column) Join { return Join{ from: c, to: col, } } func (c Column) Eq(value interface{}) squirrel.Eq{ return squirrel.Eq{c.Sql(): value} } func (c Column) Ne(value interface{}) squirrel.NotEq{ return squirrel.NotEq{c.Sql(): value} } func (c Column) Gt(value interface{}) squirrel.Gt{ return squirrel.Gt{c.Sql(): value} } func (c Column) Gte(value interface{}) squirrel.GtOrEq{ return squirrel.GtOrEq{c.Sql(): value} } func (c Column) Lt(value interface{}) squirrel.Lt{ return squirrel.Lt{c.Sql(): value} } func (c Column) Lte(value interface{}) squirrel.LtOrEq{ return squirrel.LtOrEq{c.Sql(): value} } func (c Column) Like(value interface{}) squirrel.Like{ return squirrel.Like{c.Sql(): value} } func (c Column) NotLike(value interface{}) squirrel.NotLike{ return squirrel.NotLike{c.Sql(): value} } func (c Column) As(name string) Column { c.alias = name return c } func (c Column) Name() string { if c.alias != "" { return c.alias } return c.table.GetName() + "." + c.name } func (c Column) ColName() string { return c.name } type Join struct { from Column to Column } func (j Join) Sql() string { return j.to.table.Sql() + " ON " + j.to.Sql() + " = " + j.from.Sql() } const ( // table and column names {{- range $i, $dbstruct := .DBStructs}} {{- if $dbstruct.HasTable}} // {{.Name}}Table is the name of the table where {{.Name}} are stored {{.Name}}Table = "{{.Tablename}}" {{- end}} {{- if .PKey.Name}} // {{.Name}}PKeyColumn is the name of the primary key {{.Name}}PKeyColumn = {{.Name}}{{.PKey.Name}}Column {{- end}} {{- range .Fields}} // {{$dbstruct.Name}}{{.Name}}Column is the name of the column containing field "{{.Name}}" data {{$dbstruct.Name}}{{.Name}}Column = "{{.Column}}" {{- end}} {{- end}} ) var ( // DBAllTables is the list of all the database table names DBAllTables = []string{ {{- range .DBStructs}} {{- if .HasTable}} {{.Name}}Table, {{- end}} {{- end}} } {{- range $dbstruct := .DBStructs}} {{- if .HasTable}} // {{.Name}}DataColumns is the list of the columns for the {{.Name}} structure, expect its primary key {{.Name}}DataColumns = []string{ {{- range .Fields}} {{- if not .IsPKey}} {{$dbstruct.Name}}{{.Name}}Column, {{- end}} {{- end}} } // {{.Name}}Columns is the list of the columns for the {{.Name}} structure {{- if .PKey.Name}} {{.Name}}Columns = append( []string{ {{$dbstruct.Name}}{{.PKey.Name}}Column }, {{.Name}}DataColumns..., ) {{- else }} {{.Name}}Columns = {{.Name}}DataColumns {{- end}} {{- else}} // {{.Name}}Columns is the list of the columns for the {{.Name}} structure {{.Name}}Columns = []string{ {{- range .Fields}} {{$dbstruct.Name}}{{.Name}}Column, {{- end}} } {{- end}} {{- end}} ) {{- range $dbstruct := .DBStructs}} {{- if .HasTable}} // Table returns the database table name func (s {{.Name}}) Table() string { return {{.Name}}Table } {{- end}} {{- if .PKey.Name}} // PKeyColumn returns the database table primary key column name func (s {{.Name}}) PKeyColumn() string { return {{.Name}}PKeyColumn } {{- end}} {{- if .HasTable}} // Columns returns the database table column names func (s {{.Name}}) Columns(withPKey bool) []string { if withPKey { return {{.Name}}Columns } return {{.Name}}DataColumns } {{- else}} // Columns returns the database table column names func (s {{.Name}}) Columns() []string { return {{.Name}}Columns } {{- end}} // Values returns the values for a list of columns. If a column does not exits, // the corresponding value is left empty func (s {{.Name}}) Values(columns ...string) []interface{} { values := make([]interface{}, len(columns)) for i, column := range columns { switch column { {{- range .Fields}} case "{{.Column}}": values[i] = s.{{.Name}} {{- end}} } } return values } {{- if .HasTable}} func New{{.Name}}TableSchema() *{{.Name}}TableSchema { t := {{.Name}}TableSchema{} {{- range .Fields}} t.{{.Name}} = NewColumn(&t, "{{.Column}}") {{- end}} return &t } type {{.Name}}TableSchema struct { alias string {{- range .Fields}} {{.Name}} Column {{- end}} } // Columns returns the database table column names func (t {{.Name}}TableSchema) Columns(withPKey bool) []string { if withPKey { return {{.Name}}Columns } return {{.Name}}DataColumns } // FQColumns returns the database table column names prefixed with the table alias func (t {{.Name}}TableSchema) FQColumns(withPKey bool) []string { var colList []string colList = {{.Name}}DataColumns if withPKey { colList = {{.Name}}Columns } var cols []string for _, col := range colList { cols = append(cols, t.GetName()+"."+col) } return cols } func (t {{.Name}}TableSchema) As(name string) *{{.Name}}TableSchema { t.alias = name {{- range .Fields}} t.{{.Name}} = NewColumn(&t, "{{.Column}}") {{- end}} return &t } func (t {{.Name}}TableSchema) GetName() string { if t.alias == "" { return {{.Name}}Table } return t.alias } func (t {{.Name}}TableSchema) Sql() string { if t.alias == "" { return {{.Name}}Table } return {{.Name}}Table + " AS " + t.alias } func (t {{.Name}}TableSchema) ToSql() (string, []interface{}, error) { return t.Sql(), nil, nil } func (t {{.Name}}TableSchema) Select() squirrel.SelectBuilder { return squirrel.Select(t.Columns(true)...).From(t.Sql()) } {{- end}} {{- end}} func NewDBSchema() *DBSchema { return &DBSchema{ {{- range $dbstruct := .DBStructs}} {{- if .HasTable}} {{.Name}}: New{{.Name}}TableSchema(), {{- end}} {{- end}} } } type DBSchema struct { {{- range $dbstruct := .DBStructs}} {{- if .HasTable}} {{.Name}} *{{.Name}}TableSchema {{- end}} {{- end}} } var Schema = NewDBSchema() `)) )