fix .connect injected method
This commit is contained in:
parent
ea1e24d8f9
commit
6cbc545f84
3 changed files with 10 additions and 12 deletions
|
@ -130,15 +130,15 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
|
||||||
headers = MainHeaders(':memory:')
|
headers = MainHeaders(':memory:')
|
||||||
headers.checkpoint = 100, hexlify(sha256(self.get_bytes(block_bytes(100))))
|
headers.checkpoint = 100, hexlify(sha256(self.get_bytes(block_bytes(100))))
|
||||||
genblocks = lambda start, end: self.get_bytes(block_bytes(end - start), block_bytes(start))
|
genblocks = lambda start, end: self.get_bytes(block_bytes(end - start), block_bytes(start))
|
||||||
async with headers.checkpointed_connector() as connector:
|
async with headers.checkpointed_connector() as buff:
|
||||||
connector.connect(0, genblocks(0, 10))
|
buff.write(genblocks(0, 10))
|
||||||
self.assertEqual(len(headers), 10)
|
self.assertEqual(len(headers), 10)
|
||||||
async with headers.checkpointed_connector() as connector:
|
async with headers.checkpointed_connector() as buff:
|
||||||
connector.connect(10, genblocks(10, 100))
|
buff.write(genblocks(10, 100))
|
||||||
self.assertEqual(len(headers), 100)
|
self.assertEqual(len(headers), 100)
|
||||||
headers = MainHeaders(':memory:')
|
headers = MainHeaders(':memory:')
|
||||||
async with headers.checkpointed_connector() as connector:
|
async with headers.checkpointed_connector() as buff:
|
||||||
connector.connect(0, genblocks(0, 300))
|
buff.write(genblocks(0, 300))
|
||||||
self.assertEqual(len(headers), 300)
|
self.assertEqual(len(headers), 300)
|
||||||
|
|
||||||
async def test_concurrency(self):
|
async def test_concurrency(self):
|
||||||
|
|
|
@ -106,7 +106,6 @@ class BaseHeaders:
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def checkpointed_connector(self):
|
async def checkpointed_connector(self):
|
||||||
buf = BytesIO()
|
buf = BytesIO()
|
||||||
buf.connect = lambda _, headers: buf.write(headers)
|
|
||||||
try:
|
try:
|
||||||
yield buf
|
yield buf
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -313,20 +313,19 @@ class BaseLedger(metaclass=LedgerRegistry):
|
||||||
return max(self.headers.height, self._download_height)
|
return max(self.headers.height, self._download_height)
|
||||||
|
|
||||||
async def initial_headers_sync(self):
|
async def initial_headers_sync(self):
|
||||||
target = self.network.remote_height
|
target = self.network.remote_height + 1
|
||||||
current = len(self.headers)
|
current = len(self.headers)
|
||||||
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True)
|
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True)
|
||||||
chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)]
|
chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)]
|
||||||
total = 0
|
total = 0
|
||||||
async with self.headers.checkpointed_connector() as connector:
|
async with self.headers.checkpointed_connector() as buffer:
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
headers = await chunk
|
headers = await chunk
|
||||||
total += len(headers['base64'])
|
total += len(headers['base64'])
|
||||||
connector.connect(
|
buffer.write(
|
||||||
len(self.headers),
|
|
||||||
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
|
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
|
||||||
)
|
)
|
||||||
self._download_height = len(self.headers) + connector.tell() // self.headers.header_size
|
self._download_height = current + buffer.tell() // self.headers.header_size
|
||||||
log.info("Headers sync: %s / %s", self._download_height, target)
|
log.info("Headers sync: %s / %s", self._download_height, target)
|
||||||
|
|
||||||
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||||
|
|
Loading…
Reference in a new issue