From 8e1ae77aad3caaf6c16bf2ad1fd5b51ddff89b9b Mon Sep 17 00:00:00 2001 From: Tzu-Jung Lee Date: Mon, 9 Jul 2018 10:17:33 -0700 Subject: [PATCH] cleanup: bring the merkletrie in as trie. --- claimtrie.go | 30 ++--- cmd/claimtrie/main.go | 9 +- trie/README.md | 47 ++++++++ trie/cmd/triesh/README.md | 171 +++++++++++++++++++++++++++ trie/cmd/triesh/main.go | 242 ++++++++++++++++++++++++++++++++++++++ trie/commit.go | 21 ++++ trie/errors.go | 8 ++ trie/node.go | 96 +++++++++++++++ trie/node_test.go | 139 ++++++++++++++++++++++ trie/stage.go | 61 ++++++++++ trie/stage_test.go | 34 ++++++ trie/test.go | 106 +++++++++++++++++ trie/trie.go | 128 ++++++++++++++++++++ trie/trie_test.go | 75 ++++++++++++ 14 files changed, 1147 insertions(+), 20 deletions(-) create mode 100644 trie/README.md create mode 100644 trie/cmd/triesh/README.md create mode 100644 trie/cmd/triesh/main.go create mode 100644 trie/commit.go create mode 100644 trie/errors.go create mode 100644 trie/node.go create mode 100644 trie/node_test.go create mode 100644 trie/stage.go create mode 100644 trie/stage_test.go create mode 100644 trie/test.go create mode 100644 trie/trie.go create mode 100644 trie/trie_test.go diff --git a/claimtrie.go b/claimtrie.go index b807496..62ad35b 100644 --- a/claimtrie.go +++ b/claimtrie.go @@ -9,7 +9,7 @@ import ( "github.com/lbryio/claimtrie/claim" "github.com/lbryio/claimtrie/claimnode" - "github.com/lbryio/merkletrie" + "github.com/lbryio/claimtrie/trie" ) // ClaimTrie implements a Merkle Trie supporting linear history of commits. @@ -19,33 +19,33 @@ type ClaimTrie struct { bestBlock claim.Height // Immutable linear history. - head *merkletrie.Commit + head *trie.Commit // An overlay supporting Copy-on-Write to the current tip commit. - stg *merkletrie.Stage + stg *trie.Stage // pending keeps track update for future block height. pending map[claim.Height][]string } -// CommitMeta implements merkletrie.CommitMeta with commit-specific metadata. +// CommitMeta implements trie.CommitMeta with commit-specific metadata. type CommitMeta struct { Height claim.Height } // New returns a ClaimTrie. func New() *ClaimTrie { - mt := merkletrie.New() + mt := trie.New() return &ClaimTrie{ - head: merkletrie.NewCommit(nil, CommitMeta{0}, mt), - stg: merkletrie.NewStage(mt), + head: trie.NewCommit(nil, CommitMeta{0}, mt), + stg: trie.NewStage(mt), pending: map[claim.Height][]string{}, } } -func updateStageNode(stg *merkletrie.Stage, name string, modifier func(n *claimnode.Node) error) error { - v, err := stg.Get(merkletrie.Key(name)) - if err != nil && err != merkletrie.ErrKeyNotFound { +func updateStageNode(stg *trie.Stage, name string, modifier func(n *claimnode.Node) error) error { + v, err := stg.Get(trie.Key(name)) + if err != nil && err != trie.ErrKeyNotFound { return err } var n *claimnode.Node @@ -57,7 +57,7 @@ func updateStageNode(stg *merkletrie.Stage, name string, modifier func(n *claimn if err = modifier(n); err != nil { return err } - return stg.Update(merkletrie.Key(name), n) + return stg.Update(trie.Key(name), n) } // AddClaim adds a Claim to the Stage of ClaimTrie. @@ -102,7 +102,7 @@ func (ct *ClaimTrie) SpendSupport(name string, op wire.OutPoint) error { } // Traverse visits Nodes in the Stage of the ClaimTrie. -func (ct *ClaimTrie) Traverse(visit merkletrie.Visit, update, valueOnly bool) error { +func (ct *ClaimTrie) Traverse(visit trie.Visit, update, valueOnly bool) error { return ct.stg.Traverse(visit, update, valueOnly) } @@ -157,7 +157,7 @@ func (ct *ClaimTrie) Commit(h claim.Height) error { // Reset reverts the Stage to a specified commit by height. func (ct *ClaimTrie) Reset(h claim.Height) error { - visit := func(prefix merkletrie.Key, value merkletrie.Value) error { + visit := func(prefix trie.Key, value trie.Value) error { n := value.(*claimnode.Node) return n.DecrementBlock(n.Height() - claim.Height(h)) } @@ -169,7 +169,7 @@ func (ct *ClaimTrie) Reset(h claim.Height) error { if meta.Height <= h { ct.head = commit ct.bestBlock = h - ct.stg = merkletrie.NewStage(commit.MerkleTrie) + ct.stg = trie.NewStage(commit.MerkleTrie) return nil } } @@ -177,6 +177,6 @@ func (ct *ClaimTrie) Reset(h claim.Height) error { } // Head returns the current tip commit in the commit database. -func (ct *ClaimTrie) Head() *merkletrie.Commit { +func (ct *ClaimTrie) Head() *trie.Commit { return ct.head } diff --git a/cmd/claimtrie/main.go b/cmd/claimtrie/main.go index 4fbe8d2..6a9b8cc 100644 --- a/cmd/claimtrie/main.go +++ b/cmd/claimtrie/main.go @@ -17,8 +17,7 @@ import ( "github.com/lbryio/claimtrie" "github.com/lbryio/claimtrie/claim" - - "github.com/lbryio/merkletrie" + "github.com/lbryio/claimtrie/trie" ) var ( @@ -211,7 +210,7 @@ func cmdSpendSupport(c *cli.Context) error { } func cmdShow(c *cli.Context) error { - dump := func(prefix merkletrie.Key, val merkletrie.Value) error { + dump := func(prefix trie.Key, val trie.Value) error { if val == nil { fmt.Printf("%-8s:\n", prefix) return nil @@ -255,13 +254,13 @@ func cmdReset(c *cli.Context) error { } func cmdLog(c *cli.Context) error { - commitVisit := func(c *merkletrie.Commit) { + commitVisit := func(c *trie.Commit) { meta := c.Meta.(claimtrie.CommitMeta) fmt.Printf("height: %d, commit %s\n", meta.Height, c.MerkleTrie.MerkleHash()) } fmt.Printf("\n") - merkletrie.Log(ct.Head(), commitVisit) + trie.Log(ct.Head(), commitVisit) return nil } diff --git a/trie/README.md b/trie/README.md new file mode 100644 index 0000000..710fb50 --- /dev/null +++ b/trie/README.md @@ -0,0 +1,47 @@ +# MerkleTrie + +coming soon + +## Installation + +coming soon + +## Usage + +coming soon + +## Running from Source + +This project requires [Go v1.10](https://golang.org/doc/install) or higher. + +``` bash +go get -v github.com/lbryio/trie +``` + +## Examples + +Refer to [triesh](https://github.com/lbryio/trie/blob/master/cmd/triesh) + +## Testing + +``` bash +go test -v github.com/lbryio/trie +gocov test -v github.com/lbryio/trie 1>/dev/null +``` + +## Contributing + +coming soon + +## License + +This project is MIT licensed. + +## Security + +We take security seriously. Please contact security@lbry.io regarding any security issues. +Our PGP key is [here](https://keybase.io/lbry/key.asc) if you need it. + +## Contact + +The primary contact for this project is [@lyoshenka](https://github.com/lyoshenka) (grin@lbry.io) \ No newline at end of file diff --git a/trie/cmd/triesh/README.md b/trie/cmd/triesh/README.md new file mode 100644 index 0000000..3354d7b --- /dev/null +++ b/trie/cmd/triesh/README.md @@ -0,0 +1,171 @@ +# Triesh + +An example Key-Value store to excercise the merkletree package + +Currently, it's only in-memory. + +## Installation + +This project requires [Go v1.10](https://golang.org/doc/install) or higher. + +``` bash +go get -v github.com/lbryio/trie +``` + +## Usage + +Adding values. + +``` bloocks +triesh > u -k alex -v lion +alex=lion +triesh > u -k al -v tiger +al=tiger +triesh > u -k tess -v dolphin +tess=dolphin +triesh > u -k bob -v pig +bob=pig +triesh > u -k ted -v do +ted=do +triesh > u -k ted -v dog +ted=dog +``` + +Showing Merkle Hash. + +``` blocks +triesh > merkle +bfa2927b147161146411b7f6187e1ed0c08c3dc19b200550c3458d44c0032285 + +triesh > u -k teddy -v bear +teddy=bear + +triesh > merkle +94831650b8bf76d579ca4eda1cb35861c6f5c88eb4f5b089f60fe687defe8f3d +``` + +Showing all values. + +``` blocks +triesh > s +[al ] tiger +[alex ] lion +[bob ] pig +[ted ] dog +[teddy ] bear +[tess ] dolphin +``` + +Showing all values and link nodes. + +``` bloocks +triesh > s -a +[a ] +[al ] tiger +[ale ] +[alex ] lion +[b ] +[bo ] +[bob ] pig +[t ] +[te ] +[ted ] dog +[tedd ] +[teddy ] bear +[tes ] +[tess ] dolphin +``` + +Deleting values (setting key to nil / ""). + +``` blocks +triesh > u -k al +al= +triesh > u -k alex +alex= +``` + +Updating Values. + +``` blocks +triesh > u -k bob -v cat +bob=cat +``` + +Showing all nodes, include non-pruned link nodes" + +``` blocks +triesh > s -a +[a ] +[al ] +[ale ] +[alex ] +[b ] +[bo ] +[bob ] cat +[t ] +[te ] +[ted ] dog +[tedd ] +[teddy ] bear +[tes ] +[tess ] dolphin + +``` + +Calculate Merkle Hash. + +``` blocks +triesh > merkle +c2fdce68a30e3cabf6efb3b7ebfd32afdaf09f9ebd062743fe91e181f682252b +``` + +Prune link nodes that do not reach to any values. + +``` blocks +triesh > p +pruned +``` + +Show pruned Trie and caculate the Merkle Hash again. + +``` blocks +triesh > s -a +[b ] +[bo ] +[bob ] cat +[t ] +[te ] +[ted ] dog +[tedd ] +[teddy ] bear +[tes ] +[tess ] dolphin + +triesh > merkle +c2fdce68a30e3cabf6efb3b7ebfd32afdaf09f9ebd062743fe91e181f682252b +``` + +## Running from Source + +``` bash +cd $(go env GOPATH)/src/github.com/lbryio/trie +go run cmd/triesh/*.go sh +``` + +## Contributing + +coming soon + +## License + +This project is MIT licensed. + +## Security + +We take security seriously. Please contact security@lbry.io regarding any security issues. +Our PGP key is [here](https://keybase.io/lbry/key.asc) if you need it. + +## Contact + +The primary contact for this project is [@lyoshenka](https://github.com/lyoshenka) (grin@lbry.io) \ No newline at end of file diff --git a/trie/cmd/triesh/main.go b/trie/cmd/triesh/main.go new file mode 100644 index 0000000..b71503b --- /dev/null +++ b/trie/cmd/triesh/main.go @@ -0,0 +1,242 @@ +package main + +import ( + "bufio" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lbryio/claimtrie/trie" + "github.com/urfave/cli" +) + +var ( + flgKey = cli.StringFlag{Name: "key, k", Usage: "Key"} + flgValue = cli.StringFlag{Name: "value, v", Usage: "Value"} + flgAll = cli.BoolFlag{Name: "all, a", Usage: "Apply to non-value nodes"} + flgMessage = cli.StringFlag{Name: "message, m", Usage: "Commit Message"} + flgID = cli.StringFlag{Name: "id", Usage: "Commit ID"} +) + +var ( + // ErrNotImplemented is returned when a function is not implemented yet. + ErrNotImplemented = fmt.Errorf("not implemented") +) + +func main() { + app := cli.NewApp() + + app.Name = "triesh" + app.Usage = "A CLI tool for Merkle MerkleTrie" + app.Version = "0.0.1" + app.Action = cli.ShowAppHelp + + app.Commands = []cli.Command{ + { + Name: "update", + Aliases: []string{"u"}, + Usage: "Update Value for Key", + Action: cmdUpdate, + Flags: []cli.Flag{flgKey, flgValue}, + }, + { + Name: "get", + Aliases: []string{"g"}, + Usage: "Get Value for specified Key", + Action: cmdGet, + Flags: []cli.Flag{flgKey, flgID}, + }, + { + Name: "show", + Aliases: []string{"s"}, + Usage: "Show Key-Value pairs of specified commit", + Action: cmdShow, + Flags: []cli.Flag{flgAll, flgID}, + }, + { + Name: "merkle", + Aliases: []string{"m"}, + Usage: "Show Merkle Hash of stage", + Action: cmdMerkle, + }, + { + Name: "prune", + Aliases: []string{"p"}, + Usage: "Prune link nodes that doesn't reach to any value", + Action: cmdPrune, + }, + { + Name: "commit", + Aliases: []string{"c"}, + Usage: "Commit current stage to database", + Action: cmdCommit, + Flags: []cli.Flag{flgMessage}, + }, + { + Name: "reset", + Aliases: []string{"r"}, + Usage: "Reset HEAD & Stage to specified commit", + Action: cmdReset, + Flags: []cli.Flag{flgAll, flgID}, + }, + { + Name: "log", + Aliases: []string{"l"}, + Usage: "Show commit logs", + Action: cmdLog, + }, + { + Name: "shell", + Aliases: []string{"sh"}, + Usage: "Enter interactive mode", + Action: func(c *cli.Context) { cmdShell(app) }, + }, + } + + if err := app.Run(os.Args); err != nil { + fmt.Printf("error: %s\n", err) + } +} + +type strValue string + +func (s strValue) Hash() chainhash.Hash { + return chainhash.DoubleHashH([]byte(s)) +} + +var ( + mt = trie.New() + head = trie.NewCommit(nil, "initial", mt) + stg = trie.NewStage(mt) +) + +func commitVisit(c *trie.Commit) { + fmt.Printf("commit %s\n\n", c.MerkleTrie.MerkleHash()) + fmt.Printf("\t%s\n\n", c.Meta.(string)) +} + +func cmdUpdate(c *cli.Context) error { + key, value := c.String("key"), c.String("value") + fmt.Printf("%s=%s\n", key, value) + if len(value) == 0 { + return stg.Update(trie.Key(key), nil) + } + return stg.Update(trie.Key(key), strValue(value)) +} + +func cmdGet(c *cli.Context) error { + key := c.String("key") + value, err := stg.Get(trie.Key(key)) + if err != nil { + return err + } + if str, ok := value.(strValue); ok { + fmt.Printf("[%s]\n", str) + } + return nil +} + +func cmdShow(c *cli.Context) error { + dump := func(prefix trie.Key, val trie.Value) error { + if val == nil { + fmt.Printf("[%-8s]\n", prefix) + return nil + } + fmt.Printf("[%-8s] %v\n", prefix, val) + return nil + } + id := c.String("id") + if len(id) == 0 { + return stg.Traverse(dump, false, !c.Bool("all")) + } + for commit := head; commit != nil; commit = commit.Prev { + if commit.MerkleTrie.MerkleHash().String() == id { + return commit.MerkleTrie.Traverse(dump, false, true) + } + + } + return fmt.Errorf("commit noot found") +} + +func cmdMerkle(c *cli.Context) error { + fmt.Printf("%s\n", stg.MerkleHash()) + return nil +} + +func cmdPrune(c *cli.Context) error { + stg.Prune() + fmt.Printf("pruned\n") + return nil +} + +func cmdCommit(c *cli.Context) error { + msg := c.String("message") + if len(msg) == 0 { + return fmt.Errorf("no message specified") + } + h, err := stg.Commit(head, msg) + if err != nil { + return err + } + head = h + return nil +} + +func cmdReset(c *cli.Context) error { + id := c.String("id") + for commit := head; commit != nil; commit = commit.Prev { + if commit.MerkleTrie.MerkleHash().String() != id { + continue + } + head = commit + stg = trie.NewStage(head.MerkleTrie) + return nil + } + return fmt.Errorf("commit noot found") +} + +func cmdLog(c *cli.Context) error { + commitVisit := func(c *trie.Commit) { + fmt.Printf("commit %s\n\n", c.MerkleTrie.MerkleHash()) + fmt.Printf("\t%s\n\n", c.Meta.(string)) + } + + trie.Log(head, commitVisit) + return nil +} + +func cmdShell(app *cli.App) { + cli.OsExiter = func(c int) {} + reader := bufio.NewReader(os.Stdin) + sigs := make(chan os.Signal, 1) + go func() { + for range sigs { + fmt.Printf("\n(type quit or q to exit)\n\n") + fmt.Printf("%s > ", app.Name) + } + }() + defer close(sigs) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + for { + fmt.Printf("%s > ", app.Name) + text, err := reader.ReadString('\n') + if err != nil { + fmt.Printf("error: %s\n", err) + } + text = strings.TrimSpace(text) + if text == "" { + continue + } + if text == "quit" || text == "q" { + break + } + if err := app.Run(append(os.Args[1:], strings.Split(text, " ")...)); err != nil { + fmt.Printf("errot: %s\n", err) + } + + } + signal.Stop(sigs) +} diff --git a/trie/commit.go b/trie/commit.go new file mode 100644 index 0000000..64bfa2d --- /dev/null +++ b/trie/commit.go @@ -0,0 +1,21 @@ +package trie + +// CommitMeta ... +type CommitMeta interface{} + +// NewCommit ... +func NewCommit(head *Commit, meta CommitMeta, mt *MerkleTrie) *Commit { + commit := &Commit{ + Prev: head, + MerkleTrie: mt, + Meta: meta, + } + return commit +} + +// Commit ... +type Commit struct { + Prev *Commit + MerkleTrie *MerkleTrie + Meta CommitMeta +} diff --git a/trie/errors.go b/trie/errors.go new file mode 100644 index 0000000..d25d49a --- /dev/null +++ b/trie/errors.go @@ -0,0 +1,8 @@ +package trie + +import "errors" + +var ( + // ErrKeyNotFound is returned when the key doesn't exist in the MerkleTrie. + ErrKeyNotFound = errors.New("key not found") +) diff --git a/trie/node.go b/trie/node.go new file mode 100644 index 0000000..ee489d4 --- /dev/null +++ b/trie/node.go @@ -0,0 +1,96 @@ +package trie + +import ( + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +type node struct { + hash *chainhash.Hash + links [256]*node + value Value +} + +func newNode(val Value) *node { + return &node{links: [256]*node{}, value: val} +} + +// We clear the Merkle Hash for every node along the path, including the root. +// Calculation of the hash happens much less frequently then updating to the MerkleTrie. +func update(n *node, key Key, val Value) { + // Follow the path to reach the node. + for _, k := range key { + if n.links[k] == nil { + // The path didn't exist yet. Build it. + n.links[k] = newNode(nil) + } + n.hash = nil + n = n.links[k] + } + + n.value = val + n.hash = nil +} + +func prune(n *node) *node { + if n == nil { + return nil + } + var ret *node + for i, v := range n.links { + if n.links[i] = prune(v); n.links[i] != nil { + ret = n + } + } + if n.value != nil { + ret = n + } + return ret +} + +func traverse(n *node, prefix Key, visit Visit) error { + if n == nil { + return nil + } + for i, v := range n.links { + if v == nil { + continue + } + p := append(prefix, byte(i)) + if err := visit(p, v.value); err != nil { + return err + } + if err := traverse(v, p, visit); err != nil { + return err + } + } + return nil +} + +// merkle recursively caculates the Merkle Hash of a given node +// It works with both pruned or unpruned nodes. +func merkle(n *node) *chainhash.Hash { + if n.hash != nil { + return n.hash + } + buf := Key{} + for i, v := range n.links { + if v == nil { + continue + } + if h := merkle(v); h != nil { + buf = append(buf, byte(i)) + buf = append(buf, h[:]...) + } + } + if n.value != nil { + h := n.value.Hash() + buf = append(buf, h[:]...) + } + + if len(buf) != 0 { + // At least one of the sub nodes has contributed a value hash. + h := chainhash.DoubleHashH(buf) + n.hash = &h + } + return n.hash +} diff --git a/trie/node_test.go b/trie/node_test.go new file mode 100644 index 0000000..abc58d0 --- /dev/null +++ b/trie/node_test.go @@ -0,0 +1,139 @@ +package trie + +import ( + "fmt" + "reflect" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +func Test_update(t *testing.T) { + res1 := buildNode(newNode(nil), pairs1()) + tests := []struct { + name string + res *node + exp *node + }{ + {"test1", res1, unprunedNode()}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !reflect.DeepEqual(tt.res, tt.exp) { + traverse(tt.res, Key{}, dump) + fmt.Println("") + traverse(tt.exp, Key{}, dump) + t.Errorf("update() = %v, want %v", tt.res, tt.exp) + } + }) + } +} + +func Test_nullify(t *testing.T) { + tests := []struct { + name string + res *node + exp *node + }{ + {"test1", prune(unprunedNode()), prunedNode()}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !reflect.DeepEqual(tt.res, tt.exp) { + t.Errorf("traverse() = %v, want %v", tt.res, tt.exp) + } + }) + } +} + +func Test_traverse(t *testing.T) { + res1 := []pair{} + fn := func(prefix Key, value Value) error { + res1 = append(res1, pair{string(prefix), value}) + return nil + } + traverse(unprunedNode(), Key{}, fn) + exp1 := []pair{ + {"a", nil}, + {"al", nil}, + {"ale", nil}, + {"alex", nil}, + {"b", nil}, + {"bo", nil}, + {"bob", strValue("cat")}, + {"t", nil}, + {"te", nil}, + {"ted", strValue("dog")}, + {"tedd", nil}, + {"teddy", strValue("bear")}, + {"tes", nil}, + {"tess", strValue("dolphin")}, + } + + res2 := []pair{} + fn2 := func(prefix Key, value Value) error { + res2 = append(res2, pair{string(prefix), value}) + return nil + } + traverse(prunedNode(), Key{}, fn2) + exp2 := []pair{ + {"b", nil}, + {"bo", nil}, + {"bob", strValue("cat")}, + {"t", nil}, + {"te", nil}, + {"ted", strValue("dog")}, + {"tedd", nil}, + {"teddy", strValue("bear")}, + {"tes", nil}, + {"tess", strValue("dolphin")}, + } + + tests := []struct { + name string + res []pair + exp []pair + }{ + {"test1", res1, exp1}, + {"test2", res2, exp2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !reflect.DeepEqual(tt.res, tt.exp) { + t.Errorf("traverse() = %v, want %v", tt.res, tt.exp) + } + }) + } +} + +func Test_merkle(t *testing.T) { + n1 := buildNode(newNode(nil), pairs1()) + // n2 := func() *node { + // p1 := wire.OutPoint{Hash: *newHashFromStr("627ecfee2110b28fbc4b012944cadf66a72f394ad9fa9bb18fec30789e26c9ac"), Index: 0} + // p2 := wire.OutPoint{Hash: *newHashFromStr("c31bd469112abf04930879c6b6007d2b23224e042785d404bbeff1932dd94880"), Index: 0} + + // n1 := claim.NewNode(&claim.Claim{OutPoint: p1, ClaimID: nil, Amount: 50, Height: 100, ValidAtHeight: 200}) + // n2 := claim.NewNode(&claim.Claim{OutPoint: p2, ClaimID: nil, Amount: 50, Height: 100, ValidAtHeight: 200}) + + // pairs := []pair{ + // {"test", n1}, + // {"test2", n2}, + // } + // return buildNode(newNode(nil), pairs) + // }() + tests := []struct { + name string + n *node + want *chainhash.Hash + }{ + {"test1", n1, newHashFromStr("c2fdce68a30e3cabf6efb3b7ebfd32afdaf09f9ebd062743fe91e181f682252b")}, + // {"test2", n2, newHashFromStr("71c7b8d35b9a3d7ad9a1272b68972979bbd18589f1efe6f27b0bf260a6ba78fa")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := merkle(tt.n); !reflect.DeepEqual(got, tt.want) { + t.Errorf("merkle() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/trie/stage.go b/trie/stage.go new file mode 100644 index 0000000..665777f --- /dev/null +++ b/trie/stage.go @@ -0,0 +1,61 @@ +package trie + +// Stage implements Copy-on-Write staging area on top of a MerkleTrie. +type Stage struct { + *MerkleTrie +} + +// NewStage returns a Stage initialized with a specified MerkleTrie. +func NewStage(t *MerkleTrie) *Stage { + s := &Stage{ + MerkleTrie: New(), + } + s.mu = t.mu + s.root = newNode(nil) + *s.root = *t.root + return s +} + +// Update updates the internal MerkleTrie in a Copy-on-Write manner. +func (s *Stage) Update(key Key, val Value) error { + s.mu.Lock() + defer s.mu.Unlock() + + n := s.root + for _, k := range key { + org := n.links[k] + n.links[k] = newNode(nil) + if org != nil { + *n.links[k] = *org + } + n.hash = nil + n = n.links[k] + } + if n.value != val { + n.value = val + n.hash = nil + } + return nil +} + +// Commit ... +func (s *Stage) Commit(head *Commit, meta CommitMeta) (*Commit, error) { + c := NewCommit(head, meta, s.MerkleTrie) + + s.MerkleTrie = New() + s.mu = c.MerkleTrie.mu + s.root = newNode(nil) + *s.root = *c.MerkleTrie.root + return c, nil +} + +// CommitVisit ... +type CommitVisit func(c *Commit) + +// Log ... +func Log(commit *Commit, visit CommitVisit) { + for commit != nil { + visit(commit) + commit = commit.Prev + } +} diff --git a/trie/stage_test.go b/trie/stage_test.go new file mode 100644 index 0000000..2f768d0 --- /dev/null +++ b/trie/stage_test.go @@ -0,0 +1,34 @@ +package trie + +import ( + "fmt" + "reflect" + "testing" +) + +func TestStage_Update(t *testing.T) { + tr1 := buildTrie(New(), pairs1()) + + s1 := NewStage(tr1) + s1.Update(Key("cook"), strValue("hello")) + s1.Update(Key("ted"), nil) + + tr1Exp := buildTrie(New(), pairs1()) + + s1Exp := buildTrie(New(), pairs1()) + s1Exp.Update(Key("cook"), strValue("hello")) + s1Exp.Update(Key("ted"), nil) + + if !reflect.DeepEqual(tr1, tr1Exp) { + t.Errorf("Stage.Update() tr1 != tr1Exp") + traverse(tr1.root, Key{}, dump) + fmt.Println("") + traverse(tr1Exp.root, Key{}, dump) + } + if !reflect.DeepEqual(s1.MerkleTrie, s1Exp) { + t.Errorf("Stage.Update() s1 != s1Exp") + traverse(s1.root, Key{}, dump) + fmt.Println("") + traverse(s1Exp.root, Key{}, dump) + } +} diff --git a/trie/test.go b/trie/test.go new file mode 100644 index 0000000..60a3ae0 --- /dev/null +++ b/trie/test.go @@ -0,0 +1,106 @@ +package trie + +import ( + "fmt" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// Internal utility functions to facilitate the tests. + +type strValue string + +func (s strValue) Hash() chainhash.Hash { + return chainhash.DoubleHashH([]byte(s)) +} + +func dump(prefix Key, value Value) error { + if value == nil { + fmt.Printf("[%-8s]\n", prefix) + return nil + } + fmt.Printf("[%-8s] %v\n", prefix, value) + return nil +} + +func buildNode(n *node, pairs []pair) *node { + for _, val := range pairs { + update(n, Key(val.k), val.v) + } + return n +} + +func buildTrie(mt *MerkleTrie, pairs []pair) *MerkleTrie { + for _, val := range pairs { + mt.Update(Key(val.k), val.v) + } + return mt +} + +func buildMap(m map[string]Value, pairs []pair) map[string]Value { + for _, p := range pairs { + if p.v == nil { + delete(m, p.k) + } else { + m[p.k] = p.v + } + } + return m +} + +func newMap() map[string]Value { + return map[string]Value{} +} + +type pair struct { + k string + v Value +} + +func pairs1() []pair { + return []pair{ + {"alex", strValue("lion")}, + {"al", strValue("tiger")}, + {"tess", strValue("dolphin")}, + {"bob", strValue("pig")}, + {"ted", strValue("dog")}, + {"teddy", strValue("bear")}, + {"al", nil}, + {"alex", nil}, + {"bob", strValue("cat")}, + } +} + +func prunedNode() *node { + n := newNode(nil) + n.links['b'] = newNode(nil) + n.links['b'].links['o'] = newNode(nil) + n.links['b'].links['o'].links['b'] = newNode(strValue("cat")) + n.links['t'] = newNode(nil) + n.links['t'].links['e'] = newNode(nil) + n.links['t'].links['e'].links['d'] = newNode(strValue("dog")) + n.links['t'].links['e'].links['d'].links['d'] = newNode(nil) + n.links['t'].links['e'].links['d'].links['d'].links['y'] = newNode(strValue("bear")) + n.links['t'].links['e'].links['s'] = newNode(nil) + n.links['t'].links['e'].links['s'].links['s'] = newNode(strValue("dolphin")) + return n +} + +func unprunedNode() *node { + n := newNode(nil) + n.links['a'] = newNode(nil) + n.links['a'].links['l'] = newNode(nil) + n.links['a'].links['l'].links['e'] = newNode(nil) + n.links['a'].links['l'].links['e'].links['x'] = newNode(nil) + n.links['b'] = newNode(nil) + n.links['b'].links['o'] = newNode(nil) + n.links['b'].links['o'].links['b'] = newNode(strValue("cat")) + n.links['t'] = newNode(nil) + n.links['t'].links['e'] = newNode(nil) + n.links['t'].links['e'].links['d'] = newNode(strValue("dog")) + n.links['t'].links['e'].links['d'].links['d'] = newNode(nil) + n.links['t'].links['e'].links['d'].links['d'].links['y'] = newNode(strValue("bear")) + n.links['t'].links['e'].links['s'] = newNode(nil) + n.links['t'].links['e'].links['s'].links['s'] = newNode(strValue("dolphin")) + return n +} diff --git a/trie/trie.go b/trie/trie.go new file mode 100644 index 0000000..355325f --- /dev/null +++ b/trie/trie.go @@ -0,0 +1,128 @@ +package trie + +import ( + "sync" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +var ( + // EmptyTrieHash represent the Merkle Hash of an empty MerkleTrie. + EmptyTrieHash = *newHashFromStr("0000000000000000000000000000000000000000000000000000000000000001") +) + +// Key defines the key type of the MerkleTrie. +type Key []byte + +// Value implements value for the MerkleTrie. +type Value interface { + Hash() chainhash.Hash +} + +// MerkleTrie implements a 256-way prefix tree, which takes Key as key and any value that implements the Value interface. +type MerkleTrie struct { + mu *sync.RWMutex + root *node +} + +// New returns a MerkleTrie. +func New() *MerkleTrie { + return &MerkleTrie{ + mu: &sync.RWMutex{}, + root: newNode(nil), + } +} + +// Get returns the Value associated with the key, or nil with error. +// Most common error is ErrMissing, which indicates no Value is associated with the key. +// However, there could be other errors propagated from I/O layer (TBD). +func (t *MerkleTrie) Get(key Key) (Value, error) { + t.mu.RLock() + defer t.mu.RUnlock() + + n := t.root + for _, k := range key { + if n.links[k] == nil { + // Path does not exist. + return nil, ErrKeyNotFound + } + n = n.links[k] + } + if n.value == nil { + // Path exists, but no Value is associated. + // This happens when the key had been deleted, but the MerkleTrie has not nullified yet. + return nil, ErrKeyNotFound + } + return n.value, nil +} + +// Update updates the MerkleTrie with specified key-value pair. +// Setting Value to nil deletes the Value, if exists, associated to the key. +func (t *MerkleTrie) Update(key Key, val Value) error { + t.mu.Lock() + defer t.mu.Unlock() + + update(t.root, key, val) + return nil +} + +// Prune removes nodes that do not reach to any value node. +func (t *MerkleTrie) Prune() { + t.mu.Lock() + defer t.mu.Unlock() + + prune(t.root) +} + +// Size returns the number of values. +func (t *MerkleTrie) Size() int { + t.mu.RLock() + defer t.mu.RUnlock() + + size := 0 // captured in the closure. + fn := func(prefix Key, v Value) error { + if v != nil { + size++ + } + return nil + } + traverse(t.root, Key{}, fn) + return size +} + +// Visit implements callback function invoked when the Value is visited. +// During the traversal, if a non-nil error is returned, the traversal ends early. +type Visit func(prefix Key, val Value) error + +// Traverse visits every Value in the MerkleTrie and returns error defined by specified Visit function. +// update indicates if the visit function modify the state of MerkleTrie. +func (t *MerkleTrie) Traverse(visit Visit, update, valueOnly bool) error { + if update { + t.mu.Lock() + defer t.mu.Unlock() + } else { + t.mu.RLock() + defer t.mu.RUnlock() + } + fn := func(prefix Key, value Value) error { + if !valueOnly || value != nil { + return visit(prefix, value) + } + return nil + } + return traverse(t.root, Key{}, fn) +} + +// MerkleHash calculates the Merkle Hash of the MerkleTrie. +// If the MerkleTrie is empty, EmptyTrieHash is returned. +func (t *MerkleTrie) MerkleHash() chainhash.Hash { + if merkle(t.root) == nil { + return EmptyTrieHash + } + return *t.root.hash +} + +func newHashFromStr(s string) *chainhash.Hash { + h, _ := chainhash.NewHashFromStr(s) + return h +} diff --git a/trie/trie_test.go b/trie/trie_test.go new file mode 100644 index 0000000..0caf285 --- /dev/null +++ b/trie/trie_test.go @@ -0,0 +1,75 @@ +package trie + +import ( + "reflect" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +func TestTrie_Update(t *testing.T) { + mt := buildTrie(New(), pairs1()) + m := buildMap(newMap(), pairs1()) + + for k := range m { + v, _ := mt.Get(Key(k)) + if m[k] != v { + t.Errorf("exp %s got %s", m[k], v) + } + } +} + +func TestTrie_Hash(t *testing.T) { + tr1 := buildTrie(New(), pairs1()) + // tr2 := func() *MerkleTrie { + // p1 := wire.OutPoint{Hash: *newHashFromStr("627ecfee2110b28fbc4b012944cadf66a72f394ad9fa9bb18fec30789e26c9ac"), Index: 0} + // p2 := wire.OutPoint{Hash: *newHashFromStr("c31bd469112abf04930879c6b6007d2b23224e042785d404bbeff1932dd94880"), Index: 0} + + // n1 := claim.NewNode(&claim.Claim{OutPoint: p1, ClaimID: nil, Amount: 50, Height: 100, ValidAtHeight: 200}) + // n2 := claim.NewNode(&claim.Claim{OutPoint: p2, ClaimID: nil, Amount: 50, Height: 100, ValidAtHeight: 200}) + + // pairs := []pair{ + // {"test", n1}, + // {"test2", n2}, + // } + // return buildTrie(New(), pairs) + // }() + tests := []struct { + name string + mt *MerkleTrie + want chainhash.Hash + }{ + {"empty", New(), *newHashFromStr("0000000000000000000000000000000000000000000000000000000000000001")}, + {"test1", tr1, *newHashFromStr("c2fdce68a30e3cabf6efb3b7ebfd32afdaf09f9ebd062743fe91e181f682252b")}, + // {"test2", tr2, *newHashFromStr("71c7b8d35b9a3d7ad9a1272b68972979bbd18589f1efe6f27b0bf260a6ba78fa")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mt := tt.mt + if got := mt.MerkleHash(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("trie.MerkleHash() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTrie_Size(t *testing.T) { + mt1 := buildTrie(New(), pairs1()) + map1 := buildMap(newMap(), pairs1()) + + tests := []struct { + name string + mt *MerkleTrie + want int + }{ + {"test1", mt1, len(map1)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mt := tt.mt + if got := mt.Size(); got != tt.want { + t.Errorf("trie.Size() = %v, want %v", got, tt.want) + } + }) + } +}