commit b880a5671d6105bada73ad215d33ed61bb84c8a2 Author: Greg Date: Fri Aug 3 15:08:30 2018 -0400 Initial commit. diff --git a/cmd/pgen/gen.go b/cmd/pgen/gen.go new file mode 100644 index 0000000..b974515 --- /dev/null +++ b/cmd/pgen/gen.go @@ -0,0 +1,150 @@ +package main + +import ( + "log" + "go/ast" + "go/format" + "go/parser" + "go/token" + "golang.org/x/tools/go/ast/astutil" + "os" + "regexp" +) + +type subspec struct { + r *regexp.Regexp + s string +} + +var ( + firstone bool + gfset *token.FileSet + fspec, tspec, nspec subspec + ntspecs []subspec + nodes *ast.File +) + +type gvisit struct{} + +func subs(dst *string, specs []subspec) { + for _,spec := range(specs) { + sub(dst,spec) + } +} + +func sub(dst *string,spec subspec) { + if spec.r.MatchString(*dst) { + *dst = spec.r.ReplaceAllString(*dst,spec.s) + } +} + +func (v *gvisit) Visit(node ast.Node) ast.Visitor { + switch n := node.(type) { + case *ast.Ident: + if n.Obj != nil && n.Obj.Kind == ast.Typ { + subs(&n.Name,ntspecs) + subs(&n.Obj.Name,ntspecs) + } + case *ast.TypeSwitchStmt: + for _,s := range(n.Body.List) { + cs, ok := s.(*ast.CaseClause); if !ok { continue } + for _,c := range(cs.List) { + ci, ok := c.(*ast.Ident); if !ok { continue } + subs(&ci.Name,ntspecs) + } + } + case *ast.TypeAssertExpr: + nt, ok := n.Type.(*ast.Ident); if !ok { return v } + subs(&nt.Name,ntspecs) + case *ast.Ellipsis: + nt, ok := n.Elt.(*ast.Ident); if !ok { return v } + subs(&nt.Name,ntspecs) + case *ast.Field: + nt, ok := n.Type.(*ast.Ident); if !ok { return v } + subs(&nt.Name,ntspecs) + case *ast.StarExpr: + nt, ok := n.X.(*ast.Ident); if !ok { return v } + subs(&nt.Name,ntspecs) + case *ast.TypeSpec: + //ast.Print(gfset,n) + case *ast.ValueSpec: + nt, ok := n.Type.(*ast.Ident); if !ok { return v } + subs(&nt.Name,ntspecs) + case *ast.ArrayType: + nt, ok := n.Elt.(*ast.Ident); if !ok { return v } + subs(&nt.Name,ntspecs) + case *ast.MapType: + nt, ok := n.Key.(*ast.Ident); if ok { + subs(&nt.Name,ntspecs) + } + nt, ok = n.Value.(*ast.Ident); if ok { + subs(&nt.Name,ntspecs) + } + case *ast.FuncDecl: + sub(&n.Name.Name,fspec) + //ast.Print(gfset,n) + } + return v +} + +func init() { + firstone = true + err := os.Remove("pgen.go") + if err != nil && os.IsExist(err) { + log.Fatal("Removing pgen.go:",err) + } +} + +func add(newName, typName, typ string) { + fspec = subspec{regexp.MustCompile("New"), newName} + nspec = subspec{regexp.MustCompile("_N"), typName} + tspec = subspec{regexp.MustCompile("_T"), typ} + ntspecs = []subspec{nspec,tspec} + + gfset = token.NewFileSet() + f, err := parser.ParseFile(gfset, "", template, 0) + if err != nil { + log.Fatal("Parsing persist/template.T:",err) + } + var v = &gvisit{} + ast.Walk(v,f) + + if firstone { + nodes = f + } else { + for _,decl := range(f.Decls) { + imp, ok := decl.(*ast.GenDecl); if ok { + if imp.Tok == token.IMPORT { + continue // skip imports + } + } + nodes.Decls = append(nodes.Decls,decl) + } + } + firstone = false +} + +func newImportSpec(name string) *ast.ImportSpec { + path := &ast.BasicLit{ Value: name } + ret := &ast.ImportSpec{ Path: path } + return ret +} + +func gen(wantImps []string) { + of, err := os.Create("pgen.go") + if err != nil { + log.Fatal("Cannot open pgen.go") + } + + qexp := regexp.MustCompile(`"`) + for _,v := range wantImps { + v = qexp.ReplaceAllString(v,"") + astutil.AddImport(gfset, nodes, v) + } + + err = format.Node(of,gfset,nodes) + if err != nil { + log.Fatal("Generate error:",err) + } +} + diff --git a/cmd/pgen/main.go b/cmd/pgen/main.go new file mode 100644 index 0000000..6a3e983 --- /dev/null +++ b/cmd/pgen/main.go @@ -0,0 +1,124 @@ +package main + +import ( + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + "log" + "regexp" + "strings" +) + +type visitor struct{} + +func trimType(n *string) { + *n = pkgreg.ReplaceAllString(*n,"") +} + +func (v *visitor) Visit(node ast.Node) ast.Visitor { + x, ok := node.(*ast.CallExpr); if !ok { return v } + id,ok := x.Fun.(*ast.Ident); if !ok { return v } + + if !reg.MatchString(id.Name) { return v } + + if len(x.Args) != 2 { + log.Fatal("Wrong number of arguments in persist_T call.") + } + var tp string + switch arg := x.Args[1].(type) { + case *ast.BasicLit: + tp = strings.ToLower(arg.Kind.String()) + if tp == "float" { + tp = "float64" + } + case *ast.Ident: + if arg.Obj.Kind != ast.Var { return v } + inner := pkg.Scope().Innermost(arg.Pos()) + _,obj := inner.LookupParent(arg.Obj.Name,arg.Pos()) + tp = types.TypeString(obj.Type(),types.RelativeTo(pkg)) + trimType(&tp) + for _, submatches := range impreg.FindAllStringSubmatchIndex(tp, -1) { + pkgname := impreg.ExpandString(make([]byte,0), "$1", tp, submatches) + needImps[string(pkgname)] = true + } + default: + } + + name := reg.ReplaceAllString(id.Name,"$1") + ts[name] = tps{fun: id.Name, name: name, typ: tp} + return nil +} + +type tps struct { + fun,name,typ string +} + +var fset *token.FileSet +var pkg *types.Package +var ts map[string]tps +var reg *regexp.Regexp +var pkgreg *regexp.Regexp +var impreg *regexp.Regexp +var imps map[string]string +var needImps map[string]bool + +func chkpkg(p *ast.Package) { + fs := make([]*ast.File,0) + for _,f := range p.Files { + fs = append(fs,f) + } + + conf := types.Config{Error: func(error) {}, Importer: importer.Default()} + var err error + pkg, err = conf.Check("main", fset, fs, nil) + if err != nil { + log.Print("Check:",err) + } + var v = &visitor{} + for _,f := range fs { + for _,is := range f.Imports { + name := is.Path.Value + shortName := pkgreg.ReplaceAllString(name,"") + imps[shortName] = name + } + ast.Walk(v,f) + } +} + +func init() { + reg = regexp.MustCompile("^persist([a-zA-Z_]+[a-zA-Z_0-9]*)") + + pkgreg = regexp.MustCompile(`[a-zA-Z_]+[a-zA-Z_0-9\.]*/|"`) + impreg = regexp.MustCompile(`([a-zA-Z_]+[a-zA-Z_0-9]+)\.[a-zA-Z_]+[a-zA-Z_0-9]+`) + + imps = make(map[string]string) + needImps = make(map[string]bool) + ts = make(map[string]tps) + fset = token.NewFileSet() +} + +func main() { + // In naming imported identifiers, keep only the last part of the + // path to the imported package. This does not take named imports + // into account but should allow the code generator to find the + // right libary to import and generate valid code. + + pkgs, err := parser.ParseDir(fset, ".", nil, 0) + if err != nil { + log.Fatal("Parse:",err) + } + for _,pkg := range pkgs { + chkpkg(pkg) + } + addImps := make([]string,0) + for k := range needImps { + addImps = append(addImps,imps[k]) + } + for _,v := range ts { + add(v.fun, v.name, v.typ) + } + gen(addImps) +} + diff --git a/cmd/pgen/template.go b/cmd/pgen/template.go new file mode 100644 index 0000000..ae414b0 --- /dev/null +++ b/cmd/pgen/template.go @@ -0,0 +1,36 @@ +package main +const template string = `package main + +import ( + "time" + "unsafe" + "gitlab.wow.st/gmp/persist" +) + +type Var_N persist.Var + +func New(name string, xs ..._T) *Var_N { + var x _T + if len(xs) > 0 { + x = xs[0] + } + ptr := persist.New(name, x) + ret := (*Var_N)(unsafe.Pointer(ptr)) + return ret +} + +func (v *Var_N) Set(x _T) error { + ptr := (*persist.Var)(unsafe.Pointer(v)) + return ptr.Set(x) +} + +func (v *Var_N) Get(ts ...time.Time) _T { + ptr := (*persist.Var)(unsafe.Pointer(v)) + return ptr.Get(ts...).(_T) +} + +func (v *Var_N) History() ([]persist.TVar, error) { + ptr := (*persist.Var)(unsafe.Pointer(v)) + return ptr.History() +} +` diff --git a/persist.go b/persist.go new file mode 100644 index 0000000..26d5922 --- /dev/null +++ b/persist.go @@ -0,0 +1,487 @@ +package persist + +import ( + "bytes" + "encoding/binary" + "encoding/gob" + "errors" + "fmt" + "log" + "os" + "path" + "reflect" + "runtime" + "sync" + "time" + + bolt "github.com/coreos/bbolt" +) + +// allow persistant storage and reloading of variables to bolt database. +// everything is time stamped. +// bucket is the variable name + +type Var struct { + X interface{} + name string + opt Option +} + +type Option struct { + Permanent bool +} + +type Config struct { + configured bool + MaxVars int + DBPath string +} + +func Init(x ...Config) { + mux.Lock() + c := Config{MaxVars: 64, DBPath: "db"} // default config + conf = c + if len(x) > 0 { + c = x[0] + } + if c.MaxVars != 0 { + conf.MaxVars = c.MaxVars + conf.configured = true + } + if c.DBPath != "" { + conf.DBPath = c.DBPath + conf.configured = true + } + mux.Unlock() +} + +func New(name string, x interface{},opt ...Option) *Var { + start() + mux.Lock() + if vars == nil { + vars = make(map[string]*Var) + } + ret := vars[name] + if ret == nil { + ret = &Var{X: x, name: name} + vars[name] = ret + } + if len(opt) == 1 { + ret.opt = opt[0] + } + mux.Unlock() + err := ret.Load() + if err != nil { // save default value if loading failed + fmt.Println("Load error:",err) + wchan <- encode(ret) + } + fmt.Println("New(): ",ret.name) + mux.Lock() + if loaded == nil { + loaded = make(chan *Var,conf.MaxVars) // buffered channel + } + mux.Unlock() + loaded<- ret + return ret +} + +func (p *Var) Set(x interface{}) error { + if reflect.TypeOf(x) == reflect.TypeOf(p.X) { + p.X = x + p.Save() + return nil + } else { + return errors.New("Set(): Type mismatch") + } +} + +func (p *Var) SaveSync() { + donech := make(chan struct{}) + wchan <- encode(p,donech) + <-donech +} + +func (p *Var) Save(sync ...bool) { + dirty = true + wchan <- encode(p) +} + +func (p *Var) Load(ts ...time.Time) error { + var err error + var v []byte + if len(ts) == 0 { + v, err = retrieve(p.name,lastCheckpoint) + } else { + v, err = retrieve(p.name,ts[0]) + } + if err == nil { + var p2 Var + dec := gob.NewDecoder(bytes.NewBuffer(v)) + err := dec.Decode(&p2) + if reflect.TypeOf(p2.X) != reflect.TypeOf(p.X) { + err = errors.New("Load(): Type mismatch") + } + if err != nil { + fmt.Println("Load(): ",err) + } else { + p.X = p2.X + } + } else { + fmt.Println("Load(",p.name,"): type =",reflect.TypeOf(p.X)) + } + return err +} + +func (p *Var) Get(ts ...time.Time) interface{} { + if len(ts) == 0 { + return p.X + } else { + p2 := *p + p2.Load(ts[0]) + return p2.X + } +} + +func (p *Var) Delete(t time.Time) { + wchan<- dencode(p,t) +} + +func (p *Var) DeleteSync(t time.Time) { + donech := make(chan struct{}) + wchan<- dencode(p,t,donech) + <-donech +} + +func Commit() { + mux.Lock() + if dirty { + dirty = false + __checkpoints.SaveSync() + } + mux.Unlock() +} + +func Checkpoints() { + fmt.Println(__checkpoints.History()) +} + +func Undo() { + fmt.Println("Undo()") + mux.Lock() + __checkpoints.DeleteSync(lastCheckpoint) + h,_ := __checkpoints.History() + var t time.Time + if len(h) > 0 { + t = h[len(h)-1].Time + } else { + return + } + lastCheckpoint = t + for _,v := range vars { + v.Load(t) + } + dirty = false + mux.Unlock() +} + +func Shutdown() { + close(tchan) + tg.Wait() + close(wchan) + wg.Wait() + mux.Lock() + launched = false + vars = make(map[string]*Var) + mux.Unlock() +} + +var ( + launched bool + wchan chan encoded + tchan chan struct{} + mux sync.Mutex + wmux sync.Mutex + wg sync.WaitGroup + tg sync.WaitGroup + db *bolt.DB + loaded chan *Var + conf Config + __checkpoints *Var + lastCheckpoint time.Time + vars map[string]*Var + dirty bool +) + + +type encoded struct { + name []byte + value interface{} + donech chan struct{} +} + +func encode(p *Var,chs ...chan struct{}) encoded { + ret := encoded{name:[]byte(p.name)} + if len(chs) > 0 { + ret.donech = chs[0] + } + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(*p) + if err != nil { + fmt.Println("encode(): ",err) + } + ret.value = buf.Bytes() + + return ret +} + +func dencode(p *Var,t time.Time,chs ...chan struct{}) encoded { + ret := encoded{name:[]byte(p.name),value:t} + if len(chs) > 0 { + ret.donech = chs[0] + } + return ret +} + +// tencode creates a DB key based on the time provided +func tencode(t time.Time) []byte { + buf := make([]byte, 8) + ns := t.UnixNano() + binary.BigEndian.PutUint64(buf, uint64(ns)) + + return buf +} + +func tdecode(x []byte) time.Time { + ns := int64(binary.BigEndian.Uint64(x)) + return time.Unix(0,ns) +} + +func start() { + mux.Lock() + if launched { + mux.Unlock() + return + } + launched = true + + if conf.configured == false { + Init() + } + + var err error + dbdir := path.Dir(conf.DBPath) + if _, err = os.Stat(dbdir); os.IsNotExist(err) { + log.Print("creating db dir") + err = os.MkdirAll(dbdir,0700) + if err != nil { + log.Fatal(err) + } + } + db, err = bolt.Open(conf.DBPath, 0600, &bolt.Options{Timeout:5 * time.Second}) + + if err != nil { + log.Fatal("persist.start() bolt.Open(): ", err) + } + + tchan = make(chan struct{}) + tg.Add(1) + go tidyThread() + + wchan = make(chan encoded) + wg.Add(1) + go writer() + + __checkpoints = &Var{name:"__checkpoints", X: nil} + h, err := __checkpoints.History() + if len(h) > 0 { + lastCheckpoint = h[len(h)-1].Time + fmt.Println("lastCheckpoint = ",lastCheckpoint) + } + mux.Unlock() +} + +func writer() { + fmt.Println("launching writer thread") + runtime.UnlockOSThread() + for p := range wchan { + wmux.Lock() + err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket(p.name) + if b == nil { + _, err := tx.CreateBucket(p.name) + if err != nil { + log.Fatal("Cannot create bucket for ",string(p.name)) + } + fmt.Println("writer(): created bucket for ",string(p.name)) + b = tx.Bucket(p.name) + if b == nil { + log.Fatal("Bucket error") + } + } + var err error + switch v := p.value.(type) { + case time.Time: + vv := b.Get(tencode(v)) + if len(vv) == 0 { + return errors.New("writer() error: can't delete") + } + err = b.Delete(tencode(v)) + case []byte: + t := time.Now() + if string(p.name) == "__checkpoints" { + fmt.Println("saving: ",string(p.name),t) + lastCheckpoint = t + } + err = b.Put(tencode(t),v) + } + if err != nil { + log.Fatal("writer error", err) + } + return nil + }) + wmux.Unlock() + if err != nil { + fmt.Println("writer() error: ",err) + } + if p.donech != nil { // close out synchronous writes + close(p.donech) + } + } + fmt.Println("leaving writer thread") + db.Close() + fmt.Println("DB is closed") + wg.Done() +} + +type TVar struct { + Time time.Time + Value Var +} + +func (p *Var) History() ([]TVar, error) { + ret := make([]TVar,0) + wmux.Lock() + err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(p.name)) + if b == nil { + return errors.New("DB entry not found") + } + c := b.Cursor() + for k, x := c.First(); x != nil; k, x = c.Next() { + t := tdecode(k) + p2 := Var{} + dec := gob.NewDecoder(bytes.NewBuffer(x)) + err := dec.Decode(&p2) + if err != nil { + return err + } + if reflect.TypeOf(p2.X) != reflect.TypeOf(p.X) { + return errors.New("History(): Type mismatch") + } + p3 := *p + p3.X = p2.X + ret = append(ret,TVar{t,p3}) + } + return nil + }) + wmux.Unlock() + return ret, err +} + +// retrieve a variable from the database. If ts is provided, return the first version +// that is not after ts[0]. +func retrieve(name string, ts ...time.Time) ([]byte, error) { + var ret []byte + err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(name)) + if b == nil { + return errors.New("DB entry not found") + } + c := b.Cursor() + if len(ts) == 0 { + _, ret = c.Last() + } else { + var k []byte + for k, ret = c.Last(); ret != nil; k, ret = c.Prev() { + if ts[0].After(tdecode(k)) { + return nil + } + } + return errors.New("Not found") + } + return nil + }) + return ret, err +} + +func tidyThread() { + fmt.Println("tidyThread() starting") + lchan := make(chan string) + go func() { + fmt.Println("starting loaded reading thread") + if loaded == nil { + loaded = make(chan *Var,conf.MaxVars) // buffered channel + } + for { + for p := range loaded { + if p.opt.Permanent == false { + lchan<- p.name + loaded<- p // put it back in the queue + } + } + fmt.Println("closing lchan") + close(lchan) + } + }() + tick := time.NewTicker(time.Second * 10) + for { select { + case <-tchan: + fmt.Println("closing loaded channel") + close(loaded) + tg.Done() + fmt.Println("tidy(): returning") + return + case <-tick.C: + name := <-lchan + fmt.Println("tidyThread(): tidying ", name) + Tidy(name) + } } +} + +const expiration time.Duration = -24 * time.Hour + +// discard entries if they are too old. +func Tidy(name string) { + stime := time.Now() + etime := stime.Add(expiration) + mux.Lock() + i := 0 + db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(name)) + if b == nil { + fmt.Println("Can't open bucket for ",name) + return nil + } + c := b.Cursor() + end,_ := c.Last() // always keep the most recent on + for time.Since(stime) < 10 * time.Millisecond { + k,_ := c.First() + if tdecode(k).After(etime) { + return nil // entry not expired + } + if reflect.DeepEqual(k,end) { + return nil // always keep the most recent entry + } + b.Delete(k) + i++ + } + return nil + }) + mux.Unlock() + if i > 0 { + fmt.Println("tidy(): deleted ",i," entries in ",time.Since(stime)) + } + fmt.Println("tidy(): done tidying",name) +} + diff --git a/test/basic/main.go b/test/basic/main.go new file mode 100644 index 0000000..72bb974 --- /dev/null +++ b/test/basic/main.go @@ -0,0 +1,45 @@ +package main +//go:generate pgen + +import ( + "fmt" + "go/token" + "gitlab.wow.st/gmp/persist" +) + +func main() { + persist.Init() + x := persistInt("x",5) + var y1 mine1 = 6 + y2 := mine2{} + y1p := persistM1("y",y1) + y2p := persistM2("y",y2) + y3 := token.FileSet{} + y3p := persistM3("y",y3) + y4 := make(map[*token.FileSet]*token.File) + y4p := persistM4("y",y4) + y5 := make(map[*token.FileSet]*persist.Var) + y5p := persistM5("y",y5) + y6 := make(map[persist.Var]token.FileSet) + y6p := persistM6("y",y6) + _,_,_,_,_,_ = y1p, y2p, y3p, y4p, y5p, y6p + fmt.Println(x) + fmt.Println(x.Get()) + x.Set(3) + fmt.Println(x) + + z := func(interface{}) { + _ = persistFloat("y",1.0) + } + z(persistString("name","ok")) + + s := persistString("s","ok bye") + fmt.Println(s) + s.Set("one two") + var ts string + + ts = s.Get() // this works + _ = ts + persist.Commit() +} + diff --git a/test/basic/types.go b/test/basic/types.go new file mode 100644 index 0000000..e5c6292 --- /dev/null +++ b/test/basic/types.go @@ -0,0 +1,7 @@ +package main + +type mine1 int +type mine2 struct { + a int + b int +}