From d9460dcd9eef7d9dc7943906dac0817f2b8543b5 Mon Sep 17 00:00:00 2001
From: Victor Shyba <victor1984@riseup.net>
Date: Fri, 12 Jul 2019 19:54:04 -0300
Subject: [PATCH] refactor and clean up

---
 lbry/lbry/extras/daemon/ComponentManager.py   |  3 +
 lbry/lbry/extras/daemon/Components.py         | 72 +++++++++----------
 .../client_tests/integration/test_network.py  |  8 ++-
 torba/tests/client_tests/unit/test_headers.py |  6 +-
 torba/torba/client/baseheader.py              |  2 +-
 torba/torba/client/basenetwork.py             |  3 +-
 6 files changed, 48 insertions(+), 46 deletions(-)

diff --git a/lbry/lbry/extras/daemon/ComponentManager.py b/lbry/lbry/extras/daemon/ComponentManager.py
index 3206de83d..c9e2fa336 100644
--- a/lbry/lbry/extras/daemon/ComponentManager.py
+++ b/lbry/lbry/extras/daemon/ComponentManager.py
@@ -163,3 +163,6 @@ class ComponentManager:
             if component.component_name == component_name:
                 return component.component
         raise NameError(component_name)
+
+    def has_component(self, component_name):
+        return any(component.component_name == component_name for component in self.components)
diff --git a/lbry/lbry/extras/daemon/Components.py b/lbry/lbry/extras/daemon/Components.py
index c0b858865..b01e2f315 100644
--- a/lbry/lbry/extras/daemon/Components.py
+++ b/lbry/lbry/extras/daemon/Components.py
@@ -109,31 +109,30 @@ class HeadersComponent(Component):
         self.headers_file = os.path.join(self.headers_dir, 'headers')
         self.old_file = os.path.join(self.conf.wallet_dir, 'blockchain_headers')
         self.headers = Headers(self.headers_file)
-        self._downloading_headers = None
+        self.is_downloading_headers = False
         self._headers_progress_percent = 0
 
     @property
     def component(self):
         return self
 
+    def _round_progress(self, local_height, remote_height):
+        return min(max(math.ceil(float(local_height) / float(remote_height) * 100), 0), 100)
+
     async def get_status(self):
-        if self._downloading_headers:
+        progress = None
+        if self.is_downloading_headers:
             progress = self._headers_progress_percent
-        else:
-            try:
-                wallet_manager = self.component_manager.get_component(WALLET_COMPONENT)
-                if wallet_manager and wallet_manager.ledger.network.remote_height > 0:
-                    local_height = wallet_manager.ledger.headers.height
-                    remote_height = wallet_manager.ledger.network.remote_height
-                    progress = max(math.ceil(float(local_height) / float(remote_height) * 100), 0)
-                else:
-                    return {}
-            except NameError:
-                return {}
+        elif self.component_manager.has_component(WALLET_COMPONENT):
+            wallet_manager = self.component_manager.get_component(WALLET_COMPONENT)
+            if wallet_manager and wallet_manager.ledger.network.remote_height > 0:
+                local_height = wallet_manager.ledger.headers.height
+                remote_height = wallet_manager.ledger.network.remote_height
+                progress = self._round_progress(local_height, remote_height)
         return {
             'downloading_headers': True,
             'download_progress': progress
-        } if progress < 100 else {}
+        } if progress is not None and progress < 100 else {}
 
     async def fetch_headers_from_s3(self):
         local_header_size = self.headers.bytes_size
@@ -158,8 +157,8 @@ class HeadersComponent(Component):
                 if not await self.headers.connect(len(self.headers), chunk):
                     log.warning("Error connecting downloaded headers from at %s.", self.headers.height)
                     return
