blockchain reorg handling and overall header refactor

This commit is contained in:
Lex Berezhny 2018-08-16 00:43:38 -04:00
parent 1a5654d50b
commit 133a86cd89
11 changed files with 634 additions and 309 deletions

View file

@ -0,0 +1,17 @@
from orchstr8.testcase import IntegrationTestCase
class BlockchainReorganizationTests(IntegrationTestCase):
VERBOSE = True
async def test(self):
self.assertEqual(self.ledger.headers.height, 200)
await self.blockchain.generate(1)
await self.on_header(201)
self.assertEqual(self.ledger.headers.height, 201)
await self.blockchain.invalidateblock(self.ledger.headers.hash(201).decode())
await self.blockchain.generate(2)
await self.on_header(203)

108
tests/unit/test_headers.py Normal file
View file

@ -0,0 +1,108 @@
import os
from urllib.request import Request, urlopen
from twisted.trial import unittest
from twisted.internet import defer
from torba.coin.bitcoinsegwit import MainHeaders
def block_bytes(blocks):
return blocks * MainHeaders.header_size
class BitcoinHeadersTestCase(unittest.TestCase):
# Download headers instead of storing them in git.
HEADER_URL = 'http://headers.electrum.org/blockchain_headers'
HEADER_FILE = 'bitcoin_headers'
HEADER_BYTES = block_bytes(32260) # 2.6MB
RETARGET_BLOCK = 32256 # difficulty: 1 -> 1.18
def setUp(self):
self.maxDiff = None
self.header_file_name = os.path.join(os.path.dirname(__file__), self.HEADER_FILE)
if not os.path.exists(self.header_file_name):
req = Request(self.HEADER_URL)
req.add_header('Range', 'bytes=0-{}'.format(self.HEADER_BYTES-1))
with urlopen(req) as response, open(self.header_file_name, 'wb') as header_file:
header_file.write(response.read())
if os.path.getsize(self.header_file_name) != self.HEADER_BYTES:
os.remove(self.header_file_name)
raise Exception("Downloaded headers for testing are not the correct number of bytes.")
def get_bytes(self, upto: int = -1, after: int = 0) -> bytes:
with open(self.header_file_name, 'rb') as headers:
headers.seek(after, os.SEEK_SET)
return headers.read(upto)
def get_headers(self, upto: int = -1):
h = MainHeaders(':memory:')
h.io.write(self.get_bytes(upto))
return h
class BasicHeadersTests(BitcoinHeadersTestCase):
def test_serialization(self):
h = self.get_headers()
self.assertEqual(h[0], {
'bits': 486604799,
'block_height': 0,
'merkle_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b',
'nonce': 2083236893,
'prev_block_hash': b'0000000000000000000000000000000000000000000000000000000000000000',
'timestamp': 1231006505,
'version': 1
})
self.assertEqual(h[self.RETARGET_BLOCK-1], {
'bits': 486604799,
'block_height': 32255,
'merkle_root': b'89b4f223789e40b5b475af6483bb05bceda54059e17d2053334b358f6bb310ac',
'nonce': 312762301,
'prev_block_hash': b'000000006baebaa74cecde6c6787c26ee0a616a3c333261bff36653babdac149',
'timestamp': 1262152739,
'version': 1
})
self.assertEqual(h[self.RETARGET_BLOCK], {
'bits': 486594666,
'block_height': 32256,
'merkle_root': b'64b5e5f5a262f47af443a0120609206a3305877693edfe03e994f20a024ab627',
'nonce': 121087187,
'prev_block_hash': b'00000000984f962134a7291e3693075ae03e521f0ee33378ec30a334d860034b',
'timestamp': 1262153464,
'version': 1
})
self.assertEqual(h[self.RETARGET_BLOCK+1], {
'bits': 486594666,
'block_height': 32257,
'merkle_root': b'4d1488981f08b3037878193297dbac701a2054e0f803d4424fe6a4d763d62334',
'nonce': 274675219,
'prev_block_hash': b'000000004f2886a170adb7204cb0c7a824217dd24d11a74423d564c4e0904967',
'timestamp': 1262154352,
'version': 1
})
self.assertEqual(
h.serialize(h[0]),
h.get_raw_header(0)
)
self.assertEqual(
h.serialize(h[self.RETARGET_BLOCK]),
h.get_raw_header(self.RETARGET_BLOCK)
)
@defer.inlineCallbacks
def test_connect_from_genesis_to_3000_past_first_chunk_at_2016(self):
headers = MainHeaders(':memory:')
self.assertEqual(headers.height, -1)
yield headers.connect(0, self.get_bytes(block_bytes(3001)))
self.assertEqual(headers.height, 3000)
@defer.inlineCallbacks
def test_connect_9_blocks_passing_a_retarget_at_32256(self):
retarget = block_bytes(self.RETARGET_BLOCK-5)
headers = self.get_headers(upto=retarget)
remainder = self.get_bytes(after=retarget)
self.assertEqual(headers.height, 32250)
yield headers.connect(len(headers), remainder)
self.assertEqual(headers.height, 32259)

View file

