[qa] Add a test for merkle proof malleation
This commit is contained in:
parent
ed82f17000
commit
d280617bf5
2 changed files with 69 additions and 0 deletions
|
@ -6,6 +6,8 @@
|
|||
|
||||
from test_framework.test_framework import BitcoinTestFramework
|
||||
from test_framework.util import *
|
||||
from test_framework.mininode import FromHex, ToHex
|
||||
from test_framework.messages import CMerkleBlock
|
||||
|
||||
class MerkleBlockTest(BitcoinTestFramework):
|
||||
def set_test_params(self):
|
||||
|
@ -78,6 +80,27 @@ class MerkleBlockTest(BitcoinTestFramework):
|
|||
# We can't get a proof if we specify transactions from different blocks
|
||||
assert_raises_rpc_error(-5, "Not all transactions found in specified or retrieved block", self.nodes[2].gettxoutproof, [txid1, txid3])
|
||||
|
||||
# Now we'll try tweaking a proof.
|
||||
proof = self.nodes[3].gettxoutproof([txid1, txid2])
|
||||
assert txid1 in self.nodes[0].verifytxoutproof(proof)
|
||||
assert txid2 in self.nodes[1].verifytxoutproof(proof)
|
||||
|
||||
tweaked_proof = FromHex(CMerkleBlock(), proof)
|
||||
|
||||
# Make sure that our serialization/deserialization is working
|
||||
assert txid1 in self.nodes[2].verifytxoutproof(ToHex(tweaked_proof))
|
||||
|
||||
# Check to see if we can go up the merkle tree and pass this off as a
|
||||
# single-transaction block
|
||||
tweaked_proof.txn.nTransactions = 1
|
||||
tweaked_proof.txn.vHash = [tweaked_proof.header.hashMerkleRoot]
|
||||
tweaked_proof.txn.vBits = [True] + [False]*7
|
||||
|
||||
for n in self.nodes:
|
||||
assert not n.verifytxoutproof(ToHex(tweaked_proof))
|
||||
|
||||
# TODO: try more variants, eg transactions at different depths, and
|
||||
# verify that the proofs are invalid
|
||||
|
||||
if __name__ == '__main__':
|
||||
MerkleBlockTest().main()
|
||||
|
|
|
@ -841,6 +841,52 @@ class BlockTransactions():
|
|||
def __repr__(self):
|
||||
return "BlockTransactions(hash=%064x transactions=%s)" % (self.blockhash, repr(self.transactions))
|
||||
|
||||
class CPartialMerkleTree():
|
||||
def __init__(self):
|
||||
self.nTransactions = 0
|
||||
self.vHash = []
|
||||
self.vBits = []
|
||||
self.fBad = False
|
||||
|
||||
def deserialize(self, f):
|
||||
self.nTransactions = struct.unpack("<i", f.read(4))[0]
|
||||
self.vHash = deser_uint256_vector(f)
|
||||
vBytes = deser_string(f)
|
||||
self.vBits = []
|
||||
for i in range(len(vBytes) * 8):
|
||||
self.vBits.append(vBytes[i//8] & (1 << (i % 8)) != 0)
|
||||
|
||||
def serialize(self):
|
||||
r = b""
|
||||
r += struct.pack("<i", self.nTransactions)
|
||||
r += ser_uint256_vector(self.vHash)
|
||||
vBytesArray = bytearray([0x00] * ((len(self.vBits) + 7)//8))
|
||||
for i in range(len(self.vBits)):
|
||||
vBytesArray[i // 8] |= self.vBits[i] << (i % 8)
|
||||
r += ser_string(bytes(vBytesArray))
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return "CPartialMerkleTree(nTransactions=%d, vHash=%s, vBits=%s)" % (self.nTransactions, repr(self.vHash), repr(self.vBits))
|
||||
|
||||
class CMerkleBlock():
|
||||
def __init__(self):
|
||||
self.header = CBlockHeader()
|
||||
self.txn = CPartialMerkleTree()
|
||||
|
||||
def deserialize(self, f):
|
||||
self.header.deserialize(f)
|
||||
self.txn.deserialize(f)
|
||||
|
||||
def serialize(self):
|
||||
r = b""
|
||||
r += self.header.serialize()
|
||||
r += self.txn.serialize()
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return "CMerkleBlock(header=%s, txn=%s)" % (repr(self.header), repr(self.txn))
|
||||
|
||||
|
||||
# Objects that correspond to messages on the wire
|
||||
class msg_version():
|
||||
|
|
Loading…
Reference in a new issue