diff --git a/cmd/pgen/main.go b/cmd/pgen/main.go index 876e316..4ed92a1 100644 --- a/cmd/pgen/main.go +++ b/cmd/pgen/main.go @@ -15,8 +15,6 @@ import ( "gitlab.wow.st/gmp/persist/generate" ) -type visitor struct{} - func trimType(n *string) { *n = pkgreg.ReplaceAllString(*n,"") *n = treg.ReplaceAllString(*n,"") @@ -39,19 +37,19 @@ func (v *visitor) Visit(node ast.Node) ast.Visitor { tp = "float64" } case *ast.CompositeLit: - tp = types.TypeString(info.TypeOf(arg),types.RelativeTo(pkg)) + tp = types.TypeString(v.info.TypeOf(arg),types.RelativeTo(v.pkg)) case *ast.Ident: - tp = types.TypeString(info.TypeOf(arg),types.RelativeTo(pkg)) + tp = types.TypeString(v.info.TypeOf(arg),types.RelativeTo(v.pkg)) trimType(&tp) for _, s := range impreg.FindAllStringSubmatchIndex(tp, -1) { pkgname := impreg.ExpandString(make([]byte,0), "$1", tp, s) - needImps[string(pkgname)] = true + v.needImps[string(pkgname)] = true } default: } name := reg.ReplaceAllString(id.Name,"$1") - ts[name] = tps{fun: id.Name, name: name, typ: tp} + v.ts[name] = tps{fun: id.Name, name: name, typ: tp} return nil } @@ -59,43 +57,58 @@ type tps struct { fun,name,typ string } -var fset *token.FileSet -var pkg *types.Package -var ts map[string]tps -var reg *regexp.Regexp -var pkgreg *regexp.Regexp -var impreg *regexp.Regexp -var treg *regexp.Regexp -var imps map[string]string -var needImps map[string]bool -var info *types.Info +type visitor struct { + pkg *types.Package + ts map[string]tps + imps map[string]string + needImps map[string]bool + info *types.Info +} -func chkpkg(p *ast.Package) { +var ( + fset *token.FileSet + reg, pkgreg, impreg, treg *regexp.Regexp +) + +func newVisitor() *visitor { + var v = &visitor{} + v.imps = make(map[string]string) + v.needImps = make(map[string]bool) + v.info = &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object)} + v.ts = make(map[string]tps) + v.needImps = make(map[string]bool) + return v +} + +func chkpkg(p *ast.Package) (map[string]tps, []string) { fs := make([]*ast.File,0) for _,f := range p.Files { fs = append(fs,f) } conf := types.Config{Error: func(error) {}, Importer: importer.Default()} + v := newVisitor() var err error - info = &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object)} - - pkg, err = conf.Check("main", fset, fs, info) + v.pkg, err = conf.Check("main", fset, fs, v.info) if err != nil { log.Print("Check:",err) } - var v = &visitor{} for _,f := range fs { for _,is := range f.Imports { name := is.Path.Value shortName := pkgreg.ReplaceAllString(name,"") - imps[shortName] = name + v.imps[shortName] = name } ast.Walk(v,f) } + imps := make([]string,0) + for shortName := range v.needImps { + imps = append(imps,v.imps[shortName]) + } + return v.ts, imps } func init() { @@ -105,8 +118,6 @@ func init() { impreg = regexp.MustCompile(`([a-zA-Z_]+[a-zA-Z_0-9]+)\.[a-zA-Z_]+[a-zA-Z_0-9]+`) treg = regexp.MustCompile(`untyped `) - imps = make(map[string]string) - needImps = make(map[string]bool) fset = token.NewFileSet() // clear out old generated files @@ -130,8 +141,7 @@ func main() { // run type checker on each package for _,pkg := range pkgs { log.Print("Processing package ",pkg.Name) - ts = make(map[string]tps) - chkpkg(pkg) + ts,imps := chkpkg(pkg) if len(ts) == 0 { continue } @@ -139,9 +149,7 @@ func main() { for _,v := range ts { g.Add(v.fun, v.name, v.typ) } - for name := range needImps { - g.Import(imps[name]) - } + g.Import(imps...) var of *os.File var err error if pkg.Name == "main" { diff --git a/generate/generate.go b/generate/generate.go index 537fd2f..5cf188b 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -131,14 +131,16 @@ func (g *Generator) Add(newName, typName, typ string) { g.firstone = false } -func (g *Generator) Import(name string) { +func (g *Generator) Import(ns ...string) { if g.nodes == nil { return } - // strip quotation marks from imports - name = qexp.ReplaceAllString(name,"") - astutil.AddImport(g.fset, g.nodes, name) + for _,n := range ns { + // strip quotation marks from imports + n = qexp.ReplaceAllString(n,"") + astutil.AddImport(g.fset, g.nodes, n) + } } func (g *Generator) Save(of io.Writer) {