-                self._headers_progress_percent = math.ceil(
-                    float(self.headers.bytes_size) / float(final_size_after_download) * 100
+                self._headers_progress_percent = self._round_progress(
+                    self.headers.bytes_size, final_size_after_download
                 )
 
     def local_header_file_size(self):
@@ -167,12 +166,14 @@ class HeadersComponent(Component):
             return os.stat(self.headers_file).st_size
         return 0
 
-    async def get_download_height(self):
-        async with utils.aiohttp_request('HEAD', HEADERS_URL) as response:
-            if response.status != 200:
-                log.warning("Header download error: %s", response.status)
-                return 0
-            return response.content_length // HEADER_SIZE
+    async def get_downloadable_header_height(self) -> typing.Optional[int]:
+        try:
+            async with utils.aiohttp_request('HEAD', HEADERS_URL) as response:
+                if response.status != 200:
+                    log.warning("Header download error, unexpected response code: %s", response.status)
+                return response.content_length // HEADER_SIZE
+        except OSError:
+            log.exception("Failed to download headers using https.")
 
     async def should_download_headers_from_s3(self):
         if self.conf.blockchain_name != "lbrycrd_main":
@@ -182,14 +183,11 @@ class HeadersComponent(Component):
             return False
 
         local_height = self.headers.height
-        try:
-            remote_height = await self.get_download_height()
-        except OSError:
-            log.warning("Failed to download headers using https.")
-            return False
-        log.info("remote height: %i, local height: %i", remote_height, local_height)
-        if remote_height > (local_height + s3_headers_depth):
-            return True
+        remote_height = await self.get_downloadable_header_height()
+        if remote_height is not None:
+            log.info("remote height: %i, local height: %i", remote_height, local_height)
+            if remote_height > (local_height + s3_headers_depth):
+                return True
         return False
 
     async def start(self):
@@ -199,15 +197,15 @@ class HeadersComponent(Component):
             log.warning("Moving old headers from %s to %s.", self.old_file, self.headers_file)
             os.rename(self.old_file, self.headers_file)
         await self.headers.open()
-        self.headers.repair()
-        self._downloading_headers = await self.should_download_headers_from_s3()
-        if self._downloading_headers:
+        await self.headers.repair()
+        if await self.should_download_headers_from_s3():
             try:
+                self.is_downloading_headers = True
                 await self.fetch_headers_from_s3()
             except Exception as err:
                 log.error("failed to fetch headers from s3: %s", err)
             finally:
-                self._downloading_headers = False
+                self.is_downloading_headers = False
         await self.headers.close()
 
     async def stop(self):
@@ -401,10 +399,8 @@ class StreamManagerComponent(Component):
         blob_manager = self.component_manager.get_component(BLOB_COMPONENT)
         storage = self.component_manager.get_component(DATABASE_COMPONENT)
         wallet = self.component_manager.get_component(WALLET_COMPONENT)
-        try:
-            node = self.component_manager.get_component(DHT_COMPONENT)
-        except NameError:
-            node = None
+        node = self.component_manager.get_component(DHT_COMPONENT)\
+            if self.component_manager.has_component(DHT_COMPONENT) else None
         log.info('Starting the file manager')
         loop = asyncio.get_event_loop()
         self.stream_manager = StreamManager(
diff --git a/torba/tests/client_tests/integration/test_network.py b/torba/tests/client_tests/integration/test_network.py
index d5870ffb2..c74102fe5 100644
--- a/torba/tests/client_tests/integration/test_network.py
+++ b/torba/tests/client_tests/integration/test_network.py
@@ -56,10 +56,14 @@ class ServerPickingTestCase(AsyncioTestCase):
     async def _make_fake_server(self, latency=1.0, port=1337):
         # local fake server with artificial latency
         proto = RPCSession()
-        proto.handle_request = lambda _: asyncio.sleep(latency)
+        async def __handler(_):
+            await asyncio.sleep(latency)
+            return {'height': 1}
+
+        proto.handle_request = __handler
         server = await self.loop.create_server(lambda: proto, host='127.0.0.1', port=port)
         self.addCleanup(server.close)
-        return ('127.0.0.1', port)
+        return '127.0.0.1', port
 
     async def test_pick_fastest(self):
         ledger = Mock(config={
diff --git a/torba/tests/client_tests/unit/test_headers.py b/torba/tests/client_tests/unit/test_headers.py
index 9d3b65cc6..ca494b8e2 100644
--- a/torba/tests/client_tests/unit/test_headers.py
+++ b/torba/tests/client_tests/unit/test_headers.py
@@ -98,19 +98,19 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
         headers = MainHeaders(':memory:')
         await headers.connect(0, self.get_bytes(block_bytes(3001)))
         self.assertEqual(headers.height, 3000)
-        headers.repair()
+        await headers.repair()
         self.assertEqual(headers.height, 3000)
         # corrupt the middle of it
         headers.io.seek(block_bytes(1500))
         headers.io.write(b"wtf")
-        headers.repair()
+        await headers.repair()
         self.assertEqual(headers.height, 1499)
         self.assertEqual(len(headers), 1500)
         # corrupt by appending
         headers.io.seek(block_bytes(len(headers)))
         headers.io.write(b"appending")
         headers._size = None
-        headers.repair()
+        await headers.repair()
         self.assertEqual(headers.height, 1499)
         await headers.connect(len(headers), self.get_bytes(block_bytes(3001 - 1500), after=block_bytes(1500)))
         self.assertEqual(headers.height, 3000)
diff --git a/torba/torba/client/baseheader.py b/torba/torba/client/baseheader.py
index 59541d4ab..86477c067 100644
--- a/torba/torba/client/baseheader.py
+++ b/torba/torba/client/baseheader.py
@@ -168,7 +168,7 @@ class BaseHeaders:
                         proof_of_work.value, target.value)
                 )
 
-    def repair(self):
+    async def repair(self):
         previous_header_hash = fail = None
         for height in range(self.height):
             raw = self.get_raw_header(height)
diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py
index 524a01729..1f5a122a8 100644
--- a/torba/torba/client/basenetwork.py
+++ b/torba/torba/client/basenetwork.py
@@ -140,8 +140,7 @@ class BaseNetwork:
         return session
 
     def _update_remote_height(self, header_args):
-        if header_args and header_args[0]:
-            self.remote_height = header_args[0]["height"]
+        self.remote_height = header_args[0]["height"]
 
     def ensure_server_version(self, required='1.2'):
         return self.rpc('server.version', [__version__, required])