Refactor gen.go into Generator type.

This commit is contained in:
Greg 2018-08-06 10:01:40 -04:00
parent 1c3bab0dcc
commit 3329fe75df
2 changed files with 67 additions and 70 deletions

View File

@ -16,15 +16,13 @@ type subspec struct {
s string s string
} }
var ( type Generator struct{
firstone bool firstone bool
gfset *token.FileSet fset *token.FileSet
fspec, tspec, nspec subspec fspec, tspec, nspec subspec
ntspecs []subspec nspecs []subspec
nodes *ast.File nodes *ast.File
) }
type gvisit struct{}
func subs(dst *string, specs []subspec) { func subs(dst *string, specs []subspec) {
for _,spec := range(specs) { for _,spec := range(specs) {
@ -38,79 +36,81 @@ func sub(dst *string,spec subspec) {
} }
} }
func (v *gvisit) Visit(node ast.Node) ast.Visitor { func (g *Generator) Visit(node ast.Node) ast.Visitor {
switch n := node.(type) { switch n := node.(type) {
case *ast.Ident: case *ast.Ident:
if n.Obj != nil && n.Obj.Kind == ast.Typ { if n.Obj != nil && n.Obj.Kind == ast.Typ {
subs(&n.Name,ntspecs) subs(&n.Name,g.nspecs)
subs(&n.Obj.Name,ntspecs) subs(&n.Obj.Name,g.nspecs)
} }
case *ast.TypeSwitchStmt: case *ast.TypeSwitchStmt:
for _,s := range(n.Body.List) { for _,s := range(n.Body.List) {
cs, ok := s.(*ast.CaseClause); if !ok { continue } cs, ok := s.(*ast.CaseClause); if !ok { continue }
for _,c := range(cs.List) { for _,c := range(cs.List) {
ci, ok := c.(*ast.Ident); if !ok { continue } ci, ok := c.(*ast.Ident); if !ok { continue }
subs(&ci.Name,ntspecs) subs(&ci.Name,g.nspecs)
} }
} }
case *ast.TypeAssertExpr: case *ast.TypeAssertExpr:
nt, ok := n.Type.(*ast.Ident); if !ok { return v } nt, ok := n.Type.(*ast.Ident); if !ok { return g }
subs(&nt.Name,ntspecs) subs(&nt.Name,g.nspecs)
case *ast.Ellipsis: case *ast.Ellipsis:
nt, ok := n.Elt.(*ast.Ident); if !ok { return v } nt, ok := n.Elt.(*ast.Ident); if !ok { return g }
subs(&nt.Name,ntspecs) subs(&nt.Name,g.nspecs)
case *ast.Field: case *ast.Field:
nt, ok := n.Type.(*ast.Ident); if !ok { return v } nt, ok := n.Type.(*ast.Ident); if !ok { return g }
subs(&nt.Name,ntspecs) subs(&nt.Name,g.nspecs)
case *ast.StarExpr: case *ast.StarExpr:
nt, ok := n.X.(*ast.Ident); if !ok { return v } nt, ok := n.X.(*ast.Ident); if !ok { return g }
subs(&nt.Name,ntspecs) subs(&nt.Name,g.nspecs)
case *ast.TypeSpec:
//ast.Print(gfset,n)
case *ast.ValueSpec: case *ast.ValueSpec:
nt, ok := n.Type.(*ast.Ident); if !ok { return v } nt, ok := n.Type.(*ast.Ident); if !ok { return g }
subs(&nt.Name,ntspecs) subs(&nt.Name,g.nspecs)
case *ast.ArrayType: case *ast.ArrayType:
nt, ok := n.Elt.(*ast.Ident); if !ok { return v } nt, ok := n.Elt.(*ast.Ident); if !ok { return g }
subs(&nt.Name,ntspecs) subs(&nt.Name,g.nspecs)
case *ast.MapType: case *ast.MapType:
nt, ok := n.Key.(*ast.Ident); if ok { nt, ok := n.Key.(*ast.Ident); if ok {
subs(&nt.Name,ntspecs) subs(&nt.Name,g.nspecs)
} }
nt, ok = n.Value.(*ast.Ident); if ok { nt, ok = n.Value.(*ast.Ident); if ok {
subs(&nt.Name,ntspecs) subs(&nt.Name,g.nspecs)
} }
case *ast.FuncDecl: case *ast.FuncDecl:
sub(&n.Name.Name,fspec) sub(&n.Name.Name,g.fspec)
//ast.Print(gfset,n) //ast.Print(g.fset,n)
} }
return v return g
} }
func init() { func init() {
firstone = true
err := os.Remove("pgen.go") err := os.Remove("pgen.go")
if err != nil && os.IsExist(err) { if err != nil && os.IsExist(err) {
log.Fatal("Removing pgen.go:",err) log.Fatal("Removing pgen.go:",err)
} }
} }
func add(newName, typName, typ string) { func NewGenerator() *Generator {
fspec = subspec{regexp.MustCompile("New"), newName} g := &Generator{}
nspec = subspec{regexp.MustCompile("_N"), typName} g.firstone = true
tspec = subspec{regexp.MustCompile("_T"), typ} g.fset = token.NewFileSet()
ntspecs = []subspec{nspec,tspec} return g
}
gfset = token.NewFileSet() func (g *Generator) Add(newName, typName, typ string) {
f, err := parser.ParseFile(gfset, "", template, 0) g.fspec = subspec{regexp.MustCompile("New"), newName}
g.nspecs = []subspec{
subspec{regexp.MustCompile("_N"), typName},
subspec{regexp.MustCompile("_T"), typ}}
f, err := parser.ParseFile(g.fset, "", template, 0)
if err != nil { if err != nil {
log.Fatal("Parsing persist/template.T:",err) log.Fatal("Parsing persist/template.T:",err)
} }
var v = &gvisit{} ast.Walk(g,f)
ast.Walk(v,f)
if firstone { if g.firstone {
nodes = f g.nodes = f
} else { } else {
for _,decl := range(f.Decls) { for _,decl := range(f.Decls) {
imp, ok := decl.(*ast.GenDecl); if ok { imp, ok := decl.(*ast.GenDecl); if ok {
@ -118,20 +118,27 @@ func add(newName, typName, typ string) {
continue // skip imports continue // skip imports
} }
} }
nodes.Decls = append(nodes.Decls,decl) g.nodes.Decls = append(g.nodes.Decls,decl)
} }
} }
firstone = false g.firstone = false
} }
func newImportSpec(name string) *ast.ImportSpec { func (g *Generator) Import(imports map[string]bool) {
path := &ast.BasicLit{ Value: name } if g.nodes == nil {
ret := &ast.ImportSpec{ Path: path } return
return ret }
// strip quotation marks from imports
qexp := regexp.MustCompile(`"`)
for v := range imports {
v = qexp.ReplaceAllString(v,"")
astutil.AddImport(g.fset, g.nodes, v)
}
} }
func gen(wantImps []string) { func (g *Generator) Save() {
if nodes == nil { if g.nodes == nil {
return return
} }
of, err := os.Create("pgen.go") of, err := os.Create("pgen.go")
@ -139,13 +146,7 @@ func gen(wantImps []string) {
log.Fatal("Cannot open pgen.go") log.Fatal("Cannot open pgen.go")
} }
qexp := regexp.MustCompile(`"`) err = format.Node(of,g.fset,g.nodes)
for _,v := range wantImps {
v = qexp.ReplaceAllString(v,"")
astutil.AddImport(gfset, nodes, v)
}
err = format.Node(of,gfset,nodes)
if err != nil { if err != nil {
log.Fatal("Generate error:",err) log.Fatal("Generate error:",err)
} }

View File

@ -103,30 +103,26 @@ func init() {
imps = make(map[string]string) imps = make(map[string]string)
needImps = make(map[string]bool) needImps = make(map[string]bool)
ts = make(map[string]tps)
fset = token.NewFileSet() fset = token.NewFileSet()
} }
func main() { func main() {
// In naming imported identifiers, keep only the last part of the
// path to the imported package. This does not take named imports
// into account but should allow the code generator to find the
// right libary to import and generate valid code.
pkgs, err := parser.ParseDir(fset, ".", nil, 0) pkgs, err := parser.ParseDir(fset, ".", nil, 0)
if err != nil { if err != nil {
log.Fatal("Parse:",err) log.Fatal("Parse:",err)
} }
// run type checker on each package
for _,pkg := range pkgs { for _,pkg := range pkgs {
ts = make(map[string]tps)
chkpkg(pkg) chkpkg(pkg)
g := NewGenerator()
for _,v := range ts {
g.Add(v.fun, v.name, v.typ)
}
g.Import(needImps)
// process template for each type identified and
// generate output
g.Save()
} }
addImps := make([]string,0)
for k := range needImps {
addImps = append(addImps,imps[k])
}
for _,v := range ts {
add(v.fun, v.name, v.typ)
}
gen(addImps)
} }