From 100f65f83af6bef79b86b60b9a618af73fcb0ae5 Mon Sep 17 00:00:00 2001 From: Greg Date: Sat, 4 Aug 2018 15:41:34 -0400 Subject: [PATCH] pgen: discover types for composite literals. --- cmd/pgen/main.go | 16 ++++++++++++---- test/basic/main.go | 2 ++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/cmd/pgen/main.go b/cmd/pgen/main.go index 6a3e983..6777c7f 100644 --- a/cmd/pgen/main.go +++ b/cmd/pgen/main.go @@ -24,7 +24,7 @@ func (v *visitor) Visit(node ast.Node) ast.Visitor { if !reg.MatchString(id.Name) { return v } if len(x.Args) != 2 { - log.Fatal("Wrong number of arguments in persist_T call.") + log.Fatal("Wrong number of arguments in persistT call.") } var tp string switch arg := x.Args[1].(type) { @@ -33,14 +33,16 @@ func (v *visitor) Visit(node ast.Node) ast.Visitor { if tp == "float" { tp = "float64" } + case *ast.CompositeLit: + tp = types.TypeString(info.TypeOf(arg),types.RelativeTo(pkg)) case *ast.Ident: if arg.Obj.Kind != ast.Var { return v } inner := pkg.Scope().Innermost(arg.Pos()) _,obj := inner.LookupParent(arg.Obj.Name,arg.Pos()) tp = types.TypeString(obj.Type(),types.RelativeTo(pkg)) trimType(&tp) - for _, submatches := range impreg.FindAllStringSubmatchIndex(tp, -1) { - pkgname := impreg.ExpandString(make([]byte,0), "$1", tp, submatches) + for _, s := range impreg.FindAllStringSubmatchIndex(tp, -1) { + pkgname := impreg.ExpandString(make([]byte,0), "$1", tp, s) needImps[string(pkgname)] = true } default: @@ -63,6 +65,7 @@ var pkgreg *regexp.Regexp var impreg *regexp.Regexp var imps map[string]string var needImps map[string]bool +var info *types.Info func chkpkg(p *ast.Package) { fs := make([]*ast.File,0) @@ -72,7 +75,12 @@ func chkpkg(p *ast.Package) { conf := types.Config{Error: func(error) {}, Importer: importer.Default()} var err error - pkg, err = conf.Check("main", fset, fs, nil) + info = &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object)} + + pkg, err = conf.Check("main", fset, fs, info) if err != nil { log.Print("Check:",err) } diff --git a/test/basic/main.go b/test/basic/main.go index 054edc4..4550890 100644 --- a/test/basic/main.go +++ b/test/basic/main.go @@ -34,6 +34,8 @@ func main() { y6 := make(map[persist.Var]gtoken.FileSet) y6p := persistM6("y6",y6) y6 = y6p.Get() + y7p := persistIntArray("y7",[]int{1,2,3}) + _ = y7p z := func(interface{}) { _ = persistFloat("y7",1.0)