@ -1,11 +1,12 @@
import os
from binascii import hexlify from binascii import hexlify
from twisted.trial import unittest
from twisted.internet import defer from twisted.internet import defer
from torba.coin.bitcoinsegwit import MainNetLedger from torba.coin.bitcoinsegwit import MainNetLedger
from torba.wallet import Wallet from torba.wallet import Wallet
from .test_transaction import get_transaction, get_output from .test_transaction import get_transaction, get_output
from .test_headers import BitcoinHeadersTestCase, block_bytes
class MockNetwork: class MockNetwork:
@ -30,34 +31,40 @@ class MockNetwork:
return defer.succeed(self.transaction[tx_hash]) return defer.succeed(self.transaction[tx_hash])
class MockHeaders: class LedgerTestCase(BitcoinHeadersTestCase):
def __init__(self, ledger):
self.ledger = ledger
self.height = 1
def __len__(self):
return self.height
def __getitem__(self, height):
return {'merkle_root': 'abcd04'}
class MainNetTestLedger(MainNetLedger):
headers_class = MockHeaders
network_name = 'unittest'
def __init__(self):
super().__init__({'db': MainNetLedger.database_class(':memory:')})
class LedgerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.ledger = MainNetTestLedger() super().setUp()
return self.ledger.db.start() self.ledger = MainNetLedger({
'db': MainNetLedger.database_class(':memory:'),
'headers': MainNetLedger.headers_class(':memory:')
})
return self.ledger.db.open()
def tearDown(self): def tearDown(self):
return self.ledger.db.stop() super().tearDown()
return self.ledger.db.close()
def make_header(self, **kwargs):
header = {
'bits': 486604799,
'block_height': 0,
'merkle_root': b'4a5e1e4baab89f3a32518a88c31bc87f618f76673e2cc77ab2127b7afdeda33b',
'nonce': 2083236893,
'prev_block_hash': b'0000000000000000000000000000000000000000000000000000000000000000',
'timestamp': 1231006505,
'version': 1
}
header.update(kwargs)
header['merkle_root'] = header['merkle_root'].ljust(64, b'a')
header['prev_block_hash'] = header['prev_block_hash'].ljust(64, b'0')
return self.ledger.headers.serialize(header)
def add_header(self, **kwargs):
serialized = self.make_header(**kwargs)
self.ledger.headers.io.seek(0, os.SEEK_END)
self.ledger.headers.io.write(serialized)
self.ledger.headers._size = None
class TestSynchronization(LedgerTestCase): class TestSynchronization(LedgerTestCase):
@ -69,11 +76,14 @@ class TestSynchronization(LedgerTestCase):
address_details = yield self.ledger.db.get_address(address) address_details = yield self.ledger.db.get_address(address)
self.assertEqual(address_details['history'], None) self.assertEqual(address_details['history'], None)
self.ledger.headers.height = 3 self.add_header(block_height=0, merkle_root=b'abcd04')
self.add_header(block_height=1, merkle_root=b'abcd04')
self.add_header(block_height=2, merkle_root=b'abcd04')
self.add_header(block_height=3, merkle_root=b'abcd04')
self.ledger.network = MockNetwork([ self.ledger.network = MockNetwork([
{'tx_hash': 'abcd01', 'height': 1}, {'tx_hash': 'abcd01', 'height': 0},
{'tx_hash': 'abcd02', 'height': 2}, {'tx_hash': 'abcd02', 'height': 1},
{'tx_hash': 'abcd03', 'height': 3}, {'tx_hash': 'abcd03', 'height': 2},
], { ], {
'abcd01': hexlify(get_transaction(get_output(1)).raw), 'abcd01': hexlify(get_transaction(get_output(1)).raw),
'abcd02': hexlify(get_transaction(get_output(2)).raw), 'abcd02': hexlify(get_transaction(get_output(2)).raw),
@ -84,7 +94,7 @@ class TestSynchronization(LedgerTestCase):
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03']) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd01', 'abcd02', 'abcd03'])
address_details = yield self.ledger.db.get_address(address) address_details = yield self.ledger.db.get_address(address)
self.assertEqual(address_details['history'], 'abcd01:1:abcd02:2:abcd03:3:') self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:')
self.ledger.network.get_history_called = [] self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = [] self.ledger.network.get_transaction_called = []
@ -92,7 +102,7 @@ class TestSynchronization(LedgerTestCase):
self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, []) self.assertEqual(self.ledger.network.get_transaction_called, [])
self.ledger.network.history.append({'tx_hash': 'abcd04', 'height': 4}) self.ledger.network.history.append({'tx_hash': 'abcd04', 'height': 3})
self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw) self.ledger.network.transaction['abcd04'] = hexlify(get_transaction(get_output(4)).raw)
self.ledger.network.get_history_called = [] self.ledger.network.get_history_called = []
self.ledger.network.get_transaction_called = [] self.ledger.network.get_transaction_called = []
@ -100,4 +110,51 @@ class TestSynchronization(LedgerTestCase):
self.assertEqual(self.ledger.network.get_history_called, [address]) self.assertEqual(self.ledger.network.get_history_called, [address])
self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04']) self.assertEqual(self.ledger.network.get_transaction_called, ['abcd04'])
address_details = yield self.ledger.db.get_address(address) address_details = yield self.ledger.db.get_address(address)
self.assertEqual(address_details['history'], 'abcd01:1:abcd02:2:abcd03:3:abcd04:4:') self.assertEqual(address_details['history'], 'abcd01:0:abcd02:1:abcd03:2:abcd04:3:')
class MocHeaderNetwork:
def __init__(self, responses):
self.responses = responses
def get_headers(self, height, blocks):
return self.responses[height]
class BlockchainReorganizationTests(LedgerTestCase):
@defer.inlineCallbacks
def test_1_block_reorganization(self):
self.ledger.network = MocHeaderNetwork({
20: {'height': 20, 'count': 5, 'hex': hexlify(
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
)},
25: {'height': 25, 'count': 0, 'hex': b''}
})
headers = self.ledger.headers
yield headers.connect(0, self.get_bytes(upto=block_bytes(20)))
self.add_header(block_height=len(headers))
self.assertEqual(headers.height, 20)
yield self.ledger.receive_header([{
'height': 21, 'hex': hexlify(self.make_header(block_height=21))
}])
@defer.inlineCallbacks
def test_3_block_reorganization(self):
self.ledger.network = MocHeaderNetwork({
20: {'height': 20, 'count': 5, 'hex': hexlify(
self.get_bytes(after=block_bytes(20), upto=block_bytes(5))
)},
21: {'height': 21, 'count': 1, 'hex': hexlify(self.make_header(block_height=21))},
22: {'height': 22, 'count': 1, 'hex': hexlify(self.make_header(block_height=22))},
25: {'height': 25, 'count': 0, 'hex': b''}
})
headers = self.ledger.headers
yield headers.connect(0, self.get_bytes(upto=block_bytes(20)))
self.add_header(block_height=len(headers))
self.add_header(block_height=len(headers))
self.add_header(block_height=len(headers))
self.assertEqual(headers.height, 22)
yield self.ledger.receive_header(({
'height': 23, 'hex': hexlify(self.make_header(block_height=23))
},))

