diff --git a/cmd/pgen/gen.go b/cmd/pgen/gen.go index cf83c10..e2e9776 100644 --- a/cmd/pgen/gen.go +++ b/cmd/pgen/gen.go @@ -16,15 +16,13 @@ type subspec struct { s string } -var ( +type Generator struct{ firstone bool - gfset *token.FileSet + fset *token.FileSet fspec, tspec, nspec subspec - ntspecs []subspec + nspecs []subspec nodes *ast.File -) - -type gvisit struct{} +} func subs(dst *string, specs []subspec) { for _,spec := range(specs) { @@ -38,79 +36,81 @@ func sub(dst *string,spec subspec) { } } -func (v *gvisit) Visit(node ast.Node) ast.Visitor { +func (g *Generator) Visit(node ast.Node) ast.Visitor { switch n := node.(type) { case *ast.Ident: if n.Obj != nil && n.Obj.Kind == ast.Typ { - subs(&n.Name,ntspecs) - subs(&n.Obj.Name,ntspecs) + subs(&n.Name,g.nspecs) + subs(&n.Obj.Name,g.nspecs) } case *ast.TypeSwitchStmt: for _,s := range(n.Body.List) { cs, ok := s.(*ast.CaseClause); if !ok { continue } for _,c := range(cs.List) { ci, ok := c.(*ast.Ident); if !ok { continue } - subs(&ci.Name,ntspecs) + subs(&ci.Name,g.nspecs) } } case *ast.TypeAssertExpr: - nt, ok := n.Type.(*ast.Ident); if !ok { return v } - subs(&nt.Name,ntspecs) + nt, ok := n.Type.(*ast.Ident); if !ok { return g } + subs(&nt.Name,g.nspecs) case *ast.Ellipsis: - nt, ok := n.Elt.(*ast.Ident); if !ok { return v } - subs(&nt.Name,ntspecs) + nt, ok := n.Elt.(*ast.Ident); if !ok { return g } + subs(&nt.Name,g.nspecs) case *ast.Field: - nt, ok := n.Type.(*ast.Ident); if !ok { return v } - subs(&nt.Name,ntspecs) + nt, ok := n.Type.(*ast.Ident); if !ok { return g } + subs(&nt.Name,g.nspecs) case *ast.StarExpr: - nt, ok := n.X.(*ast.Ident); if !ok { return v } - subs(&nt.Name,ntspecs) - case *ast.TypeSpec: - //ast.Print(gfset,n) + nt, ok := n.X.(*ast.Ident); if !ok { return g } + subs(&nt.Name,g.nspecs) case *ast.ValueSpec: - nt, ok := n.Type.(*ast.Ident); if !ok { return v } - subs(&nt.Name,ntspecs) + nt, ok := n.Type.(*ast.Ident); if !ok { return g } + subs(&nt.Name,g.nspecs) case *ast.ArrayType: - nt, ok := n.Elt.(*ast.Ident); if !ok { return v } - subs(&nt.Name,ntspecs) + nt, ok := n.Elt.(*ast.Ident); if !ok { return g } + subs(&nt.Name,g.nspecs) case *ast.MapType: nt, ok := n.Key.(*ast.Ident); if ok { - subs(&nt.Name,ntspecs) + subs(&nt.Name,g.nspecs) } nt, ok = n.Value.(*ast.Ident); if ok { - subs(&nt.Name,ntspecs) + subs(&nt.Name,g.nspecs) } case *ast.FuncDecl: - sub(&n.Name.Name,fspec) - //ast.Print(gfset,n) + sub(&n.Name.Name,g.fspec) + //ast.Print(g.fset,n) } - return v + return g } func init() { - firstone = true err := os.Remove("pgen.go") if err != nil && os.IsExist(err) { log.Fatal("Removing pgen.go:",err) } } -func add(newName, typName, typ string) { - fspec = subspec{regexp.MustCompile("New"), newName} - nspec = subspec{regexp.MustCompile("_N"), typName} - tspec = subspec{regexp.MustCompile("_T"), typ} - ntspecs = []subspec{nspec,tspec} +func NewGenerator() *Generator { + g := &Generator{} + g.firstone = true + g.fset = token.NewFileSet() + return g +} - gfset = token.NewFileSet() - f, err := parser.ParseFile(gfset, "", template, 0) +func (g *Generator) Add(newName, typName, typ string) { + g.fspec = subspec{regexp.MustCompile("New"), newName} + g.nspecs = []subspec{ + subspec{regexp.MustCompile("_N"), typName}, + subspec{regexp.MustCompile("_T"), typ}} + + f, err := parser.ParseFile(g.fset, "", template, 0) if err != nil { log.Fatal("Parsing persist/template.T:",err) } - var v = &gvisit{} - ast.Walk(v,f) + ast.Walk(g,f) - if firstone { - nodes = f + if g.firstone { + g.nodes = f } else { for _,decl := range(f.Decls) { imp, ok := decl.(*ast.GenDecl); if ok { @@ -118,20 +118,27 @@ func add(newName, typName, typ string) { continue // skip imports } } - nodes.Decls = append(nodes.Decls,decl) + g.nodes.Decls = append(g.nodes.Decls,decl) } } - firstone = false + g.firstone = false } -func newImportSpec(name string) *ast.ImportSpec { - path := &ast.BasicLit{ Value: name } - ret := &ast.ImportSpec{ Path: path } - return ret +func (g *Generator) Import(imports map[string]bool) { + if g.nodes == nil { + return + } + + // strip quotation marks from imports + qexp := regexp.MustCompile(`"`) + for v := range imports { + v = qexp.ReplaceAllString(v,"") + astutil.AddImport(g.fset, g.nodes, v) + } } -func gen(wantImps []string) { - if nodes == nil { +func (g *Generator) Save() { + if g.nodes == nil { return } of, err := os.Create("pgen.go") @@ -139,13 +146,7 @@ func gen(wantImps []string) { log.Fatal("Cannot open pgen.go") } - qexp := regexp.MustCompile(`"`) - for _,v := range wantImps { - v = qexp.ReplaceAllString(v,"") - astutil.AddImport(gfset, nodes, v) - } - - err = format.Node(of,gfset,nodes) + err = format.Node(of,g.fset,g.nodes) if err != nil { log.Fatal("Generate error:",err) } diff --git a/cmd/pgen/main.go b/cmd/pgen/main.go index 0aa5a47..5561a5e 100644 --- a/cmd/pgen/main.go +++ b/cmd/pgen/main.go @@ -103,30 +103,26 @@ func init() { imps = make(map[string]string) needImps = make(map[string]bool) - ts = make(map[string]tps) fset = token.NewFileSet() } func main() { - // In naming imported identifiers, keep only the last part of the - // path to the imported package. This does not take named imports - // into account but should allow the code generator to find the - // right libary to import and generate valid code. - pkgs, err := parser.ParseDir(fset, ".", nil, 0) if err != nil { log.Fatal("Parse:",err) } + // run type checker on each package for _,pkg := range pkgs { + ts = make(map[string]tps) chkpkg(pkg) + g := NewGenerator() + for _,v := range ts { + g.Add(v.fun, v.name, v.typ) + } + g.Import(needImps) + // process template for each type identified and + // generate output + g.Save() } - addImps := make([]string,0) - for k := range needImps { - addImps = append(addImps,imps[k]) - } - for _,v := range ts { - add(v.fun, v.name, v.typ) - } - gen(addImps) }