diff --git a/torba/tests/client_tests/unit/test_headers.py b/torba/tests/client_tests/unit/test_headers.py index c869a0556..5f43ca71f 100644 --- a/torba/tests/client_tests/unit/test_headers.py +++ b/torba/tests/client_tests/unit/test_headers.py @@ -130,15 +130,15 @@ class BasicHeadersTests(BitcoinHeadersTestCase): headers = MainHeaders(':memory:') headers.checkpoint = 100, hexlify(sha256(self.get_bytes(block_bytes(100)))) genblocks = lambda start, end: self.get_bytes(block_bytes(end - start), block_bytes(start)) - async with headers.checkpointed_connector() as connector: - connector.connect(0, genblocks(0, 10)) + async with headers.checkpointed_connector() as buff: + buff.write(genblocks(0, 10)) self.assertEqual(len(headers), 10) - async with headers.checkpointed_connector() as connector: - connector.connect(10, genblocks(10, 100)) + async with headers.checkpointed_connector() as buff: + buff.write(genblocks(10, 100)) self.assertEqual(len(headers), 100) headers = MainHeaders(':memory:') - async with headers.checkpointed_connector() as connector: - connector.connect(0, genblocks(0, 300)) + async with headers.checkpointed_connector() as buff: + buff.write(genblocks(0, 300)) self.assertEqual(len(headers), 300) async def test_concurrency(self): diff --git a/torba/torba/client/baseheader.py b/torba/torba/client/baseheader.py index ceeb97812..338c5ec8a 100644 --- a/torba/torba/client/baseheader.py +++ b/torba/torba/client/baseheader.py @@ -106,7 +106,6 @@ class BaseHeaders: @asynccontextmanager async def checkpointed_connector(self): buf = BytesIO() - buf.connect = lambda _, headers: buf.write(headers) try: yield buf finally: diff --git a/torba/torba/client/baseledger.py b/torba/torba/client/baseledger.py index 531c9f0a5..f19a35d58 100644 --- a/torba/torba/client/baseledger.py +++ b/torba/torba/client/baseledger.py @@ -313,20 +313,19 @@ class BaseLedger(metaclass=LedgerRegistry): return max(self.headers.height, self._download_height) async def initial_headers_sync(self): - target = self.network.remote_height + target = self.network.remote_height + 1 current = len(self.headers) get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True) chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)] total = 0 - async with self.headers.checkpointed_connector() as connector: + async with self.headers.checkpointed_connector() as buffer: for chunk in chunks: headers = await chunk total += len(headers['base64']) - connector.connect( - len(self.headers), + buffer.write( zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000) ) - self._download_height = len(self.headers) + connector.tell() // self.headers.header_size + self._download_height = current + buffer.tell() // self.headers.header_size log.info("Headers sync: %s / %s", self._download_height, target) async def update_headers(self, height=None, headers=None, subscription_update=False):