60
tests/unit/test_utils.py Normal file
View file

@ -0,0 +1,60 @@
import unittest
from torba.util import ArithUint256
class TestArithUint256(unittest.TestCase):
def test(self):
# https://github.com/bitcoin/bitcoin/blob/master/src/test/arith_uint256_tests.cpp
from_compact = ArithUint256.from_compact
eq = self.assertEqual
eq(from_compact(0).value, 0)
eq(from_compact(0x00123456).value, 0)
eq(from_compact(0x01003456).value, 0)
eq(from_compact(0x02000056).value, 0)
eq(from_compact(0x03000000).value, 0)
eq(from_compact(0x04000000).value, 0)
eq(from_compact(0x00923456).value, 0)
eq(from_compact(0x01803456).value, 0)
eq(from_compact(0x02800056).value, 0)
eq(from_compact(0x03800000).value, 0)
eq(from_compact(0x04800000).value, 0)
# Make sure that we don't generate compacts with the 0x00800000 bit set
uint = ArithUint256(0x80)
eq(uint.compact, 0x02008000)
uint = from_compact(0x01123456)
eq(uint.value, 0x12)
eq(uint.compact, 0x01120000)
uint = from_compact(0x01fedcba)
eq(uint.value, 0x7e)
eq(uint.negative, 0x01fe0000)
uint = from_compact(0x02123456)
eq(uint.value, 0x1234)
eq(uint.compact, 0x02123400)
uint = from_compact(0x03123456)
eq(uint.value, 0x123456)
eq(uint.compact, 0x03123456)
uint = from_compact(0x04123456)
eq(uint.value, 0x12345600)
eq(uint.compact, 0x04123456)
uint = from_compact(0x04923456)
eq(uint.value, 0x12345600)
eq(uint.negative, 0x04923456)
uint = from_compact(0x05009234)
eq(uint.value, 0x92340000)
eq(uint.compact, 0x05009234)
uint = from_compact(0x20123456)
eq(uint.value, 0x1234560000000000000000000000000000000000000000000000000000000000)
eq(uint.compact, 0x20123456)

View file

@ -47,7 +47,7 @@ class SQLiteMixin:
self._db_path = path self._db_path = path
self.db: adbapi.ConnectionPool = None self.db: adbapi.ConnectionPool = None
def start(self): def open(self):
log.info("connecting to database: %s", self._db_path) log.info("connecting to database: %s", self._db_path)
self.db = adbapi.ConnectionPool( self.db = adbapi.ConnectionPool(
'sqlite3', self._db_path, cp_min=1, cp_max=1, check_same_thread=False 'sqlite3', self._db_path, cp_min=1, cp_max=1, check_same_thread=False
@ -56,7 +56,7 @@ class SQLiteMixin:
lambda t: t.executescript(self.CREATE_TABLES_QUERY) lambda t: t.executescript(self.CREATE_TABLES_QUERY)
) )
def stop(self): def close(self):
self.db.close() self.db.close()
return defer.succeed(True) return defer.succeed(True)

View file

