refactor and clean up

This commit is contained in:
Victor Shyba 2019-07-12 19:54:04 -03:00 committed by Lex Berezhny
parent d0607b6fec
commit d9460dcd9e
6 changed files with 48 additions and 46 deletions

View file

@ -163,3 +163,6 @@ class ComponentManager:
if component.component_name == component_name:
return component.component
raise NameError(component_name)
def has_component(self, component_name):
return any(component.component_name == component_name for component in self.components)

View file

@ -109,31 +109,30 @@ class HeadersComponent(Component):
self.headers_file = os.path.join(self.headers_dir, 'headers')
self.old_file = os.path.join(self.conf.wallet_dir, 'blockchain_headers')
self.headers = Headers(self.headers_file)
self._downloading_headers = None
self.is_downloading_headers = False
self._headers_progress_percent = 0
@property
def component(self):
return self
def _round_progress(self, local_height, remote_height):
return min(max(math.ceil(float(local_height) / float(remote_height) * 100), 0), 100)
async def get_status(self):
if self._downloading_headers:
progress = None
if self.is_downloading_headers:
progress = self._headers_progress_percent
else:
try:
wallet_manager = self.component_manager.get_component(WALLET_COMPONENT)
if wallet_manager and wallet_manager.ledger.network.remote_height > 0:
local_height = wallet_manager.ledger.headers.height
remote_height = wallet_manager.ledger.network.remote_height
progress = max(math.ceil(float(local_height) / float(remote_height) * 100), 0)
else:
return {}
except NameError:
return {}
elif self.component_manager.has_component(WALLET_COMPONENT):
wallet_manager = self.component_manager.get_component(WALLET_COMPONENT)
if wallet_manager and wallet_manager.ledger.network.remote_height > 0:
local_height = wallet_manager.ledger.headers.height
remote_height = wallet_manager.ledger.network.remote_height
progress = self._round_progress(local_height, remote_height)
return {
'downloading_headers': True,
'download_progress': progress
} if progress < 100 else {}
} if progress is not None and progress < 100 else {}
async def fetch_headers_from_s3(self):
local_header_size = self.headers.bytes_size
@ -158,8 +157,8 @@ class HeadersComponent(Component):
if not await self.headers.connect(len(self.headers), chunk):
log.warning("Error connecting downloaded headers from at %s.", self.headers.height)
return
self._headers_progress_percent = math.ceil(
float(self.headers.bytes_size) / float(final_size_after_download) * 100
self._headers_progress_percent = self._round_progress(
self.headers.bytes_size, final_size_after_download
)
def local_header_file_size(self):
@ -167,12 +166,14 @@ class HeadersComponent(Component):
return os.stat(self.headers_file).st_size
return 0
async def get_download_height(self):
async with utils.aiohttp_request('HEAD', HEADERS_URL) as response:
if response.status != 200:
log.warning("Header download error: %s", response.status)
return 0
return response.content_length // HEADER_SIZE
async def get_downloadable_header_height(self) -> typing.Optional[int]:
try:
async with utils.aiohttp_request('HEAD', HEADERS_URL) as response:
if response.status != 200:
log.warning("Header download error, unexpected response code: %s", response.status)
return response.content_length // HEADER_SIZE
except OSError:
log.exception("Failed to download headers using https.")
async def should_download_headers_from_s3(self):
if self.conf.blockchain_name != "lbrycrd_main":
@ -182,14 +183,11 @@ class HeadersComponent(Component):
return False
local_height = self.headers.height
try:
remote_height = await self.get_download_height()
except OSError:
log.warning("Failed to download headers using https.")
return False
log.info("remote height: %i, local height: %i", remote_height, local_height)
if remote_height > (local_height + s3_headers_depth):
return True
remote_height = await self.get_downloadable_header_height()
if remote_height is not None:
log.info("remote height: %i, local height: %i", remote_height, local_height)
if remote_height > (local_height + s3_headers_depth):
return True
return False
async def start(self):
@ -199,15 +197,15 @@ class HeadersComponent(Component):
log.warning("Moving old headers from %s to %s.", self.old_file, self.headers_file)
os.rename(self.old_file, self.headers_file)
await self.headers.open()
self.headers.repair()
self._downloading_headers = await self.should_download_headers_from_s3()
if self._downloading_headers:
await self.headers.repair()
if await self.should_download_headers_from_s3():
try:
self.is_downloading_headers = True
await self.fetch_headers_from_s3()
except Exception as err:
log.error("failed to fetch headers from s3: %s", err)
finally:
self._downloading_headers = False
self.is_downloading_headers = False
await self.headers.close()
async def stop(self):
@ -401,10 +399,8 @@ class StreamManagerComponent(Component):
blob_manager = self.component_manager.get_component(BLOB_COMPONENT)
storage = self.component_manager.get_component(DATABASE_COMPONENT)
wallet = self.component_manager.get_component(WALLET_COMPONENT)
try:
node = self.component_manager.get_component(DHT_COMPONENT)
except NameError:
node = None
node = self.component_manager.get_component(DHT_COMPONENT)\
if self.component_manager.has_component(DHT_COMPONENT) else None
log.info('Starting the file manager')
loop = asyncio.get_event_loop()
self.stream_manager = StreamManager(

View file

@ -56,10 +56,14 @@ class ServerPickingTestCase(AsyncioTestCase):
async def _make_fake_server(self, latency=1.0, port=1337):
# local fake server with artificial latency
proto = RPCSession()
proto.handle_request = lambda _: asyncio.sleep(latency)
async def __handler(_):
await asyncio.sleep(latency)
return {'height': 1}
proto.handle_request = __handler
server = await self.loop.create_server(lambda: proto, host='127.0.0.1', port=port)
self.addCleanup(server.close)
return ('127.0.0.1', port)
return '127.0.0.1', port
async def test_pick_fastest(self):
ledger = Mock(config={

View file

@ -98,19 +98,19 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
headers = MainHeaders(':memory:')
await headers.connect(0, self.get_bytes(block_bytes(3001)))
self.assertEqual(headers.height, 3000)
headers.repair()
await headers.repair()
self.assertEqual(headers.height, 3000)
# corrupt the middle of it
headers.io.seek(block_bytes(1500))
headers.io.write(b"wtf")
headers.repair()
await headers.repair()
self.assertEqual(headers.height, 1499)
self.assertEqual(len(headers), 1500)
# corrupt by appending
headers.io.seek(block_bytes(len(headers)))
headers.io.write(b"appending")
headers._size = None
headers.repair()
await headers.repair()
self.assertEqual(headers.height, 1499)
await headers.connect(len(headers), self.get_bytes(block_bytes(3001 - 1500), after=block_bytes(1500)))
self.assertEqual(headers.height, 3000)

View file

@ -168,7 +168,7 @@ class BaseHeaders:
proof_of_work.value, target.value)
)
def repair(self):
async def repair(self):
previous_header_hash = fail = None
for height in range(self.height):
raw = self.get_raw_header(height)

View file

@ -140,8 +140,7 @@ class BaseNetwork:
return session
def _update_remote_height(self, header_args):
if header_args and header_args[0]:
self.remote_height = header_args[0]["height"]
self.remote_height = header_args[0]["height"]
def ensure_server_version(self, required='1.2'):
return self.rpc('server.version', [__version__, required])