forked from LBRYCommunity/lbry-sdk
fix .connect injected method
This commit is contained in:
parent
ea1e24d8f9
commit
6cbc545f84
3 changed files with 10 additions and 12 deletions
|
@ -130,15 +130,15 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
|
|||
headers = MainHeaders(':memory:')
|
||||
headers.checkpoint = 100, hexlify(sha256(self.get_bytes(block_bytes(100))))
|
||||
genblocks = lambda start, end: self.get_bytes(block_bytes(end - start), block_bytes(start))
|
||||
async with headers.checkpointed_connector() as connector:
|
||||
connector.connect(0, genblocks(0, 10))
|
||||
async with headers.checkpointed_connector() as buff:
|
||||
buff.write(genblocks(0, 10))
|
||||
self.assertEqual(len(headers), 10)
|
||||
async with headers.checkpointed_connector() as connector:
|
||||
connector.connect(10, genblocks(10, 100))
|
||||
async with headers.checkpointed_connector() as buff:
|
||||
buff.write(genblocks(10, 100))
|
||||
self.assertEqual(len(headers), 100)
|
||||
headers = MainHeaders(':memory:')
|
||||
async with headers.checkpointed_connector() as connector:
|
||||
connector.connect(0, genblocks(0, 300))
|
||||
async with headers.checkpointed_connector() as buff:
|
||||
buff.write(genblocks(0, 300))
|
||||
self.assertEqual(len(headers), 300)
|
||||
|
||||
async def test_concurrency(self):
|
||||
|
|
|
@ -106,7 +106,6 @@ class BaseHeaders:
|
|||
@asynccontextmanager
|
||||
async def checkpointed_connector(self):
|
||||
buf = BytesIO()
|
||||
buf.connect = lambda _, headers: buf.write(headers)
|
||||
try:
|
||||
yield buf
|
||||
finally:
|
||||
|
|
|
@ -313,20 +313,19 @@ class BaseLedger(metaclass=LedgerRegistry):
|
|||
return max(self.headers.height, self._download_height)
|
||||
|
||||
async def initial_headers_sync(self):
|
||||
target = self.network.remote_height
|
||||
target = self.network.remote_height + 1
|
||||
current = len(self.headers)
|
||||
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True)
|
||||
chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)]
|
||||
total = 0
|
||||
async with self.headers.checkpointed_connector() as connector:
|
||||
async with self.headers.checkpointed_connector() as buffer:
|
||||
for chunk in chunks:
|
||||
headers = await chunk
|
||||
total += len(headers['base64'])
|
||||
connector.connect(
|
||||
len(self.headers),
|
||||
buffer.write(
|
||||
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
|
||||
)
|
||||
self._download_height = len(self.headers) + connector.tell() // self.headers.header_size
|
||||
self._download_height = current + buffer.tell() // self.headers.header_size
|
||||
log.info("Headers sync: %s / %s", self._download_height, target)
|
||||
|
||||
async def update_headers(self, height=None, headers=None, subscription_update=False):
|
||||
|
|
Loading…
Reference in a new issue