fix .connect injected method

This commit is contained in:
Victor Shyba 2019-11-14 14:07:52 -03:00 committed by Lex Berezhny
parent ea1e24d8f9
commit 6cbc545f84
3 changed files with 10 additions and 12 deletions

View file

@ -130,15 +130,15 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
headers = MainHeaders(':memory:') headers = MainHeaders(':memory:')
headers.checkpoint = 100, hexlify(sha256(self.get_bytes(block_bytes(100)))) headers.checkpoint = 100, hexlify(sha256(self.get_bytes(block_bytes(100))))
genblocks = lambda start, end: self.get_bytes(block_bytes(end - start), block_bytes(start)) genblocks = lambda start, end: self.get_bytes(block_bytes(end - start), block_bytes(start))
async with headers.checkpointed_connector() as connector: async with headers.checkpointed_connector() as buff:
connector.connect(0, genblocks(0, 10)) buff.write(genblocks(0, 10))
self.assertEqual(len(headers), 10) self.assertEqual(len(headers), 10)
async with headers.checkpointed_connector() as connector: async with headers.checkpointed_connector() as buff:
connector.connect(10, genblocks(10, 100)) buff.write(genblocks(10, 100))
self.assertEqual(len(headers), 100) self.assertEqual(len(headers), 100)
headers = MainHeaders(':memory:') headers = MainHeaders(':memory:')
async with headers.checkpointed_connector() as connector: async with headers.checkpointed_connector() as buff:
connector.connect(0, genblocks(0, 300)) buff.write(genblocks(0, 300))
self.assertEqual(len(headers), 300) self.assertEqual(len(headers), 300)
async def test_concurrency(self): async def test_concurrency(self):

View file

@ -106,7 +106,6 @@ class BaseHeaders:
@asynccontextmanager @asynccontextmanager
async def checkpointed_connector(self): async def checkpointed_connector(self):
buf = BytesIO() buf = BytesIO()
buf.connect = lambda _, headers: buf.write(headers)
try: try:
yield buf yield buf
finally: finally:

View file

@ -313,20 +313,19 @@ class BaseLedger(metaclass=LedgerRegistry):
return max(self.headers.height, self._download_height) return max(self.headers.height, self._download_height)
async def initial_headers_sync(self): async def initial_headers_sync(self):
target = self.network.remote_height target = self.network.remote_height + 1
current = len(self.headers) current = len(self.headers)
get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True) get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=4096, b64=True)
chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)] chunks = [asyncio.create_task(get_chunk(height)) for height in range(current, target, 4096)]
total = 0 total = 0
async with self.headers.checkpointed_connector() as connector: async with self.headers.checkpointed_connector() as buffer:
for chunk in chunks: for chunk in chunks:
headers = await chunk headers = await chunk
total += len(headers['base64']) total += len(headers['base64'])
connector.connect( buffer.write(
len(self.headers),
zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000) zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000)
) )
self._download_height = len(self.headers) + connector.tell() // self.headers.header_size self._download_height = current + buffer.tell() // self.headers.header_size
log.info("Headers sync: %s / %s", self._download_height, target) log.info("Headers sync: %s / %s", self._download_height, target)
async def update_headers(self, height=None, headers=None, subscription_update=False): async def update_headers(self, height=None, headers=None, subscription_update=False):