diff --git a/claimtrie/claimtrie.go b/claimtrie/claimtrie.go index abce243a..1e36a184 100644 --- a/claimtrie/claimtrie.go +++ b/claimtrie/claimtrie.go @@ -46,7 +46,7 @@ type ClaimTrie struct { nodeManager node.Manager // Prefix tree (trie) that manages merkle hash of each node. - merkleTrie *merkletrie.MerkleTrie + merkleTrie merkletrie.MerkleTrie // Current block height, which is increased by one when AppendBlock() is called. height int32 @@ -99,6 +99,11 @@ func New(cfg config.Config) (*ClaimTrie, error) { trie := merkletrie.New(nodeManager, trieRepo) cleanups = append(cleanups, trie.Close) + persistentTrie := merkletrie.NewPersistentTrie(nodeManager, trieRepo) + cleanups = append(cleanups, persistentTrie.Close) + trie = persistentTrie + } + // Restore the last height. previousHeight, err := blockRepo.Load() if err != nil { @@ -110,12 +115,11 @@ func New(cfg config.Config) (*ClaimTrie, error) { if err != nil { return nil, fmt.Errorf("get hash: %w", err) } - trie.SetRoot(hash) - _, err = nodeManager.IncrementHeightTo(previousHeight) if err != nil { return nil, fmt.Errorf("node manager init: %w", err) } + trie.SetRoot(hash, nil) // keep this after IncrementHeightTo } ct := &ClaimTrie{ @@ -275,7 +279,7 @@ func (ct *ClaimTrie) AppendBlock() error { ct.blockRepo.Set(ct.height, h) if hitFork { - ct.merkleTrie.SetRoot(h) // for clearing the memory entirely + ct.merkleTrie.SetRoot(h, names) // for clearing the memory entirely runtime.GC() } @@ -337,12 +341,21 @@ func (ct *ClaimTrie) ResetHeight(height int32) error { return err } + passedHashFork := ct.height >= param.AllClaimsInMerkleForkHeight && height < param.AllClaimsInMerkleForkHeight ct.height = height hash, err := ct.blockRepo.Get(height) if err != nil { return err } - ct.merkleTrie.SetRoot(hash) + + if passedHashFork { + names = nil // force them to reconsider all names + } + ct.merkleTrie.SetRoot(hash, names) + + if !ct.MerkleHash().IsEqual(hash) { + return fmt.Errorf("unable to restore the hash at height %d", height) + } return nil } diff --git a/claimtrie/claimtrie_test.go b/claimtrie/claimtrie_test.go index 5945a181..1a2f869f 100644 --- a/claimtrie/claimtrie_test.go +++ b/claimtrie/claimtrie_test.go @@ -231,3 +231,39 @@ func verifyBestIndex(t *testing.T, ct *ClaimTrie, name string, idx uint32, claim r.Equal(idx, n.BestClaim.OutPoint.Index) } } + +func TestRebuild(t *testing.T) { + r := require.New(t) + setup(t) + ct, err := New(true) + r.NoError(err) + r.NotNil(ct) + defer func() { + err := ct.Close() + r.NoError(err) + }() + + hash := chainhash.HashH([]byte{1, 2, 3}) + + o1 := wire.OutPoint{Hash: hash, Index: 1} + err = ct.AddClaim([]byte("test1"), o1, node.NewClaimID(o1), 1, nil) + r.NoError(err) + + o2 := wire.OutPoint{Hash: hash, Index: 2} + err = ct.AddClaim([]byte("test2"), o2, node.NewClaimID(o2), 2, nil) + r.NoError(err) + + err = ct.AppendBlock() + r.NoError(err) + + m := ct.MerkleHash() + r.NotNil(m) + r.NotEqual(*merkletrie.EmptyTrieHash, *m) + + ct.merkleTrie = merkletrie.NewRamTrie(ct.nodeManager) + ct.merkleTrie.SetRoot(m, nil) + + m2 := ct.MerkleHash() + r.NotNil(m2) + r.Equal(*m, *m2) +} diff --git a/claimtrie/cmd/cmd/block.go b/claimtrie/cmd/cmd/block.go index 4f4a8825..ced2f68a 100644 --- a/claimtrie/cmd/cmd/block.go +++ b/claimtrie/cmd/cmd/block.go @@ -135,9 +135,9 @@ var blockNameCmd = &cobra.Command{ return fmt.Errorf("can't open merkle trie repo: %w", err) } - trie := merkletrie.New(nil, trieRepo) + trie := merkletrie.NewPersistentTrie(nil, trieRepo) defer trie.Close() - trie.SetRoot(hash) + trie.SetRoot(hash, nil) if len(args) > 1 { trie.Dump(args[1], param.AllClaimsInMerkleForkHeight >= int32(height)) } else { diff --git a/claimtrie/merkletrie/prefix_trie.go b/claimtrie/merkletrie/collapsedtrie.go similarity index 55% rename from claimtrie/merkletrie/prefix_trie.go rename to claimtrie/merkletrie/collapsedtrie.go index ceabb180..7544aa25 100644 --- a/claimtrie/merkletrie/prefix_trie.go +++ b/claimtrie/merkletrie/collapsedtrie.go @@ -1,21 +1,21 @@ package merkletrie import ( - "github.com/lbryio/chain/chaincfg/chainhash" + "github.com/btcsuite/btcd/chaincfg/chainhash" ) type KeyType []byte -type PrefixTrieNode struct { // implements sort.Interface - children []*PrefixTrieNode - key KeyType - hash *chainhash.Hash - hasClaims bool +type collapsedVertex struct { // implements sort.Interface + children []*collapsedVertex + key KeyType + merkleHash *chainhash.Hash + claimHash *chainhash.Hash } // insertAt inserts v into s at index i and returns the new slice. // https://stackoverflow.com/questions/42746972/golang-insert-to-a-sorted-slice -func insertAt(data []*PrefixTrieNode, i int, v *PrefixTrieNode) []*PrefixTrieNode { +func insertAt(data []*collapsedVertex, i int, v *collapsedVertex) []*collapsedVertex { if i == len(data) { // Insert at end is the easy case. return append(data, v) @@ -30,7 +30,7 @@ func insertAt(data []*PrefixTrieNode, i int, v *PrefixTrieNode) []*PrefixTrieNod return data } -func (ptn *PrefixTrieNode) Insert(value *PrefixTrieNode) *PrefixTrieNode { +func (ptn *collapsedVertex) Insert(value *collapsedVertex) *collapsedVertex { // keep it sorted (and sort.Sort is too slow) index := sortSearch(ptn.children, value.key[0]) ptn.children = insertAt(ptn.children, index, value) @@ -40,7 +40,7 @@ func (ptn *PrefixTrieNode) Insert(value *PrefixTrieNode) *PrefixTrieNode { // this sort.Search is stolen shamelessly from search.go, // and modified for performance to not need a closure -func sortSearch(nodes []*PrefixTrieNode, b byte) int { +func sortSearch(nodes []*collapsedVertex, b byte) int { i, j := 0, len(nodes) for i < j { h := int(uint(i+j) >> 1) // avoid overflow when computing h @@ -55,9 +55,9 @@ func sortSearch(nodes []*PrefixTrieNode, b byte) int { return i } -func (ptn *PrefixTrieNode) FindNearest(start KeyType) (int, *PrefixTrieNode) { +func (ptn *collapsedVertex) findNearest(key KeyType) (int, *collapsedVertex) { // none of the children overlap on the first char or we would have a parent node with that char - index := sortSearch(ptn.children, start[0]) + index := sortSearch(ptn.children, key[0]) hits := ptn.children[index:] if len(hits) > 0 { return index, hits[0] @@ -65,26 +65,17 @@ func (ptn *PrefixTrieNode) FindNearest(start KeyType) (int, *PrefixTrieNode) { return -1, nil } -type PrefixTrie interface { - InsertOrFind(value KeyType) (bool, *PrefixTrieNode) - Find(value KeyType) *PrefixTrieNode - FindPath(value KeyType) ([]int, []*PrefixTrieNode) - IterateFrom(start KeyType, handler func(value *PrefixTrieNode) bool) - Erase(value KeyType) bool - NodeCount() int -} - -type prefixTrie struct { - root *PrefixTrieNode +type collapsedTrie struct { + Root *collapsedVertex Nodes int } -func NewPrefixTrie() PrefixTrie { - // we never delete the root node - return &prefixTrie{root: &PrefixTrieNode{key: make(KeyType, 0)}, Nodes: 1} +func NewCollapsedTrie() *collapsedTrie { + // we never delete the Root node + return &collapsedTrie{Root: &collapsedVertex{key: make(KeyType, 0)}, Nodes: 1} } -func (pt *prefixTrie) NodeCount() int { +func (pt *collapsedTrie) NodeCount() int { return pt.Nodes } @@ -101,10 +92,11 @@ func matchLength(a, b KeyType) int { return minLen } -func (pt *prefixTrie) insert(value KeyType, node *PrefixTrieNode) (bool, *PrefixTrieNode) { - index, child := node.FindNearest(value) +func (pt *collapsedTrie) insert(value KeyType, node *collapsedVertex) (bool, *collapsedVertex) { + index, child := node.findNearest(value) match := 0 if index >= 0 { // if we found a child + child.merkleHash = nil match = matchLength(value, child.key) if len(value) == match && len(child.key) == match { return false, child @@ -112,12 +104,12 @@ func (pt *prefixTrie) insert(value KeyType, node *PrefixTrieNode) (bool, *Prefix } if match <= 0 { pt.Nodes++ - return true, node.Insert(&PrefixTrieNode{key: value}) + return true, node.Insert(&collapsedVertex{key: value}) } if match < len(child.key) { - grandChild := PrefixTrieNode{key: child.key[match:], children: child.children, - hasClaims: child.hasClaims, hash: child.hash} - newChild := PrefixTrieNode{key: child.key[0:match], children: []*PrefixTrieNode{&grandChild}} + grandChild := collapsedVertex{key: child.key[match:], children: child.children, + claimHash: child.claimHash, merkleHash: child.merkleHash} + newChild := collapsedVertex{key: child.key[0:match], children: []*collapsedVertex{&grandChild}} child = &newChild node.children[index] = child pt.Nodes++ @@ -128,15 +120,21 @@ func (pt *prefixTrie) insert(value KeyType, node *PrefixTrieNode) (bool, *Prefix return pt.insert(value[match:], child) } -func (pt *prefixTrie) InsertOrFind(value KeyType) (bool, *PrefixTrieNode) { +func (pt *collapsedTrie) InsertOrFind(value KeyType) (bool, *collapsedVertex) { + pt.Root.merkleHash = nil if len(value) <= 0 { - return false, pt.root + return false, pt.Root } - return pt.insert(value, pt.root) + + // we store the name so we need to make our own copy of it + // this avoids errors where this function is called via the DB iterator + v2 := make([]byte, len(value)) + copy(v2, value) + return pt.insert(v2, pt.Root) } -func find(value KeyType, node *PrefixTrieNode, pathIndexes *[]int, path *[]*PrefixTrieNode) *PrefixTrieNode { - index, child := node.FindNearest(value) +func find(value KeyType, node *collapsedVertex, pathIndexes *[]int, path *[]*collapsedVertex) *collapsedVertex { + index, child := node.findNearest(value) if index < 0 { return nil } @@ -162,34 +160,36 @@ func find(value KeyType, node *PrefixTrieNode, pathIndexes *[]int, path *[]*Pref return find(value[match:], child, pathIndexes, path) } -func (pt *prefixTrie) Find(value KeyType) *PrefixTrieNode { +func (pt *collapsedTrie) Find(value KeyType) *collapsedVertex { if len(value) <= 0 { - return pt.root + return pt.Root } - return find(value, pt.root, nil, nil) + return find(value, pt.Root, nil, nil) } -func (pt *prefixTrie) FindPath(value KeyType) ([]int, []*PrefixTrieNode) { +func (pt *collapsedTrie) FindPath(value KeyType) ([]int, []*collapsedVertex) { pathIndexes := []int{-1} - path := []*PrefixTrieNode{pt.root} - result := find(value, pt.root, &pathIndexes, &path) - if result == nil { - return nil, nil - } // not sure I want this line + path := []*collapsedVertex{pt.Root} + if len(value) > 0 { + result := find(value, pt.Root, &pathIndexes, &path) + if result == nil { // not sure I want this line + return nil, nil + } + } return pathIndexes, path } // IterateFrom can be used to find a value and run a function on that value. // If the handler returns true it continues to iterate through the children of value. -func (pt *prefixTrie) IterateFrom(start KeyType, handler func(value *PrefixTrieNode) bool) { - node := find(start, pt.root, nil, nil) +func (pt *collapsedTrie) IterateFrom(start KeyType, handler func(value *collapsedVertex) bool) { + node := find(start, pt.Root, nil, nil) if node == nil { return } iterateFrom(node, handler) } -func iterateFrom(node *PrefixTrieNode, handler func(value *PrefixTrieNode) bool) { +func iterateFrom(node *collapsedVertex, handler func(value *collapsedVertex) bool) { for handler(node) { for _, child := range node.children { iterateFrom(child, handler) @@ -197,19 +197,25 @@ func iterateFrom(node *PrefixTrieNode, handler func(value *PrefixTrieNode) bool) } } -func (pt *prefixTrie) Erase(value KeyType) bool { +func (pt *collapsedTrie) Erase(value KeyType) bool { indexes, path := pt.FindPath(value) if path == nil || len(path) <= 1 { + if len(path) == 1 { + path[0].merkleHash = nil + path[0].claimHash = nil + } return false } nodes := pt.Nodes - for i := len(path) - 1; i > 0; i-- { + i := len(path) - 1 + path[i].claimHash = nil // this is the thing we are erasing; the rest is book-keeping + for ; i > 0; i-- { childCount := len(path[i].children) - noClaimData := !path[i].hasClaims + noClaimData := path[i].claimHash == nil + path[i].merkleHash = nil if childCount == 1 && noClaimData { path[i].key = append(path[i].key, path[i].children[0].key...) - path[i].hash = nil - path[i].hasClaims = path[i].children[0].hasClaims + path[i].claimHash = path[i].children[0].claimHash path[i].children = path[i].children[0].children pt.Nodes-- continue @@ -222,5 +228,8 @@ func (pt *prefixTrie) Erase(value KeyType) bool { } break } + for ; i >= 0; i-- { + path[i].merkleHash = nil + } return nodes > pt.Nodes } diff --git a/claimtrie/merkletrie/prefix_trie_test.go b/claimtrie/merkletrie/collapsedtrie_test.go similarity index 66% rename from claimtrie/merkletrie/prefix_trie_test.go rename to claimtrie/merkletrie/collapsedtrie_test.go index ad277c36..8efc3375 100644 --- a/claimtrie/merkletrie/prefix_trie_test.go +++ b/claimtrie/merkletrie/collapsedtrie_test.go @@ -9,10 +9,10 @@ import ( ) func b(value string) []byte { return []byte(value) } -func eq(x []byte, y string) bool { return bytes.Compare(x, b(y)) == 0 } +func eq(x []byte, y string) bool { return bytes.Equal(x, b(y)) } func TestInsertAndErase(t *testing.T) { - trie := NewPrefixTrie() + trie := NewCollapsedTrie() assert.True(t, trie.NodeCount() == 1) inserted, node := trie.InsertOrFind(b("abc")) assert.True(t, inserted) @@ -45,8 +45,28 @@ func TestInsertAndErase(t *testing.T) { assert.Equal(t, 1, trie.NodeCount()) } -func TestPrefixTrie(t *testing.T) { - inserts := 1000000 +func TestNilNameHandling(t *testing.T) { + trie := NewCollapsedTrie() + inserted, n := trie.InsertOrFind([]byte("test")) + assert.True(t, inserted) + n.claimHash = EmptyTrieHash + inserted, n = trie.InsertOrFind(nil) + assert.False(t, inserted) + n.claimHash = EmptyTrieHash + n.merkleHash = EmptyTrieHash + inserted, n = trie.InsertOrFind(nil) + assert.False(t, inserted) + assert.NotNil(t, n.claimHash) + assert.Nil(t, n.merkleHash) + nodeRemoved := trie.Erase(nil) + assert.False(t, nodeRemoved) + inserted, n = trie.InsertOrFind(nil) + assert.False(t, inserted) + assert.Nil(t, n.claimHash) +} + +func TestCollapsedTriePerformance(t *testing.T) { + inserts := 10000 // increase this to 1M for more interesting results data := make([][]byte, inserts) rand.Seed(42) for i := 0; i < inserts; i++ { @@ -58,21 +78,21 @@ func TestPrefixTrie(t *testing.T) { } } - trie := NewPrefixTrie() + trie := NewCollapsedTrie() // doing my own timing because I couldn't get the B.Run method to work: start := time.Now() for i := 0; i < inserts; i++ { _, node := trie.InsertOrFind(data[i]) assert.NotNil(t, node, "Failure at %d of %d", i, inserts) } - t.Logf("Insertion in %f sec.", time.Now().Sub(start).Seconds()) + t.Logf("Insertion in %f sec.", time.Since(start).Seconds()) start = time.Now() for i := 0; i < inserts; i++ { node := trie.Find(data[i]) assert.True(t, bytes.HasSuffix(data[i], node.key), "Failure on %d of %d", i, inserts) } - t.Logf("Lookup in %f sec. on %d nodes.", time.Now().Sub(start).Seconds(), trie.NodeCount()) + t.Logf("Lookup in %f sec. on %d nodes.", time.Since(start).Seconds(), trie.NodeCount()) start = time.Now() for i := 0; i < inserts; i++ { @@ -81,12 +101,12 @@ func TestPrefixTrie(t *testing.T) { assert.True(t, len(path) > 1) assert.True(t, bytes.HasSuffix(data[i], path[len(path)-1].key)) } - t.Logf("Parents in %f sec.", time.Now().Sub(start).Seconds()) + t.Logf("Parents in %f sec.", time.Since(start).Seconds()) start = time.Now() for i := 0; i < inserts; i++ { trie.Erase(data[i]) } - t.Logf("Deletion in %f sec.", time.Now().Sub(start).Seconds()) + t.Logf("Deletion in %f sec.", time.Since(start).Seconds()) assert.Equal(t, 1, trie.NodeCount()) } diff --git a/claimtrie/merkletrie/merkletrie.go b/claimtrie/merkletrie/merkletrie.go index 9d244757..3c279652 100644 --- a/claimtrie/merkletrie/merkletrie.go +++ b/claimtrie/merkletrie/merkletrie.go @@ -7,25 +7,27 @@ import ( "sync" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/claimtrie/node" + "github.com/cockroachdb/pebble" ) var ( - // EmptyTrieHash represents the Merkle Hash of an empty MerkleTrie. + // EmptyTrieHash represents the Merkle Hash of an empty PersistentTrie. // "0000000000000000000000000000000000000000000000000000000000000001" EmptyTrieHash = &chainhash.Hash{1} NoChildrenHash = &chainhash.Hash{2} NoClaimsHash = &chainhash.Hash{3} ) -// ValueStore enables MerkleTrie to query node values from different implementations. +// ValueStore enables PersistentTrie to query node values from different implementations. type ValueStore interface { - ClaimHashes(name []byte) []*chainhash.Hash Hash(name []byte) *chainhash.Hash + IterateNames(predicate func(name []byte) bool) } -// MerkleTrie implements a 256-way prefix tree. -type MerkleTrie struct { +// PersistentTrie implements a 256-way prefix tree. +type PersistentTrie struct { store ValueStore repo Repo @@ -33,10 +35,10 @@ type MerkleTrie struct { bufs *sync.Pool } -// New returns a MerkleTrie. -func New(store ValueStore, repo Repo) *MerkleTrie { +// NewPersistentTrie returns a PersistentTrie. +func NewPersistentTrie(store ValueStore, repo Repo) *PersistentTrie { - tr := &MerkleTrie{ + tr := &PersistentTrie{ store: store, repo: repo, bufs: &sync.Pool{ @@ -50,14 +52,14 @@ func New(store ValueStore, repo Repo) *MerkleTrie { return tr } -// SetRoot drops all resolved nodes in the MerkleTrie, and set the root with specified hash. -func (t *MerkleTrie) SetRoot(h *chainhash.Hash) { +// SetRoot drops all resolved nodes in the PersistentTrie, and set the Root with specified hash. +func (t *PersistentTrie) SetRoot(h *chainhash.Hash, names [][]byte) { t.root = newVertex(h) } // Update updates the nodes along the path to the key. // Each node is resolved or created with their Hash cleared. -func (t *MerkleTrie) Update(name []byte, restoreChildren bool) { +func (t *PersistentTrie) Update(name []byte, restoreChildren bool) { n := t.root for i, ch := range name { @@ -80,7 +82,7 @@ func (t *MerkleTrie) Update(name []byte, restoreChildren bool) { } // resolveChildLinks updates the links on n -func (t *MerkleTrie) resolveChildLinks(n *vertex, key []byte) { +func (t *PersistentTrie) resolveChildLinks(n *vertex, key []byte) { if n.merkleHash == nil { return @@ -108,9 +110,9 @@ func (t *MerkleTrie) resolveChildLinks(n *vertex, key []byte) { } } -// MerkleHash returns the Merkle Hash of the MerkleTrie. +// MerkleHash returns the Merkle Hash of the PersistentTrie. // All nodes must have been resolved before calling this function. -func (t *MerkleTrie) MerkleHash() *chainhash.Hash { +func (t *PersistentTrie) MerkleHash() *chainhash.Hash { buf := make([]byte, 0, 256) if h := t.merkle(buf, t.root); h == nil { return EmptyTrieHash @@ -120,7 +122,7 @@ func (t *MerkleTrie) MerkleHash() *chainhash.Hash { // merkle recursively resolves the hashes of the node. // All nodes must have been resolved before calling this function. -func (t *MerkleTrie) merkle(prefix []byte, v *vertex) *chainhash.Hash { +func (t *PersistentTrie) merkle(prefix []byte, v *vertex) *chainhash.Hash { if v.merkleHash != nil { return v.merkleHash } @@ -178,7 +180,7 @@ func keysInOrder(v *vertex) []byte { return keys } -func (t *MerkleTrie) MerkleHashAllClaims() *chainhash.Hash { +func (t *PersistentTrie) MerkleHashAllClaims() *chainhash.Hash { buf := make([]byte, 0, 256) if h := t.merkleAllClaims(buf, t.root); h == nil { return EmptyTrieHash @@ -186,7 +188,7 @@ func (t *MerkleTrie) MerkleHashAllClaims() *chainhash.Hash { return t.root.merkleHash } -func (t *MerkleTrie) merkleAllClaims(prefix []byte, v *vertex) *chainhash.Hash { +func (t *PersistentTrie) merkleAllClaims(prefix []byte, v *vertex) *chainhash.Hash { if v.merkleHash != nil { return v.merkleHash } @@ -213,32 +215,25 @@ func (t *MerkleTrie) merkleAllClaims(prefix []byte, v *vertex) *chainhash.Hash { } } - var claimsHash *chainhash.Hash if v.hasValue { - claimsHash = v.claimsHash - if claimsHash == nil { - claimHashes := t.store.ClaimHashes(prefix) - if len(claimHashes) > 0 { - claimsHash = computeMerkleRoot(claimHashes) - v.claimsHash = claimsHash - } else { - v.hasValue = false - } + if v.claimsHash == nil { + v.claimsHash = t.store.Hash(prefix) + v.hasValue = v.claimsHash != nil } } - if len(childHashes) > 1 || claimsHash != nil { // yeah, about that 1 there -- old code used the condensed trie + if len(childHashes) > 1 || v.claimsHash != nil { // yeah, about that 1 there -- old code used the condensed trie left := NoChildrenHash if len(childHashes) > 0 { - left = computeMerkleRoot(childHashes) + left = node.ComputeMerkleRoot(childHashes) } right := NoClaimsHash - if claimsHash != nil { - b.Write(claimsHash[:]) // for Has Value, nolint : errchk - right = claimsHash + if v.claimsHash != nil { + b.Write(v.claimsHash[:]) // for Has Value, nolint : errchk + right = v.claimsHash } - h := hashMerkleBranches(left, right) + h := node.HashMerkleBranches(left, right) v.merkleHash = h t.repo.Set(append(prefix, h[:]...), b.Bytes()) } else if len(childHashes) == 1 { @@ -249,11 +244,11 @@ func (t *MerkleTrie) merkleAllClaims(prefix []byte, v *vertex) *chainhash.Hash { return v.merkleHash } -func (t *MerkleTrie) Close() error { +func (t *PersistentTrie) Close() error { return t.repo.Close() } -func (t *MerkleTrie) Dump(s string, allClaims bool) { +func (t *PersistentTrie) Dump(s string, allClaims bool) { v := t.root for i := 0; i < len(s); i++ { diff --git a/claimtrie/merkletrie/merkletrie_test.go b/claimtrie/merkletrie/merkletrie_test.go index e5bf89f3..1fdb6b04 100644 --- a/claimtrie/merkletrie/merkletrie_test.go +++ b/claimtrie/merkletrie/merkletrie_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/claimtrie/node" "github.com/stretchr/testify/require" ) @@ -15,10 +16,10 @@ func TestName(t *testing.T) { target, _ := chainhash.NewHashFromStr("e9ffb584c62449f157c8be88257bd1eebb2d8ef824f5c86b43c4f8fd9e800d6a") data := []*chainhash.Hash{EmptyTrieHash} - root := computeMerkleRoot(data) + root := node.ComputeMerkleRoot(data) r.True(EmptyTrieHash.IsEqual(root)) data = append(data, NoChildrenHash, NoClaimsHash) - root = computeMerkleRoot(data) + root = node.ComputeMerkleRoot(data) r.True(target.IsEqual(root)) } diff --git a/claimtrie/merkletrie/ramtrie.go b/claimtrie/merkletrie/ramtrie.go new file mode 100644 index 00000000..77620f06 --- /dev/null +++ b/claimtrie/merkletrie/ramtrie.go @@ -0,0 +1,139 @@ +package merkletrie + +import ( + "bytes" + "fmt" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/claimtrie/node" + "sync" +) + +type MerkleTrie interface { + SetRoot(h *chainhash.Hash, names [][]byte) + Update(name []byte, restoreChildren bool) + MerkleHash() *chainhash.Hash + MerkleHashAllClaims() *chainhash.Hash +} + +type RamTrie struct { + collapsedTrie + store ValueStore + bufs *sync.Pool +} + +func NewRamTrie(s ValueStore) *RamTrie { + return &RamTrie{ + store: s, + bufs: &sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, + }, + collapsedTrie: collapsedTrie{Root: &collapsedVertex{}}, + } +} + +func (rt *RamTrie) SetRoot(h *chainhash.Hash, names [][]byte) { + if rt.Root.merkleHash.IsEqual(h) { + return + } + + // if names is nil then we need to query all names + if names == nil { + fmt.Printf("Building the entire claim trie in RAM...\n") + // TODO: should technically clear the old trie first + rt.store.IterateNames(func(name []byte) bool { + rt.Update(name, false) + return true + }) + } else { + for _, name := range names { + rt.Update(name, false) + } + } +} + +func (rt *RamTrie) Update(name []byte, _ bool) { + h := rt.store.Hash(name) + if h == nil { + rt.Erase(name) + } else { + _, n := rt.InsertOrFind(name) + n.claimHash = h + } +} + +func (rt *RamTrie) MerkleHash() *chainhash.Hash { + if h := rt.merkleHash(rt.Root); h == nil { + return EmptyTrieHash + } + return rt.Root.merkleHash +} + +func (rt *RamTrie) merkleHash(v *collapsedVertex) *chainhash.Hash { + if v.merkleHash != nil { + return v.merkleHash + } + + b := rt.bufs.Get().(*bytes.Buffer) + defer rt.bufs.Put(b) + b.Reset() + + for _, ch := range v.children { + h := rt.merkleHash(ch) // h is a pointer; don't destroy its data + b.WriteByte(ch.key[0]) // nolint : errchk + b.Write(rt.completeHash(h, ch.key)) // nolint : errchk + } + + if v.claimHash != nil { + b.Write(v.claimHash[:]) + } + + if b.Len() > 0 { + h := chainhash.DoubleHashH(b.Bytes()) + v.merkleHash = &h + } + + return v.merkleHash +} + +func (rt *RamTrie) completeHash(h *chainhash.Hash, childKey KeyType) []byte { + var data [chainhash.HashSize + 1]byte + copy(data[1:], h[:]) + for i := len(childKey) - 1; i > 0; i-- { + data[0] = childKey[i] + copy(data[1:], chainhash.DoubleHashB(data[:])) + } + return data[1:] +} + +func (rt *RamTrie) MerkleHashAllClaims() *chainhash.Hash { + return rt.merkleHashAllClaims(rt.Root) +} + +func (rt *RamTrie) merkleHashAllClaims(v *collapsedVertex) *chainhash.Hash { + if v.merkleHash != nil { + return v.merkleHash + } + + childHashes := make([]*chainhash.Hash, 0, len(v.children)) + for _, ch := range v.children { + h := rt.merkleHashAllClaims(ch) + childHashes = append(childHashes, h) + } + + claimHash := NoClaimsHash + if v.claimHash != nil { + claimHash = v.claimHash + } else if len(childHashes) == 0 { + return v.merkleHash + } + + childHash := NoChildrenHash + if len(childHashes) > 0 { + childHash = node.ComputeMerkleRoot(childHashes) + } + + v.merkleHash = node.HashMerkleBranches(childHash, claimHash) + return v.merkleHash +} diff --git a/claimtrie/merkletrie/repo.go b/claimtrie/merkletrie/repo.go index 5f26baa8..75c57261 100644 --- a/claimtrie/merkletrie/repo.go +++ b/claimtrie/merkletrie/repo.go @@ -4,7 +4,7 @@ import ( "io" ) -// Repo defines APIs for MerkleTrie to access persistence layer. +// Repo defines APIs for PersistentTrie to access persistence layer. type Repo interface { Get(key []byte) ([]byte, io.Closer, error) Set(key, value []byte) error diff --git a/claimtrie/merkletrie/hashfunc.go b/claimtrie/node/hashfunc.go similarity index 75% rename from claimtrie/merkletrie/hashfunc.go rename to claimtrie/node/hashfunc.go index 9d300434..7c401e5d 100644 --- a/claimtrie/merkletrie/hashfunc.go +++ b/claimtrie/node/hashfunc.go @@ -1,8 +1,8 @@ -package merkletrie +package node import "github.com/btcsuite/btcd/chaincfg/chainhash" -func hashMerkleBranches(left *chainhash.Hash, right *chainhash.Hash) *chainhash.Hash { +func HashMerkleBranches(left *chainhash.Hash, right *chainhash.Hash) *chainhash.Hash { // Concatenate the left and right nodes. var hash [chainhash.HashSize * 2]byte copy(hash[:chainhash.HashSize], left[:]) @@ -12,7 +12,7 @@ func hashMerkleBranches(left *chainhash.Hash, right *chainhash.Hash) *chainhash. return &newHash } -func computeMerkleRoot(hashes []*chainhash.Hash) *chainhash.Hash { +func ComputeMerkleRoot(hashes []*chainhash.Hash) *chainhash.Hash { if len(hashes) <= 0 { return nil } @@ -21,7 +21,7 @@ func computeMerkleRoot(hashes []*chainhash.Hash) *chainhash.Hash { hashes = append(hashes, hashes[len(hashes)-1]) } for i := 0; i < len(hashes); i += 2 { // TODO: parallelize this loop (or use a lib that does it) - hashes[i>>1] = hashMerkleBranches(hashes[i], hashes[i+1]) + hashes[i>>1] = HashMerkleBranches(hashes[i], hashes[i+1]) } hashes = hashes[:len(hashes)>>1] } diff --git a/claimtrie/node/manager.go b/claimtrie/node/manager.go index 987fdef2..acbf4bf7 100644 --- a/claimtrie/node/manager.go +++ b/claimtrie/node/manager.go @@ -21,7 +21,6 @@ type Manager interface { Node(name []byte) (*Node, error) NextUpdateHeightOfNode(name []byte) ([]byte, int32) IterateNames(predicate func(name []byte) bool) - ClaimHashes(name []byte) []*chainhash.Hash Hash(name []byte) *chainhash.Hash } @@ -295,7 +294,7 @@ func (nm *BaseManager) IterateNames(predicate func(name []byte) bool) { nm.repo.IterateAll(predicate) } -func (nm *BaseManager) ClaimHashes(name []byte) []*chainhash.Hash { +func (nm *BaseManager) claimHashes(name []byte) *chainhash.Hash { n, err := nm.Node(name) if err != nil || n == nil { @@ -308,11 +307,18 @@ func (nm *BaseManager) ClaimHashes(name []byte) []*chainhash.Hash { claimHashes = append(claimHashes, calculateNodeHash(c.OutPoint, n.TakenOverAt)) } } - return claimHashes + if len(claimHashes) > 0 { + return ComputeMerkleRoot(claimHashes) + } + return nil } func (nm *BaseManager) Hash(name []byte) *chainhash.Hash { + if nm.height >= param.AllClaimsInMerkleForkHeight { + return nm.claimHashes(name) + } + n, err := nm.Node(name) if err != nil { return nil