# Copyright (c) 2018, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # and warranty status of this software. """Merkle trees, branches, proofs and roots.""" import typing from asyncio import Event from math import ceil, log from scribe.common import double_sha256 class Merkle: """Perform merkle tree calculations on binary hashes using a given hash function. If the hash count is not even, the final hash is repeated when calculating the next merkle layer up the tree. """ def __init__(self, hash_func=double_sha256): self.hash_func = hash_func @staticmethod def tree_depth(hash_count): return Merkle.branch_length(hash_count) + 1 @staticmethod def branch_length(hash_count): """Return the length of a merkle branch given the number of hashes.""" if not isinstance(hash_count, int): raise TypeError('hash_count must be an integer') if hash_count < 1: raise ValueError('hash_count must be at least 1') return ceil(log(hash_count, 2)) @staticmethod def branch_and_root(hashes, index, length=None, hash_func=double_sha256): """Return a (merkle branch, merkle_root) pair given hashes, and the index of one of those hashes. """ hashes = list(hashes) if not isinstance(index, int): raise TypeError('index must be an integer') # This also asserts hashes is not empty if not 0 <= index < len(hashes): raise ValueError(f"index '{index}/{len(hashes)}' out of range") natural_length = Merkle.branch_length(len(hashes)) if length is None: length = natural_length else: if not isinstance(length, int): raise TypeError('length must be an integer') if length < natural_length: raise ValueError('length out of range') branch = [] for _ in range(length): if len(hashes) & 1: hashes.append(hashes[-1]) branch.append(hashes[index ^ 1]) index >>= 1 hashes = [hash_func(hashes[n] + hashes[n + 1]) for n in range(0, len(hashes), 2)] return branch, hashes[0] @staticmethod def branches_and_root(block_tx_hashes: typing.List[bytes], tx_positions: typing.List[int]): block_tx_hashes = list(block_tx_hashes) positions = list(tx_positions) length = ceil(log(len(block_tx_hashes), 2)) branches = [[] for _ in range(len(tx_positions))] for _ in range(length): if len(block_tx_hashes) & 1: h = block_tx_hashes[-1] block_tx_hashes.append(h) for idx, tx_position in enumerate(tx_positions): h = block_tx_hashes[tx_position ^ 1] branches[idx].append(h) tx_positions[idx] >>= 1 block_tx_hashes = [ double_sha256(block_tx_hashes[n] + block_tx_hashes[n + 1]) for n in range(0, len(block_tx_hashes), 2) ] return {tx_position: branch for tx_position, branch in zip(positions, branches)}, block_tx_hashes[0] @staticmethod def root(hashes, length=None): """Return the merkle root of a non-empty iterable of binary hashes.""" branch, root = Merkle.branch_and_root(hashes, 0, length) return root # @staticmethod # def root_from_proof(hash, branch, index, hash_func=double_sha256): # """Return the merkle root given a hash, a merkle branch to it, and # its index in the hashes array. # # branch is an iterable sorted deepest to shallowest. If the # returned root is the expected value then the merkle proof is # verified. # # The caller should have confirmed the length of the branch with # branch_length(). Unfortunately this is not easily done for # bitcoin transactions as the number of transactions in a block # is unknown to an SPV client. # """ # for elt in branch: # if index & 1: # hash = hash_func(elt + hash) # else: # hash = hash_func(hash + elt) # index >>= 1 # if index: # raise ValueError('index out of range for branch') # return hash @staticmethod def level(hashes, depth_higher): """Return a level of the merkle tree of hashes the given depth higher than the bottom row of the original tree.""" size = 1 << depth_higher root = Merkle.root return [root(hashes[n: n + size], depth_higher) for n in range(0, len(hashes), size)] @staticmethod def branch_and_root_from_level(level, leaf_hashes, index, depth_higher): """Return a (merkle branch, merkle_root) pair when a merkle-tree has a level cached. To maximally reduce the amount of data hashed in computing a markle branch, cache a tree of depth N at level N // 2. level is a list of hashes in the middle of the tree (returned by level()) leaf_hashes are the leaves needed to calculate a partial branch up to level. depth_higher is how much higher level is than the leaves of the tree index is the index in the full list of hashes of the hash whose merkle branch we want. """ if not isinstance(level, list): raise TypeError("level must be a list") if not isinstance(leaf_hashes, list): raise TypeError("leaf_hashes must be a list") leaf_index = (index >> depth_higher) << depth_higher leaf_branch, leaf_root = Merkle.branch_and_root( leaf_hashes, index - leaf_index, depth_higher) index >>= depth_higher level_branch, root = Merkle.branch_and_root(level, index) # Check last so that we know index is in-range if leaf_root != level[index]: raise ValueError('leaf hashes inconsistent with level') return leaf_branch + level_branch, root class MerkleCache: """A cache to calculate merkle branches efficiently.""" def __init__(self, merkle, source_func): """Initialise a cache hashes taken from source_func: async def source_func(index, count): ... """ self.merkle = merkle self.source_func = source_func self.length = 0 self.depth_higher = 0 self.initialized = Event() def _segment_length(self): return 1 << self.depth_higher def _leaf_start(self, index): """Given a level's depth higher and a hash index, return the leaf index and leaf hash count needed to calculate a merkle branch. """ depth_higher = self.depth_higher return (index >> depth_higher) << depth_higher def _level(self, hashes): return self.merkle.level(hashes, self.depth_higher) async def _extend_to(self, length): """Extend the length of the cache if necessary.""" if length <= self.length: return # Start from the beginning of any final partial segment. # Retain the value of depth_higher; in practice this is fine start = self._leaf_start(self.length) hashes = await self.source_func(start, length - start) self.level[start >> self.depth_higher:] = self._level(hashes) self.length = length async def _level_for(self, length): """Return a (level_length, final_hash) pair for a truncation of the hashes to the given length.""" if length == self.length: return self.level level = self.level[:length >> self.depth_higher] leaf_start = self._leaf_start(length) count = min(self._segment_length(), length - leaf_start) hashes = await self.source_func(leaf_start, count) level += self._level(hashes) return level async def initialize(self, length): """Call to initialize the cache to a source of given length.""" self.length = length self.depth_higher = self.merkle.tree_depth(length) // 2 self.level = self._level(await self.source_func(0, length)) self.initialized.set() def truncate(self, length): """Truncate the cache so it covers no more than length underlying hashes.""" if not isinstance(length, int): raise TypeError('length must be an integer') if length <= 0: raise ValueError('length must be positive') if length >= self.length: return length = self._leaf_start(length) self.length = length self.level[length >> self.depth_higher:] = [] async def branch_and_root(self, length, index): """Return a merkle branch and root. Length is the number of hashes used to calculate the merkle root, index is the position of the hash to calculate the branch of. index must be less than length, which must be at least 1.""" if not isinstance(length, int): raise TypeError('length must be an integer') if not isinstance(index, int): raise TypeError('index must be an integer') if length <= 0: raise ValueError('length must be positive') if index >= length: raise ValueError('index must be less than length') await self.initialized.wait() await self._extend_to(length) leaf_start = self._leaf_start(index) count = min(self._segment_length(), length - leaf_start) leaf_hashes = await self.source_func(leaf_start, count) if length < self._segment_length(): return self.merkle.branch_and_root(leaf_hashes, index) level = await self._level_for(length) return self.merkle.branch_and_root_from_level( level, leaf_hashes, index, self.depth_higher) class FastMerkleCacheItem: __slots__ = ['tree', 'root_hash'] def __init__(self, tx_hashes: typing.List[bytes]): self.tree: typing.List[typing.List[bytes]] = [] self.root_hash = self._walk_merkle(tx_hashes, self.tree.append) @staticmethod def _walk_merkle(items: typing.List[bytes], append_layer) -> bytes: if len(items) == 1: return items[0] append_layer(items) layer = [ double_sha256(items[index] + items[index]) if index + 1 == len(items) else double_sha256(items[index] + items[index + 1]) for index in range(0, len(items), 2) ] return FastMerkleCacheItem._walk_merkle(layer, append_layer) def branch(self, tx_position: int) -> typing.List[str]: return [ (layer[-1] if (tx_position >> shift) ^ 1 == len(layer) else layer[(tx_position >> shift) ^ 1])[::-1].hex() for shift, layer in enumerate(self.tree) ]