diff --git a/lbry/lbry/wallet/header.py b/lbry/lbry/wallet/header.py index 6ef0f0213..162edd98e 100644 --- a/lbry/lbry/wallet/header.py +++ b/lbry/lbry/wallet/header.py @@ -15,6 +15,7 @@ class Headers(BaseHeaders): max_target = 0x0000ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff genesis_hash = b'9c89283ba0f3227f6c03b70216b9f665f0118d5e0fa729cedf4fb34d6a34f463' target_timespan = 150 + checkpoint = (600_000, b'100b33ca3d0b86a48f0d6d6f30458a130ecb89d5affefe4afccb134d5a40f4c2') @property def claim_trie_root(self): diff --git a/torba/tests/client_tests/unit/test_headers.py b/torba/tests/client_tests/unit/test_headers.py index 076edafed..c869a0556 100644 --- a/torba/tests/client_tests/unit/test_headers.py +++ b/torba/tests/client_tests/unit/test_headers.py @@ -1,8 +1,9 @@ import asyncio import os import tempfile -from urllib.request import Request, urlopen +from binascii import hexlify +from torba.client.hash import sha256 from torba.testcase import AsyncioTestCase from torba.coin.bitcoinsegwit import MainHeaders @@ -125,6 +126,21 @@ class BasicHeadersTests(BitcoinHeadersTestCase): await headers.connect(len(headers), self.get_bytes(block_bytes(3001 - 1500), after=block_bytes(1500))) self.assertEqual(headers.height, 3000) + async def test_checkpointed_writer(self): + headers = MainHeaders(':memory:') + headers.checkpoint = 100, hexlify(sha256(self.get_bytes(block_bytes(100)))) + genblocks = lambda start, end: self.get_bytes(block_bytes(end - start), block_bytes(start)) + async with headers.checkpointed_connector() as connector: + connector.connect(0, genblocks(0, 10)) + self.assertEqual(len(headers), 10) + async with headers.checkpointed_connector() as connector: + connector.connect(10, genblocks(10, 100)) + self.assertEqual(len(headers), 100) + headers = MainHeaders(':memory:') + async with headers.checkpointed_connector() as connector: + connector.connect(0, genblocks(0, 300)) + self.assertEqual(len(headers), 300) + async def test_concurrency(self): BLOCKS = 30 headers_temporary_file = tempfile.mktemp() diff --git a/torba/torba/client/baseheader.py b/torba/torba/client/baseheader.py index 5e209fc0b..5a4e0471c 100644 --- a/torba/torba/client/baseheader.py +++ b/torba/torba/client/baseheader.py @@ -1,5 +1,8 @@ +import asyncio +import hashlib import os import logging +from contextlib import asynccontextmanager from io import BytesIO from typing import Optional, Iterator, Tuple from binascii import hexlify @@ -28,6 +31,7 @@ class BaseHeaders: target_timespan: int validate_difficulty: bool = True + checkpoint = None def __init__(self, path) -> None: if path == ':memory:': @@ -99,6 +103,37 @@ class BaseHeaders: return b'0' * 64 return hexlify(double_sha256(header)[::-1]) + @asynccontextmanager + async def checkpointed_connector(self): + buf = BytesIO() + buf.connect = lambda _, headers: buf.write(headers) + try: + yield buf + finally: + await asyncio.sleep(0) + final_height = len(self) + buf.tell() // self.header_size + verifiable_bytes = (self.checkpoint[0] - len(self)) * self.header_size if self.checkpoint else 0 + if verifiable_bytes and final_height >= self.checkpoint[0]: + buf.seek(0) + self.io.seek(0) + h = hashlib.sha256() + h.update(self.io.read()) + h.update(buf.read(verifiable_bytes)) + if h.hexdigest().encode() == self.checkpoint[1]: + buf.seek(0) + self.io.seek(self.bytes_size, os.SEEK_SET) + self.io.write(buf.read(verifiable_bytes)) + self.io.flush() + self._size = None + remaining = buf.read() + buf.seek(0) + buf.write(remaining) + buf.truncate() + else: + log.warning("Checkpoing mismatch, connecting headers through slow method.") + if buf.tell() > 0: + await self.connect(len(self), buf.getvalue()) + async def connect(self, start: int, headers: bytes) -> int: added = 0 bail = False diff --git a/torba/torba/client/baseledger.py b/torba/torba/client/baseledger.py index 41b4191e7..6be7b883b 100644 --- a/torba/torba/client/baseledger.py +++ b/torba/torba/client/baseledger.py @@ -310,12 +310,11 @@ class BaseLedger(metaclass=LedgerRegistry): current = len(self.headers) get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=2000) chunks = [asyncio.ensure_future(get_chunk(height)) for height in range(current, target, 2000)] - for chunk in chunks: - headers = await chunk - if not headers: - continue - headers = headers['hex'] - await self.update_headers(height=len(self.headers), headers=headers, subscription_update=True) + async with self.headers.checkpointed_connector() as connector: + for chunk in chunks: + headers = await chunk + connector.connect(len(self.headers), unhexlify(headers['hex'])) + log.info("Headers sync: %s / %s", connector.tell() // self.headers.header_size, target) async def update_headers(self, height=None, headers=None, subscription_update=False): rewound = 0