persist/cmd/pgen/main.go
2018-08-03 15:08:30 -04:00

125 lines
2.8 KiB
Go

package main
import (
"go/ast"
"go/importer"
"go/parser"
"go/token"
"go/types"
"log"
"regexp"
"strings"
)
type visitor struct{}
func trimType(n *string) {
*n = pkgreg.ReplaceAllString(*n,"")
}
func (v *visitor) Visit(node ast.Node) ast.Visitor {
x, ok := node.(*ast.CallExpr); if !ok { return v }
id,ok := x.Fun.(*ast.Ident); if !ok { return v }
if !reg.MatchString(id.Name) { return v }
if len(x.Args) != 2 {
log.Fatal("Wrong number of arguments in persist_T call.")
}
var tp string
switch arg := x.Args[1].(type) {
case *ast.BasicLit:
tp = strings.ToLower(arg.Kind.String())
if tp == "float" {
tp = "float64"
}
case *ast.Ident:
if arg.Obj.Kind != ast.Var { return v }
inner := pkg.Scope().Innermost(arg.Pos())
_,obj := inner.LookupParent(arg.Obj.Name,arg.Pos())
tp = types.TypeString(obj.Type(),types.RelativeTo(pkg))
trimType(&tp)
for _, submatches := range impreg.FindAllStringSubmatchIndex(tp, -1) {
pkgname := impreg.ExpandString(make([]byte,0), "$1", tp, submatches)
needImps[string(pkgname)] = true
}
default:
}
name := reg.ReplaceAllString(id.Name,"$1")
ts[name] = tps{fun: id.Name, name: name, typ: tp}
return nil
}
type tps struct {
fun,name,typ string
}
var fset *token.FileSet
var pkg *types.Package
var ts map[string]tps
var reg *regexp.Regexp
var pkgreg *regexp.Regexp
var impreg *regexp.Regexp
var imps map[string]string
var needImps map[string]bool
func chkpkg(p *ast.Package) {
fs := make([]*ast.File,0)
for _,f := range p.Files {
fs = append(fs,f)
}
conf := types.Config{Error: func(error) {}, Importer: importer.Default()}
var err error
pkg, err = conf.Check("main", fset, fs, nil)
if err != nil {
log.Print("Check:",err)
}
var v = &visitor{}
for _,f := range fs {
for _,is := range f.Imports {
name := is.Path.Value
shortName := pkgreg.ReplaceAllString(name,"")
imps[shortName] = name
}
ast.Walk(v,f)
}
}
func init() {
reg = regexp.MustCompile("^persist([a-zA-Z_]+[a-zA-Z_0-9]*)")
pkgreg = regexp.MustCompile(`[a-zA-Z_]+[a-zA-Z_0-9\.]*/|"`)
impreg = regexp.MustCompile(`([a-zA-Z_]+[a-zA-Z_0-9]+)\.[a-zA-Z_]+[a-zA-Z_0-9]+`)
imps = make(map[string]string)
needImps = make(map[string]bool)
ts = make(map[string]tps)
fset = token.NewFileSet()
}
func main() {
// In naming imported identifiers, keep only the last part of the
// path to the imported package. This does not take named imports
// into account but should allow the code generator to find the
// right libary to import and generate valid code.
pkgs, err := parser.ParseDir(fset, ".", nil, 0)
if err != nil {
log.Fatal("Parse:",err)
}
for _,pkg := range pkgs {
chkpkg(pkg)
}
addImps := make([]string,0)
for k := range needImps {
addImps = append(addImps,imps[k])
}
for _,v := range ts {
add(v.fun, v.name, v.typ)
}
gen(addImps)
}