diff --git a/.gitignore b/.gitignore index bc6cb7f..d13c37d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ test/basic/basic test/basic/pgen.go +test/basic/pgen_aaa.go test/basic/db diff --git a/cmd/pgen/gen.go b/cmd/pgen/gen.go index e2e9776..62c9ed7 100644 --- a/cmd/pgen/gen.go +++ b/cmd/pgen/gen.go @@ -1,6 +1,7 @@ package main import ( + "path/filepath" "log" "go/ast" "go/format" @@ -9,6 +10,7 @@ import ( "golang.org/x/tools/go/ast/astutil" "os" "regexp" + "strings" ) type subspec struct { @@ -22,6 +24,7 @@ type Generator struct{ fspec, tspec, nspec subspec nspecs []subspec nodes *ast.File + name string } func subs(dst *string, specs []subspec) { @@ -83,17 +86,33 @@ func (g *Generator) Visit(node ast.Node) ast.Visitor { return g } +var ( + qexp *regexp.Regexp +) + func init() { - err := os.Remove("pgen.go") - if err != nil && os.IsExist(err) { - log.Fatal("Removing pgen.go:",err) + fs,err := filepath.Glob("pgen*.go") + if err != nil { + log.Fatal("Glob error") } + for _,f := range(fs) { + err := os.Remove(f) + if err != nil && os.IsExist(err) { + log.Fatal("Removing ",f,err) + } + } + qexp = regexp.MustCompile(`"`) } -func NewGenerator() *Generator { +func NewGenerator(ns ...string) *Generator { g := &Generator{} g.firstone = true g.fset = token.NewFileSet() + if len(ns) > 0 { + g.name = ns[0] + } else { + g.name = "main" + } return g } @@ -124,24 +143,29 @@ func (g *Generator) Add(newName, typName, typ string) { g.firstone = false } -func (g *Generator) Import(imports map[string]bool) { +func (g *Generator) Import(name string) { if g.nodes == nil { return } // strip quotation marks from imports - qexp := regexp.MustCompile(`"`) - for v := range imports { - v = qexp.ReplaceAllString(v,"") - astutil.AddImport(g.fset, g.nodes, v) - } + name = qexp.ReplaceAllString(name,"") + astutil.AddImport(g.fset, g.nodes, name) } func (g *Generator) Save() { if g.nodes == nil { return } - of, err := os.Create("pgen.go") + var of *os.File + var err error + if g.name == "main" { + of, err = os.Create("pgen.go") + } else { + oname := strings.Join([]string{"pgen_",g.name,".go"},"") + of, err = os.Create(oname) + g.nodes.Name = ast.NewIdent(g.name) + } if err != nil { log.Fatal("Cannot open pgen.go") } diff --git a/cmd/pgen/main.go b/cmd/pgen/main.go index 5561a5e..e9b8cd0 100644 --- a/cmd/pgen/main.go +++ b/cmd/pgen/main.go @@ -113,13 +113,16 @@ func main() { } // run type checker on each package for _,pkg := range pkgs { + log.Print("Processing package ",pkg.Name) ts = make(map[string]tps) chkpkg(pkg) - g := NewGenerator() + g := NewGenerator(pkg.Name) for _,v := range ts { g.Add(v.fun, v.name, v.typ) } - g.Import(needImps) + for name := range needImps { + g.Import(imps[name]) + } // process template for each type identified and // generate output g.Save() diff --git a/test/basic/other.go b/test/basic/other.go new file mode 100644 index 0000000..651827e --- /dev/null +++ b/test/basic/other.go @@ -0,0 +1,7 @@ +package aaa + +func Hi() { + y := persistFoo("y",7) + _ = y +} +