@ -1,255 +1,195 @@
import os import os
import struct
import logging import logging
import typing from io import BytesIO
from binascii import unhexlify from typing import Optional, Iterator, Tuple
from binascii import hexlify
from twisted.internet import threads, defer from twisted.internet import threads, defer
from torba.stream import StreamController from torba.stream import StreamController
from torba.util import int_to_hex, rev_hex, hash_encode from torba.util import ArithUint256
from torba.hash import double_sha256, pow_hash from torba.hash import double_sha256
if typing.TYPE_CHECKING:
from torba import baseledger
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class InvalidHeader(Exception):
def __init__(self, height, message):
super().__init__(message)
self.message = message
self.height = height
class BaseHeaders: class BaseHeaders:
header_size = 80 header_size: int
verify_bits_to_target = True chunk_size: int
def __init__(self, ledger: 'baseledger.BaseLedger') -> None: max_target: int
self.ledger = ledger genesis_hash: bytes
target_timespan: int
validate_difficulty: bool = True
def __init__(self, path) -> None:
if path == ':memory:':
self.io = BytesIO()
self.path = path
self._size = None self._size = None
self._on_change_controller = StreamController() self._on_change_controller = StreamController()
self.on_changed = self._on_change_controller.stream self.on_changed = self._on_change_controller.stream
self._header_connect_lock = defer.DeferredLock()
@property def open(self):
def path(self): if self.path != ':memory:':
return os.path.join(self.ledger.path, 'headers') self.io = open(self.path, 'a+b')
return defer.succeed(True)
def touch(self): def close(self):
if not os.path.exists(self.path): self.io.close()
with open(self.path, 'wb'): return defer.succeed(True)
pass
@property @staticmethod
def height(self): def serialize(header: dict) -> bytes:
return len(self)-1 raise NotImplementedError
def hash(self, height=None): @staticmethod
if height is None: def deserialize(height, header):
height = self.height raise NotImplementedError
header = self[height]
return self._hash_header(header)
def sync_read_length(self): def get_next_chunk_target(self, chunk: int) -> ArithUint256:
return os.path.getsize(self.path) // self.header_size return ArithUint256(self.max_target)
def sync_read_header(self, height): def get_next_block_target(self, chunk_target: ArithUint256, previous: Optional[dict],
if 0 <= height < len(self): current: Optional[dict]) -> ArithUint256:
with open(self.path, 'rb') as f: return chunk_target
f.seek(height * self.header_size)
return f.read(self.header_size)
def __len__(self): def __len__(self) -> int:
if self._size is None: if self._size is None:
self._size = self.sync_read_length() self._size = self.io.seek(0, os.SEEK_END) // self.header_size
return self._size return self._size
def __getitem__(self, height): def __bool__(self):
return True
def __getitem__(self, height) -> dict:
assert not isinstance(height, slice), \ assert not isinstance(height, slice), \
"Slicing of header chain has not been implemented yet." "Slicing of header chain has not been implemented yet."
header = self.sync_read_header(height) return self.deserialize(height, self.get_raw_header(height))
return self._deserialize(height, header)
def get_raw_header(self, height) -> bytes:
self.io.seek(height * self.header_size, os.SEEK_SET)
return self.io.read(self.header_size)
@property
def height(self) -> int:
return len(self)-1
def hash(self, height=None) -> bytes:
return self.hash_header(
self.get_raw_header(height or self.height)
)
@staticmethod
def hash_header(header: bytes) -> bytes:
if header is None:
return b'0' * 64
return hexlify(double_sha256(header)[::-1])
@defer.inlineCallbacks @defer.inlineCallbacks
def connect(self, start, headers): def connect(self, start: int, headers: bytes):
yield threads.deferToThread(self._sync_connect, start, headers) added = 0
bail = False
yield self._header_connect_lock.acquire()
try:
for height, chunk in self._iterate_chunks(start, headers):
try:
# validate_chunk() is CPU bound on large chunks
yield threads.deferToThread(self.validate_chunk, height, chunk)
except InvalidHeader as e:
bail = True
chunk = chunk[:(height-e.height)*self.header_size]
written = 0
if chunk:
self.io.seek(height * self.header_size, os.SEEK_SET)
written = self.io.write(chunk) // self.header_size
self.io.truncate()
# .seek()/.write()/.truncate() might also .flush() when needed
# the goal here is mainly to ensure we're definitely flush()'ing
yield threads.deferToThread(self.io.flush)
self._size = None
self._on_change_controller.add(written)
added += written
if bail:
break
finally:
self._header_connect_lock.release()
defer.returnValue(added)
def _sync_connect(self, start, headers): def validate_chunk(self, height, chunk):
previous_header = None previous_hash, previous_header, previous_previous_header = None, None, None
for header in self._iterate_headers(start, headers): if height > 0:
height = header['block_height']
if previous_header is None and height > 0:
previous_header = self[height-1] previous_header = self[height-1]
self._verify_header(height, header, previous_header) previous_hash = self.hash(height-1)
previous_header = header if height > 1:
previous_previous_header = self[height-2]
chunk_target = self.get_next_chunk_target(height // 2016 - 1)
for current_hash, current_header in self._iterate_headers(height, chunk):
block_target = self.get_next_block_target(chunk_target, previous_previous_header, previous_header)
self.validate_header(height, current_hash, current_header, previous_hash, block_target)
previous_previous_header = previous_header
previous_header = current_header
previous_hash = current_hash
with open(self.path, 'r+b') as f: def validate_header(self, height: int, current_hash: bytes,
f.seek(start * self.header_size) header: dict, previous_hash: bytes, target: ArithUint256):
f.write(headers)
f.truncate()
_old_size = self._size if previous_hash is None:
self._size = self.sync_read_length() if self.genesis_hash is not None and self.genesis_hash != current_hash:
change = self._size - _old_size raise InvalidHeader(
log.info( height, "genesis header doesn't match: {} vs expected {}".format(
'%s: added %s header blocks, final height %s', current_hash.decode(), self.genesis_hash.decode())
self.ledger.get_id(), change, self.height
) )
self._on_change_controller.add(change) return
def _iterate_headers(self, height, headers): if header['prev_block_hash'] != previous_hash:
raise InvalidHeader(
height, "previous hash mismatch: {} vs expected {}".format(
header['prev_block_hash'].decode(), previous_hash.decode())
)
if self.validate_difficulty:
if header['bits'] != target.compact:
raise InvalidHeader(
height, "bits mismatch: {} vs expected {}".format(
header['bits'], target.compact)
)
proof_of_work = self.get_proof_of_work(current_hash)
if proof_of_work > target:
raise InvalidHeader(
height, "insufficient proof of work: {} vs target {}".format(
proof_of_work.value, target.value)
)
@staticmethod
def get_proof_of_work(header_hash: bytes) -> ArithUint256:
return ArithUint256(int(b'0x' + header_hash, 16))
def _iterate_chunks(self, height: int, headers: bytes) -> Iterator[Tuple[int, bytes]]:
assert len(headers) % self.header_size == 0
start = 0
end = (self.chunk_size - height % self.chunk_size) * self.header_size
while start < end:
yield height + (start // self.header_size), headers[start:end]
start = end
end = min(len(headers), end + self.chunk_size * self.header_size)
def _iterate_headers(self, height: int, headers: bytes) -> Iterator[Tuple[bytes, dict]]:
assert len(headers) % self.header_size == 0 assert len(headers) % self.header_size == 0
for idx in range(len(headers) // self.header_size): for idx in range(len(headers) // self.header_size):
start, end = idx * self.header_size, (idx + 1) * self.header_size start, end = idx * self.header_size, (idx + 1) * self.header_size
header = headers[start:end] header = headers[start:end]
yield self._deserialize(height+idx, header) yield self.hash_header(header), self.deserialize(height+idx, header)
def _verify_header(self, height, header, previous_header):
previous_hash = self._hash_header(previous_header)
assert previous_hash == header['prev_block_hash'], \
"prev hash mismatch: {} vs {}".format(previous_hash, header['prev_block_hash'])
bits, _ = self._calculate_next_work_required(height, previous_header, header)
assert bits == header['bits'], \
"bits mismatch: {} vs {} (hash: {})".format(
bits, header['bits'], self._hash_header(header))
# TODO: FIX ME!!!
#_pow_hash = self._pow_hash_header(header)
#assert int(b'0x' + _pow_hash, 16) <= target, \
# "insufficient proof of work: {} vs target {}".format(
# int(b'0x' + _pow_hash, 16), target)
@staticmethod
def _serialize(header):
return b''.join([
int_to_hex(header['version'], 4),
rev_hex(header['prev_block_hash']),
rev_hex(header['merkle_root']),
int_to_hex(int(header['timestamp']), 4),
int_to_hex(int(header['bits']), 4),
int_to_hex(int(header['nonce']), 4)
])
@staticmethod
def _deserialize(height, header):
version, = struct.unpack('<I', header[:4])
timestamp, bits, nonce = struct.unpack('<III', header[68:80])
return {
'block_height': height,
'version': version,
'prev_block_hash': hash_encode(header[4:36]),
'merkle_root': hash_encode(header[36:68]),
'timestamp': timestamp,
'bits': bits,
'nonce': nonce,
}
def _hash_header(self, header):
if header is None:
return b'0' * 64
return hash_encode(double_sha256(unhexlify(self._serialize(header))))
def _pow_hash_header(self, header):
if header is None:
return b'0' * 64
return hash_encode(pow_hash(unhexlify(self._serialize(header))))
def _calculate_next_work_required(self, height, first, last):
if height == 0:
return self.ledger.genesis_bits, self.ledger.max_target
if self.verify_bits_to_target:
bits = last['bits']
bits_n = (bits >> 24) & 0xff
assert 0x03 <= bits_n <= 0x1d, \
"First part of bits should be in [0x03, 0x1d], but it was {}".format(hex(bits_n))
bits_base = bits & 0xffffff
assert 0x8000 <= bits_base <= 0x7fffff, \
"Second part of bits should be in [0x8000, 0x7fffff] but it was {}".format(bits_base)
# new target
retarget_timespan = self.ledger.target_timespan
n_actual_timespan = last['timestamp'] - first['timestamp']
n_modulated_timespan = retarget_timespan + (n_actual_timespan - retarget_timespan) // 8
n_min_timespan = retarget_timespan - (retarget_timespan // 8)
n_max_timespan = retarget_timespan + (retarget_timespan // 2)
# Limit adjustment step
if n_modulated_timespan < n_min_timespan:
n_modulated_timespan = n_min_timespan
elif n_modulated_timespan > n_max_timespan:
n_modulated_timespan = n_max_timespan
# Retarget
bn_pow_limit = _ArithUint256(self.ledger.max_target)
bn_new = _ArithUint256.set_compact(last['bits'])
bn_new *= n_modulated_timespan
bn_new //= n_modulated_timespan
if bn_new > bn_pow_limit:
bn_new = bn_pow_limit
return bn_new.get_compact(), bn_new._value
class _ArithUint256:
""" See: lbrycrd/src/arith_uint256.cpp """
def __init__(self, value):
self._value = value
def __str__(self):
return hex(self._value)
@staticmethod
def from_compact(n_compact):
"""Convert a compact representation into its value"""
n_size = n_compact >> 24
# the lower 23 bits
n_word = n_compact & 0x007fffff
if n_size <= 3:
return n_word >> 8 * (3 - n_size)
else:
return n_word << 8 * (n_size - 3)
@classmethod
def set_compact(cls, n_compact):
return cls(cls.from_compact(n_compact))
def bits(self):
"""Returns the position of the highest bit set plus one."""
bits = bin(self._value)[2:]
for i, d in enumerate(bits):
if d:
return (len(bits) - i) + 1
return 0
def get_low64(self):
return self._value & 0xffffffffffffffff
def get_compact(self):
"""Convert a value into its compact representation"""
n_size = (self.bits() + 7) // 8
if n_size <= 3:
n_compact = self.get_low64() << 8 * (3 - n_size)
else:
n = _ArithUint256(self._value >> 8 * (n_size - 3))
n_compact = n.get_low64()
# The 0x00800000 bit denotes the sign.
# Thus, if it is already set, divide the mantissa by 256 and increase the exponent.
if n_compact & 0x00800000:
n_compact >>= 8
n_size += 1
assert (n_compact & ~0x007fffff) == 0
assert n_size < 256
n_compact |= n_size << 24
return n_compact
def __mul__(self, x):
# Take the mod because we are limited to an unsigned 256 bit number
return _ArithUint256((self._value * x) % 2 ** 256)
def __ifloordiv__(self, x):
self._value = (self._value // x)
return self
def __gt__(self, x):
return self._value > x._value

View file

@ -8,10 +8,10 @@ from collections import namedtuple
from twisted.internet import defer from twisted.internet import defer
from torba import baseaccount from torba import baseaccount
from torba import basedatabase
from torba import baseheader
from torba import basenetwork from torba import basenetwork
from torba import basetransaction from torba import basetransaction
from torba.basedatabase import BaseDatabase
from torba.baseheader import BaseHeaders, InvalidHeader
from torba.coinselection import CoinSelector from torba.coinselection import CoinSelector
from torba.constants import COIN, NULL_HASH32 from torba.constants import COIN, NULL_HASH32
from torba.stream import StreamController from torba.stream import StreamController
@ -50,13 +50,13 @@ class BaseLedger(metaclass=LedgerRegistry):
symbol: str symbol: str
network_name: str network_name: str
database_class = BaseDatabase
account_class = baseaccount.BaseAccount account_class = baseaccount.BaseAccount
database_class = basedatabase.BaseDatabase
headers_class = baseheader.BaseHeaders
network_class = basenetwork.BaseNetwork network_class = basenetwork.BaseNetwork
transaction_class = basetransaction.BaseTransaction transaction_class = basetransaction.BaseTransaction
secret_prefix = None headers_class: Type[BaseHeaders]
pubkey_address_prefix: bytes pubkey_address_prefix: bytes
script_address_prefix: bytes script_address_prefix: bytes
extended_public_key_prefix: bytes extended_public_key_prefix: bytes
@ -66,14 +66,16 @@ class BaseLedger(metaclass=LedgerRegistry):
def __init__(self, config=None): def __init__(self, config=None):
self.config = config or {} self.config = config or {}
self.db = self.config.get('db') or self.database_class( self.db: BaseDatabase = self.config.get('db') or self.database_class(
os.path.join(self.path, "blockchain.db") os.path.join(self.path, "blockchain.db")
) # type: basedatabase.BaseDatabase )
self.headers: BaseHeaders = self.config.get('headers') or self.headers_class(
os.path.join(self.path, "headers")
)
self.network = self.config.get('network') or self.network_class(self) self.network = self.config.get('network') or self.network_class(self)
self.network.on_header.listen(self.process_header) self.network.on_header.listen(self.receive_header)
self.network.on_status.listen(self.process_status) self.network.on_status.listen(self.receive_status)
self.accounts = [] self.accounts = []
self.headers = self.config.get('headers') or self.headers_class(self)
self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte) self.fee_per_byte: int = self.config.get('fee_per_byte', self.default_fee_per_byte)
self._on_transaction_controller = StreamController() self._on_transaction_controller = StreamController()
@ -87,6 +89,12 @@ class BaseLedger(metaclass=LedgerRegistry):
self._on_header_controller = StreamController() self._on_header_controller = StreamController()
self.on_header = self._on_header_controller.stream self.on_header = self._on_header_controller.stream
self.on_header.listen(
lambda change: log.info(
'%s: added %s header blocks, final height %s',
self.get_id(), change, self.headers.height
)
)
self._transaction_processing_locks = {} self._transaction_processing_locks = {}
self._utxo_reservation_lock = defer.DeferredLock() self._utxo_reservation_lock = defer.DeferredLock()
@ -209,11 +217,13 @@ class BaseLedger(metaclass=LedgerRegistry):
def start(self): def start(self):
if not os.path.exists(self.path): if not os.path.exists(self.path):
os.mkdir(self.path) os.mkdir(self.path)
yield self.db.start() yield defer.gatherResults([
self.db.open(),
self.headers.open()
])
first_connection = self.network.on_connected.first first_connection = self.network.on_connected.first
self.network.start() self.network.start()
yield first_connection yield first_connection
self.headers.touch()
yield self.update_headers() yield self.update_headers()
yield self.network.subscribe_headers() yield self.network.subscribe_headers()
yield self.update_accounts() yield self.update_accounts()
@ -221,30 +231,69 @@ class BaseLedger(metaclass=LedgerRegistry):
@defer.inlineCallbacks @defer.inlineCallbacks
def stop(self): def stop(self):
yield self.network.stop() yield self.network.stop()
yield self.db.stop() yield self.db.close()
yield self.headers.close()
@defer.inlineCallbacks @defer.inlineCallbacks
def update_headers(self): def update_headers(self, height=None, headers=None, count=1, subscription_update=False):
rewound = 0
while True: while True:
height_sought = len(self.headers)
headers = yield self.network.get_headers(height_sought, 2000) height = len(self.headers) if height is None else height
if headers['count'] <= 0: if headers is None:
break header_response = yield self.network.get_headers(height, 2001)
yield self.headers.connect(height_sought, unhexlify(headers['hex'])) count = header_response['count']
self._on_header_controller.add(self.headers.height) headers = header_response['hex']
if count <= 0:
return
added = yield self.headers.connect(height, unhexlify(headers))
if added > 0:
self._on_header_controller.add(added)
if subscription_update and added == count:
# subscription updates are for latest header already
# so we don't need to check if there are newer / more
return
if added == 0:
# headers were invalid, start rewinding
height -= 1
rewound += 1
log.warning("Experiencing Blockchain Reorganization: Undoing header.")
else:
# added all headers, see if there are more
height += added
if height < 0:
raise IndexError(
"Blockchain reorganization rewound all the way back to genesis hash. "
"Something is very wrong. Maybe you are on the wrong blockchain?"
)
if rewound >= 50:
raise IndexError(
"Blockchain reorganization dropped {} headers. This is highly unusual. "
"Will not continue to attempt reorganizing."
.format(rewound)
)
headers = None
# if we made it this far and this was a subscription_update
# it means something was wrong and now we're doing a more
# robust sync, turn off subscription update shortcut
subscription_update = False
@defer.inlineCallbacks @defer.inlineCallbacks
def process_header(self, response): def receive_header(self, response):
yield self._header_processing_lock.acquire() yield self._header_processing_lock.acquire()
try: try:
header = response[0] header = response[0]
if header['height'] == len(self.headers): yield self.update_headers(
# New header from network directly connects after the last local header. height=header['height'], headers=header['hex'], subscription_update=True
yield self.headers.connect(len(self.headers), unhexlify(header['hex'])) )
self._on_header_controller.add(self.headers.height)
elif header['height'] > len(self.headers):
# New header is several heights ahead of local, do download instead.
yield self.update_headers()
finally: finally:
self._header_processing_lock.release() self._header_processing_lock.release()
@ -338,7 +387,7 @@ class BaseLedger(metaclass=LedgerRegistry):
yield self.update_history(address) yield self.update_history(address)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_status(self, response): def receive_status(self, response):
address, remote_status = response address, remote_status = response
local_status = yield self.get_local_status(address) local_status = yield self.get_local_status(address)
if local_status != remote_status: if local_status != remote_status:

View file

@ -6,15 +6,60 @@ __node_url__ = (
) )
__electrumx__ = 'electrumx.lib.coins.BitcoinSegwitRegtest' __electrumx__ = 'electrumx.lib.coins.BitcoinSegwitRegtest'
from binascii import unhexlify import struct
from binascii import hexlify, unhexlify
from torba.baseledger import BaseLedger from torba.baseledger import BaseLedger
from torba.baseheader import BaseHeaders from torba.baseheader import BaseHeaders, ArithUint256
class MainHeaders(BaseHeaders):
header_size = 80
chunk_size = 2016
max_target = 0x00000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffff
genesis_hash = b'000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f'
target_timespan = 14 * 24 * 60 * 60
@staticmethod
def serialize(header: dict) -> bytes:
return b''.join([
struct.pack('<I', header['version']),
unhexlify(header['prev_block_hash'])[::-1],
unhexlify(header['merkle_root'])[::-1],
struct.pack('<III', header['timestamp'], header['bits'], header['nonce'])
])
@staticmethod
def deserialize(height, header):
version, = struct.unpack('<I', header[:4])
timestamp, bits, nonce = struct.unpack('<III', header[68:80])
return {
'block_height': height,
'version': version,
'prev_block_hash': hexlify(header[4:36][::-1]),
'merkle_root': hexlify(header[36:68][::-1]),
'timestamp': timestamp,
'bits': bits,
'nonce': nonce
}
def get_next_chunk_target(self, chunk: int) -> ArithUint256:
if chunk == -1:
return ArithUint256(self.max_target)
previous = self[chunk * 2016]
current = self[chunk * 2016 + 2015]
actual_timespan = current['timestamp'] - previous['timestamp']
actual_timespan = max(actual_timespan, int(self.target_timespan / 4))
actual_timespan = min(actual_timespan, self.target_timespan * 4)
target = ArithUint256.from_compact(current['bits'])
new_target = min(ArithUint256(self.max_target), (target * actual_timespan) / self.target_timespan)
return new_target
class MainNetLedger(BaseLedger): class MainNetLedger(BaseLedger):
name = 'BitcoinSegwit' name = 'BitcoinSegwit'
symbol = 'BTC' symbol = 'BTC'
network_name = 'mainnet' network_name = 'mainnet'
headers_class = MainHeaders
pubkey_address_prefix = bytes((0,)) pubkey_address_prefix = bytes((0,))
script_address_prefix = bytes((5,)) script_address_prefix = bytes((5,))
@ -24,20 +69,17 @@ class MainNetLedger(BaseLedger):
default_fee_per_byte = 50 default_fee_per_byte = 50
class UnverifiedHeaders(BaseHeaders): class UnverifiedHeaders(MainHeaders):
verify_bits_to_target = False max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
genesis_hash = None
validate_difficulty = False
class RegTestLedger(MainNetLedger): class RegTestLedger(MainNetLedger):
headers_class = UnverifiedHeaders
network_name = 'regtest' network_name = 'regtest'
headers_class = UnverifiedHeaders
pubkey_address_prefix = bytes((111,)) pubkey_address_prefix = bytes((111,))
script_address_prefix = bytes((196,)) script_address_prefix = bytes((196,))
extended_public_key_prefix = unhexlify('043587cf') extended_public_key_prefix = unhexlify('043587cf')
extended_private_key_prefix = unhexlify('04358394') extended_private_key_prefix = unhexlify('04358394')
max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
genesis_hash = '0f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206'
genesis_bits = 0x207fffff
target_timespan = 1

View file

@ -79,14 +79,6 @@ def ripemd160(x):
return h.digest() return h.digest()
def pow_hash(x):
h = sha512(double_sha256(x))
return double_sha256(
ripemd160(h[:len(h) // 2]) +
ripemd160(h[len(h) // 2:])
)
def double_sha256(x): def double_sha256(x):
""" SHA-256 of SHA-256, as used extensively in bitcoin. """ """ SHA-256 of SHA-256, as used extensively in bitcoin. """
return sha256(sha256(x)) return sha256(sha256(x))

View file

@ -45,19 +45,78 @@ def int_to_bytes(value):
return unhexlify(('0' * (len(s) % 2) + s).zfill(length * 2)) return unhexlify(('0' * (len(s) % 2) + s).zfill(length * 2))
def rev_hex(s): class ArithUint256:
return hexlify(unhexlify(s)[::-1]) # https://github.com/bitcoin/bitcoin/blob/master/src/arith_uint256.cpp
__slots__ = '_value', '_compact'
def int_to_hex(i, length=1): def __init__(self, value: int) -> None:
s = hex(i)[2:].rstrip('L') self._value = value
s = "0" * (2 * length - len(s)) + s self._compact = None
return rev_hex(s)
@classmethod
def from_compact(cls, compact) -> 'ArithUint256':
size = compact >> 24
word = compact & 0x007fffff
if size <= 3:
return cls(word >> 8 * (3 - size))
else:
return cls(word << 8 * (size - 3))
def hex_to_int(x): @property
return int(b'0x' + hexlify(x[::-1]), 16) def value(self) -> int:
return self._value
@property
def compact(self) -> int:
if self._compact is None:
self._compact = self._calculate_compact()
return self._compact
def hash_encode(x): @property
return hexlify(x[::-1]) def negative(self) -> int:
return self._calculate_compact(negative=True)
@property
def bits(self) -> int:
""" Returns the position of the highest bit set plus one. """
bn = bin(self._value)[2:]
for i, d in enumerate(bn):
if d:
return (len(bn) - i) + 1
return 0
@property
def low64(self) -> int:
return self._value & 0xffffffffffffffff
def _calculate_compact(self, negative=False) -> int:
size = (self.bits + 7) // 8
if size <= 3:
compact = self.low64 << 8 * (3 - size)
else:
compact = ArithUint256(self._value >> 8 * (size - 3)).low64
# The 0x00800000 bit denotes the sign.
# Thus, if it is already set, divide the mantissa by 256 and increase the exponent.
if compact & 0x00800000:
compact >>= 8
size += 1
assert (compact & ~0x007fffff) == 0
assert size < 256
compact |= size << 24
if negative and compact & 0x007fffff:
compact |= 0x00800000
return compact
def __mul__(self, x):
# Take the mod because we are limited to an unsigned 256 bit number
return ArithUint256((self._value * x) % 2 ** 256)
def __truediv__(self, x):
return ArithUint256(int(self._value / x))
def __gt__(self, other):
return self._value > other
def __lt__(self, other):
return self._value < other

View file

@ -19,4 +19,5 @@ setenv =
commands = commands =
unit: coverage run -p --source={envsitepackagesdir}/torba -m twisted.trial unit unit: coverage run -p --source={envsitepackagesdir}/torba -m twisted.trial unit
integration: orchstr8 download integration: orchstr8 download
integration: coverage run -p --source={envsitepackagesdir}/torba -m twisted.trial --reactor=asyncio integration integration: coverage run -p --source={envsitepackagesdir}/torba -m twisted.trial --reactor=asyncio integration.test_transactions
integration: coverage run -p --source={envsitepackagesdir}/torba -m twisted.trial --reactor=asyncio integration.test_blockchain_reorganization