# HG changeset patch # User Christophe de Vienne <christophe@cdevienne.info> # Date 1729685455 -7200 # Wed Oct 23 14:10:55 2024 +0200 # Node ID cdd7f820857cacb64a69ccb10c5fd86c8b3e7511 # Parent 62b58f8ea6091f565498058062ad0ddeb961804a Fix handling of unnamed fields with 'db' tags diff --git a/tools/generate_db_helpers/generate_db_helpers.go b/tools/generate_db_helpers/generate_db_helpers.go --- a/tools/generate_db_helpers/generate_db_helpers.go +++ b/tools/generate_db_helpers/generate_db_helpers.go @@ -30,10 +30,14 @@ for _, field := range structType.Fields.List { if len(field.Names) == 0 { - if ident, ok := field.Type.(*ast.Ident); ok { - if sub, ok := p.AllStructs[ident.String()]; ok { - fields = append(fields, p.getAllFields(sub)...) + if name, _, _ := getFieldTag(field, "db"); name == "" { + if ident, ok := field.Type.(*ast.Ident); ok { + if sub, ok := p.AllStructs[ident.String()]; ok { + fields = append(fields, p.getAllFields(sub)...) + } } + } else { + fields = append(fields, field) } } else { fields = append(fields, field) @@ -64,9 +68,10 @@ // DBField is a field of a DBStruct. type DBField struct { - Name string - Column string - IsPKey bool + Name string + Column string + IsPKey bool + IsStruct bool } func main() { @@ -243,28 +248,64 @@ return &dbstruct, nil } +func getFieldTag(field *ast.Field, name string) (string, []string, error) { + if field.Tag == nil || len(field.Tag.Value) < 2 { + return "", nil, os.ErrNotExist + } + tags, err := structtag.Parse( + field.Tag.Value[1 : len(field.Tag.Value)-1]) + if err != nil { + return "", nil, err + } + + tag, err := tags.Get(name) + if err != nil { + return "", nil, err + } + + return tag.Name, tag.Options, nil +} + func getDBFields(topLevel *Package, structType *ast.StructType) ([]DBField, error) { allFields := topLevel.getAllFields(structType) dbfields := make([]DBField, 0, len(allFields)) for _, field := range allFields { - if field.Tag == nil || len(field.Tag.Value) < 2 { - continue - } - tags, err := structtag.Parse( - field.Tag.Value[1 : len(field.Tag.Value)-1]) + tagname, _, err := getFieldTag(field, "db") if err != nil { + if errors.Is(err, os.ErrNotExist) { + continue + } + return nil, err } - dbtag, err := tags.Get("db") - if err != nil { - continue + var ( + name string + isStruct bool + ) + if len(field.Names) == 0 { + var ident *ast.Ident + switch v := field.Type.(type) { + case *ast.SelectorExpr: + ident = v.Sel + isStruct = true + case *ast.Ident: + ident = v + isStruct = true + } + if ident != nil { + name = ident.Name + } else { + continue + } + } else { + name = field.Names[0].String() } - dbfields = append(dbfields, DBField{ - Name: field.Names[0].String(), - Column: dbtag.Name, + Name: name, + Column: tagname, + IsStruct: isStruct, }) } @@ -495,7 +536,7 @@ switch column { {{- range .Fields}} case "{{.Column}}": - values[i] = s.{{.Name}} + values[i] = {{if .IsStruct}}&{{ end }}s.{{.Name}} {{- end}} } } @@ -510,7 +551,7 @@ switch column { {{- range .Fields}} case "{{.Column}}": - values["{{.Column}}"] = s.{{.Name}} + values["{{.Column}}"] = {{if .IsStruct}}&{{ end }}s.{{.Name}} {{- end}} } }