package main import ( "go/ast" "go/importer" "go/parser" "go/token" "go/types" "log" "os" "path/filepath" "regexp" "strings" "gitlab.wow.st/gmp/persist/generate" ) func trimType(n *string) { *n = pkgreg.ReplaceAllString(*n,"") *n = treg.ReplaceAllString(*n,"") } func (v *visitor) Visit(node ast.Node) ast.Visitor { x, ok := node.(*ast.CallExpr); if !ok { return v } id,ok := x.Fun.(*ast.Ident); if !ok { return v } if !reg.MatchString(id.Name) { return v } if len(x.Args) < 2 { log.Fatal("Wrong number of arguments in persistT call.") } var tp string switch arg := x.Args[1].(type) { case *ast.BasicLit: tp = strings.ToLower(arg.Kind.String()) if tp == "float" { tp = "float64" } case *ast.CompositeLit: tp = types.TypeString(v.info.TypeOf(arg),types.RelativeTo(v.pkg)) case *ast.Ident: tp = types.TypeString(v.info.TypeOf(arg),types.RelativeTo(v.pkg)) trimType(&tp) for _, s := range impreg.FindAllStringSubmatchIndex(tp, -1) { pkgname := impreg.ExpandString(make([]byte,0), "$1", tp, s) v.needImps[string(pkgname)] = true } default: } name := reg.ReplaceAllString(id.Name,"$1") v.ts[name] = tps{fun: id.Name, name: name, typ: tp} return nil } type tps struct { fun,name,typ string } type visitor struct { pkg *types.Package ts map[string]tps imps map[string]string needImps map[string]bool info *types.Info } var ( fset *token.FileSet reg, pkgreg, impreg, treg *regexp.Regexp ) func newVisitor() *visitor { var v = &visitor{} v.imps = make(map[string]string) v.needImps = make(map[string]bool) v.info = &types.Info{ Types: make(map[ast.Expr]types.TypeAndValue), Defs: make(map[*ast.Ident]types.Object), Uses: make(map[*ast.Ident]types.Object)} v.ts = make(map[string]tps) v.needImps = make(map[string]bool) return v } func chkpkg(p *ast.Package) (map[string]tps, []string) { fs := make([]*ast.File,0) for _,f := range p.Files { fs = append(fs,f) } conf := types.Config{Error: func(error) {}, Importer: importer.Default()} v := newVisitor() var err error v.pkg, err = conf.Check("main", fset, fs, v.info) if err != nil { log.Print("Check:",err) } for _,f := range fs { for _,is := range f.Imports { name := is.Path.Value shortName := pkgreg.ReplaceAllString(name,"") v.imps[shortName] = name } ast.Walk(v,f) } imps := make([]string,0) for shortName := range v.needImps { imps = append(imps,v.imps[shortName]) } return v.ts, imps } func init() { reg = regexp.MustCompile("^persist([a-zA-Z_]+[a-zA-Z_0-9]*)") pkgreg = regexp.MustCompile(`[a-zA-Z_]+[a-zA-Z_0-9\.]*/|"`) impreg = regexp.MustCompile(`([a-zA-Z_]+[a-zA-Z_0-9]+)\.[a-zA-Z_]+[a-zA-Z_0-9]+`) treg = regexp.MustCompile(`untyped `) fset = token.NewFileSet() // clear out old generated files fs,err := filepath.Glob("pgen*.go") if err != nil { log.Fatal("Glob error") } for _,f := range(fs) { err := os.Remove(f) if err != nil && os.IsExist(err) { log.Fatal("Removing ",f,err) } } } func main() { pkgs, err := parser.ParseDir(fset, ".", nil, 0) if err != nil { log.Fatal("Parse:",err) } // run type checker on each package for _,pkg := range pkgs { log.Print("Processing package ",pkg.Name) ts,imps := chkpkg(pkg) if len(ts) == 0 { continue } g := generate.NewGenerator(pkg.Name) for _,v := range ts { g.Add(v.fun, v.name, v.typ) } g.Import(imps...) var of *os.File var err error if pkg.Name == "main" { of, err = os.Create("pgen.go") } else { oname := strings.Join([]string{"pgen_",pkg.Name,".go"},"") of, err = os.Create(oname) } if err != nil { log.Fatal("Cannot open output file: ",err) } // process template for each type identified and // generate output g.Save(of) } }