Skip to content
Snippets Groups Projects
Commit cdd7f820857c authored by Christophe de Vienne's avatar Christophe de Vienne
Browse files

Fix handling of unnamed fields with 'db' tags

parent 62b58f8ea609
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,9 @@
for _, field := range structType.Fields.List {
if len(field.Names) == 0 {
if ident, ok := field.Type.(*ast.Ident); ok {
if sub, ok := p.AllStructs[ident.String()]; ok {
fields = append(fields, p.getAllFields(sub)...)
if name, _, _ := getFieldTag(field, "db"); name == "" {
if ident, ok := field.Type.(*ast.Ident); ok {
if sub, ok := p.AllStructs[ident.String()]; ok {
fields = append(fields, p.getAllFields(sub)...)
}
}
......@@ -36,4 +38,6 @@
}
} else {
fields = append(fields, field)
}
} else {
fields = append(fields, field)
......@@ -64,9 +68,10 @@
// DBField is a field of a DBStruct.
type DBField struct {
Name string
Column string
IsPKey bool
Name string
Column string
IsPKey bool
IsStruct bool
}
func main() {
......@@ -243,8 +248,26 @@
return &dbstruct, nil
}
func getFieldTag(field *ast.Field, name string) (string, []string, error) {
if field.Tag == nil || len(field.Tag.Value) < 2 {
return "", nil, os.ErrNotExist
}
tags, err := structtag.Parse(
field.Tag.Value[1 : len(field.Tag.Value)-1])
if err != nil {
return "", nil, err
}
tag, err := tags.Get(name)
if err != nil {
return "", nil, err
}
return tag.Name, tag.Options, nil
}
func getDBFields(topLevel *Package, structType *ast.StructType) ([]DBField, error) {
allFields := topLevel.getAllFields(structType)
dbfields := make([]DBField, 0, len(allFields))
for _, field := range allFields {
......@@ -246,11 +269,7 @@
func getDBFields(topLevel *Package, structType *ast.StructType) ([]DBField, error) {
allFields := topLevel.getAllFields(structType)
dbfields := make([]DBField, 0, len(allFields))
for _, field := range allFields {
if field.Tag == nil || len(field.Tag.Value) < 2 {
continue
}
tags, err := structtag.Parse(
field.Tag.Value[1 : len(field.Tag.Value)-1])
tagname, _, err := getFieldTag(field, "db")
if err != nil {
......@@ -256,4 +275,8 @@
if err != nil {
if errors.Is(err, os.ErrNotExist) {
continue
}
return nil, err
}
......@@ -257,7 +280,25 @@
return nil, err
}
dbtag, err := tags.Get("db")
if err != nil {
continue
var (
name string
isStruct bool
)
if len(field.Names) == 0 {
var ident *ast.Ident
switch v := field.Type.(type) {
case *ast.SelectorExpr:
ident = v.Sel
isStruct = true
case *ast.Ident:
ident = v
isStruct = true
}
if ident != nil {
name = ident.Name
} else {
continue
}
} else {
name = field.Names[0].String()
}
......@@ -263,3 +304,2 @@
}
dbfields = append(dbfields, DBField{
......@@ -265,6 +305,7 @@
dbfields = append(dbfields, DBField{
Name: field.Names[0].String(),
Column: dbtag.Name,
Name: name,
Column: tagname,
IsStruct: isStruct,
})
}
......@@ -495,7 +536,7 @@
switch column {
{{- range .Fields}}
case "{{.Column}}":
values[i] = s.{{.Name}}
values[i] = {{if .IsStruct}}&{{ end }}s.{{.Name}}
{{- end}}
}
}
......@@ -510,7 +551,7 @@
switch column {
{{- range .Fields}}
case "{{.Column}}":
values["{{.Column}}"] = s.{{.Name}}
values["{{.Column}}"] = {{if .IsStruct}}&{{ end }}s.{{.Name}}
{{- end}}
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment