diff --git a/tools/generate_db_helpers/generate_db_helpers.go b/tools/generate_db_helpers/generate_db_helpers.go index 62b58f8ea6091f565498058062ad0ddeb961804a_dG9vbHMvZ2VuZXJhdGVfZGJfaGVscGVycy9nZW5lcmF0ZV9kYl9oZWxwZXJzLmdv..cdd7f820857cacb64a69ccb10c5fd86c8b3e7511_dG9vbHMvZ2VuZXJhdGVfZGJfaGVscGVycy9nZW5lcmF0ZV9kYl9oZWxwZXJzLmdv 100644 --- a/tools/generate_db_helpers/generate_db_helpers.go +++ b/tools/generate_db_helpers/generate_db_helpers.go @@ -30,7 +30,9 @@ for _, field := range structType.Fields.List { if len(field.Names) == 0 { - if ident, ok := field.Type.(*ast.Ident); ok { - if sub, ok := p.AllStructs[ident.String()]; ok { - fields = append(fields, p.getAllFields(sub)...) + if name, _, _ := getFieldTag(field, "db"); name == "" { + if ident, ok := field.Type.(*ast.Ident); ok { + if sub, ok := p.AllStructs[ident.String()]; ok { + fields = append(fields, p.getAllFields(sub)...) + } } @@ -36,4 +38,6 @@ } + } else { + fields = append(fields, field) } } else { fields = append(fields, field) @@ -64,9 +68,10 @@ // DBField is a field of a DBStruct. type DBField struct { - Name string - Column string - IsPKey bool + Name string + Column string + IsPKey bool + IsStruct bool } func main() { @@ -243,8 +248,26 @@ return &dbstruct, nil } +func getFieldTag(field *ast.Field, name string) (string, []string, error) { + if field.Tag == nil || len(field.Tag.Value) < 2 { + return "", nil, os.ErrNotExist + } + tags, err := structtag.Parse( + field.Tag.Value[1 : len(field.Tag.Value)-1]) + if err != nil { + return "", nil, err + } + + tag, err := tags.Get(name) + if err != nil { + return "", nil, err + } + + return tag.Name, tag.Options, nil +} + func getDBFields(topLevel *Package, structType *ast.StructType) ([]DBField, error) { allFields := topLevel.getAllFields(structType) dbfields := make([]DBField, 0, len(allFields)) for _, field := range allFields { @@ -246,11 +269,7 @@ func getDBFields(topLevel *Package, structType *ast.StructType) ([]DBField, error) { allFields := topLevel.getAllFields(structType) dbfields := make([]DBField, 0, len(allFields)) for _, field := range allFields { - if field.Tag == nil || len(field.Tag.Value) < 2 { - continue - } - tags, err := structtag.Parse( - field.Tag.Value[1 : len(field.Tag.Value)-1]) + tagname, _, err := getFieldTag(field, "db") if err != nil { @@ -256,4 +275,8 @@ if err != nil { + if errors.Is(err, os.ErrNotExist) { + continue + } + return nil, err } @@ -257,7 +280,25 @@ return nil, err } - dbtag, err := tags.Get("db") - if err != nil { - continue + var ( + name string + isStruct bool + ) + if len(field.Names) == 0 { + var ident *ast.Ident + switch v := field.Type.(type) { + case *ast.SelectorExpr: + ident = v.Sel + isStruct = true + case *ast.Ident: + ident = v + isStruct = true + } + if ident != nil { + name = ident.Name + } else { + continue + } + } else { + name = field.Names[0].String() } @@ -263,3 +304,2 @@ } - dbfields = append(dbfields, DBField{ @@ -265,6 +305,7 @@ dbfields = append(dbfields, DBField{ - Name: field.Names[0].String(), - Column: dbtag.Name, + Name: name, + Column: tagname, + IsStruct: isStruct, }) } @@ -495,7 +536,7 @@ switch column { {{- range .Fields}} case "{{.Column}}": - values[i] = s.{{.Name}} + values[i] = {{if .IsStruct}}&{{ end }}s.{{.Name}} {{- end}} } } @@ -510,7 +551,7 @@ switch column { {{- range .Fields}} case "{{.Column}}": - values["{{.Column}}"] = s.{{.Name}} + values["{{.Column}}"] = {{if .IsStruct}}&{{ end }}s.{{.Name}} {{- end}} } }