refactor and clean up

This commit is contained in:
Victor Shyba 2019-07-12 19:54:04 -03:00 committed by Lex Berezhny
parent d0607b6fec
commit d9460dcd9e
6 changed files with 48 additions and 46 deletions
lbry/lbry/extras/daemon
torba
tests/client_tests
torba/client

View file

@ -163,3 +163,6 @@ class ComponentManager:
if component.component_name == component_name: if component.component_name == component_name:
return component.component return component.component
raise NameError(component_name) raise NameError(component_name)
def has_component(self, component_name):
return any(component.component_name == component_name for component in self.components)

View file

@ -109,31 +109,30 @@ class HeadersComponent(Component):
self.headers_file = os.path.join(self.headers_dir, 'headers') self.headers_file = os.path.join(self.headers_dir, 'headers')
self.old_file = os.path.join(self.conf.wallet_dir, 'blockchain_headers') self.old_file = os.path.join(self.conf.wallet_dir, 'blockchain_headers')
self.headers = Headers(self.headers_file) self.headers = Headers(self.headers_file)
self._downloading_headers = None self.is_downloading_headers = False
self._headers_progress_percent = 0 self._headers_progress_percent = 0
@property @property
def component(self): def component(self):
return self return self
def _round_progress(self, local_height, remote_height):
return min(max(math.ceil(float(local_height) / float(remote_height) * 100), 0), 100)
async def get_status(self): async def get_status(self):
if self._downloading_headers: progress = None
if self.is_downloading_headers:
progress = self._headers_progress_percent progress = self._headers_progress_percent
else: elif self.component_manager.has_component(WALLET_COMPONENT):
try: wallet_manager = self.component_manager.get_component(WALLET_COMPONENT)
wallet_manager = self.component_manager.get_component(WALLET_COMPONENT) if wallet_manager and wallet_manager.ledger.network.remote_height > 0:
if wallet_manager and wallet_manager.ledger.network.remote_height > 0: local_height = wallet_manager.ledger.headers.height
local_height = wallet_manager.ledger.headers.height remote_height = wallet_manager.ledger.network.remote_height
remote_height = wallet_manager.ledger.network.remote_height progress = self._round_progress(local_height, remote_height)
progress = max(math.ceil(float(local_height) / float(remote_height) * 100), 0)
else:
return {}
except NameError:
return {}
return { return {
'downloading_headers': True, 'downloading_headers': True,
'download_progress': progress 'download_progress': progress
} if progress < 100 else {} } if progress is not None and progress < 100 else {}
async def fetch_headers_from_s3(self): async def fetch_headers_from_s3(self):
local_header_size = self.headers.bytes_size local_header_size = self.headers.bytes_size
@ -158,8 +157,8 @@ class HeadersComponent(Component):
if not await self.headers.connect(len(self.headers), chunk): if not await self.headers.connect(len(self.headers), chunk):
log.warning("Error connecting downloaded headers from at %s.", self.headers.height) log.warning("Error connecting downloaded headers from at %s.", self.headers.height)
return return
self._headers_progress_percent = math.ceil( self._headers_progress_percent = self._round_progress(
float(self.headers.bytes_size) / float(final_size_after_download) * 100 self.headers.bytes_size, final_size_after_download
) )
def local_header_file_size(self): def local_header_file_size(self):
@ -167,12 +166,14 @@ class HeadersComponent(Component):
return os.stat(self.headers_file).st_size return os.stat(self.headers_file).st_size
return 0 return 0
async def get_download_height(self): async def get_downloadable_header_height(self) -> typing.Optional[int]:
async with utils.aiohttp_request('HEAD', HEADERS_URL) as response: try:
if response.status != 200: async with utils.aiohttp_request('HEAD', HEADERS_URL) as response:
log.warning("Header download error: %s", response.status) if response.status != 200:
return 0 log.warning("Header download error, unexpected response code: %s", response.status)
return response.content_length // HEADER_SIZE return response.content_length // HEADER_SIZE
except OSError:
log.exception("Failed to download headers using https.")
async def should_download_headers_from_s3(self): async def should_download_headers_from_s3(self):
if self.conf.blockchain_name != "lbrycrd_main": if self.conf.blockchain_name != "lbrycrd_main":
@ -182,14 +183,11 @@ class HeadersComponent(Component):
return False return False
local_height = self.headers.height local_height = self.headers.height
try: remote_height = await self.get_downloadable_header_height()
remote_height = await self.get_download_height() if remote_height is not None:
except OSError: log.info("remote height: %i, local height: %i", remote_height, local_height)
log.warning("Failed to download headers using https.") if remote_height > (local_height + s3_headers_depth):
return False return True
log.info("remote height: %i, local height: %i", remote_height, local_height)
if remote_height > (local_height + s3_headers_depth):
return True
return False return False
async def start(self): async def start(self):
@ -199,15 +197,15 @@ class HeadersComponent(Component):
log.warning("Moving old headers from %s to %s.", self.old_file, self.headers_file) log.warning("Moving old headers from %s to %s.", self.old_file, self.headers_file)
os.rename(self.old_file, self.headers_file) os.rename(self.old_file, self.headers_file)
await self.headers.open() await self.headers.open()
self.headers.repair() await self.headers.repair()
self._downloading_headers = await self.should_download_headers_from_s3() if await self.should_download_headers_from_s3():
if self._downloading_headers:
try: try:
self.is_downloading_headers = True
await self.fetch_headers_from_s3() await self.fetch_headers_from_s3()
except Exception as err: except Exception as err:
log.error("failed to fetch headers from s3: %s", err) log.error("failed to fetch headers from s3: %s", err)
finally: finally:
self._downloading_headers = False self.is_downloading_headers = False
await self.headers.close() await self.headers.close()
async def stop(self): async def stop(self):
@ -401,10 +399,8 @@ class StreamManagerComponent(Component):
blob_manager = self.component_manager.get_component(BLOB_COMPONENT) blob_manager = self.component_manager.get_component(BLOB_COMPONENT)
storage = self.component_manager.get_component(DATABASE_COMPONENT) storage = self.component_manager.get_component(DATABASE_COMPONENT)
wallet = self.component_manager.get_component(WALLET_COMPONENT) wallet = self.component_manager.get_component(WALLET_COMPONENT)
try: node = self.component_manager.get_component(DHT_COMPONENT)\
node = self.component_manager.get_component(DHT_COMPONENT) if self.component_manager.has_component(DHT_COMPONENT) else None
except NameError:
node = None
log.info('Starting the file manager') log.info('Starting the file manager')
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.stream_manager = StreamManager( self.stream_manager = StreamManager(

View file

@ -56,10 +56,14 @@ class ServerPickingTestCase(AsyncioTestCase):
async def _make_fake_server(self, latency=1.0, port=1337): async def _make_fake_server(self, latency=1.0, port=1337):
# local fake server with artificial latency # local fake server with artificial latency
proto = RPCSession() proto = RPCSession()
proto.handle_request = lambda _: asyncio.sleep(latency) async def __handler(_):
await asyncio.sleep(latency)
return {'height': 1}
proto.handle_request = __handler
server = await self.loop.create_server(lambda: proto, host='127.0.0.1', port=port) server = await self.loop.create_server(lambda: proto, host='127.0.0.1', port=port)
self.addCleanup(server.close) self.addCleanup(server.close)
return ('127.0.0.1', port) return '127.0.0.1', port
async def test_pick_fastest(self): async def test_pick_fastest(self):
ledger = Mock(config={ ledger = Mock(config={

View file

@ -98,19 +98,19 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
headers = MainHeaders(':memory:') headers = MainHeaders(':memory:')
await headers.connect(0, self.get_bytes(block_bytes(3001))) await headers.connect(0, self.get_bytes(block_bytes(3001)))
self.assertEqual(headers.height, 3000) self.assertEqual(headers.height, 3000)
headers.repair() await headers.repair()
self.assertEqual(headers.height, 3000) self.assertEqual(headers.height, 3000)
# corrupt the middle of it # corrupt the middle of it
headers.io.seek(block_bytes(1500)) headers.io.seek(block_bytes(1500))
headers.io.write(b"wtf") headers.io.write(b"wtf")
headers.repair() await headers.repair()
self.assertEqual(headers.height, 1499) self.assertEqual(headers.height, 1499)
self.assertEqual(len(headers), 1500) self.assertEqual(len(headers), 1500)
# corrupt by appending # corrupt by appending
headers.io.seek(block_bytes(len(headers))) headers.io.seek(block_bytes(len(headers)))
headers.io.write(b"appending") headers.io.write(b"appending")
headers._size = None headers._size = None
headers.repair() await headers.repair()
self.assertEqual(headers.height, 1499) self.assertEqual(headers.height, 1499)
await headers.connect(len(headers), self.get_bytes(block_bytes(3001 - 1500), after=block_bytes(1500))) await headers.connect(len(headers), self.get_bytes(block_bytes(3001 - 1500), after=block_bytes(1500)))
self.assertEqual(headers.height, 3000) self.assertEqual(headers.height, 3000)

View file

@ -168,7 +168,7 @@ class BaseHeaders:
proof_of_work.value, target.value) proof_of_work.value, target.value)
) )
def repair(self): async def repair(self):
previous_header_hash = fail = None previous_header_hash = fail = None
for height in range(self.height): for height in range(self.height):
raw = self.get_raw_header(height) raw = self.get_raw_header(height)

View file

@ -140,8 +140,7 @@ class BaseNetwork:
return session return session
def _update_remote_height(self, header_args): def _update_remote_height(self, header_args):
if header_args and header_args[0]: self.remote_height = header_args[0]["height"]
self.remote_height = header_args[0]["height"]
def ensure_server_version(self, required='1.2'): def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', [__version__, required]) return self.rpc('server.version', [__version__, required])