forked from LBRYCommunity/lbry-sdk
refactor and clean up
This commit is contained in:
parent
d0607b6fec
commit
d9460dcd9e
6 changed files with 48 additions and 46 deletions
|
@ -163,3 +163,6 @@ class ComponentManager:
|
|||
if component.component_name == component_name:
|
||||
return component.component
|
||||
raise NameError(component_name)
|
||||
|
||||
def has_component(self, component_name):
|
||||
return any(component.component_name == component_name for component in self.components)
|
||||
|
|
|
@ -109,31 +109,30 @@ class HeadersComponent(Component):
|
|||
self.headers_file = os.path.join(self.headers_dir, 'headers')
|
||||
self.old_file = os.path.join(self.conf.wallet_dir, 'blockchain_headers')
|
||||
self.headers = Headers(self.headers_file)
|
||||
self._downloading_headers = None
|
||||
self.is_downloading_headers = False
|
||||
self._headers_progress_percent = 0
|
||||
|
||||
@property
|
||||
def component(self):
|
||||
return self
|
||||
|
||||
def _round_progress(self, local_height, remote_height):
|
||||
return min(max(math.ceil(float(local_height) / float(remote_height) * 100), 0), 100)
|
||||
|
||||
async def get_status(self):
|
||||
if self._downloading_headers:
|
||||
progress = None
|
||||
if self.is_downloading_headers:
|
||||
progress = self._headers_progress_percent
|
||||
else:
|
||||
try:
|
||||
wallet_manager = self.component_manager.get_component(WALLET_COMPONENT)
|
||||
if wallet_manager and wallet_manager.ledger.network.remote_height > 0:
|
||||
local_height = wallet_manager.ledger.headers.height
|
||||
remote_height = wallet_manager.ledger.network.remote_height
|
||||
progress = max(math.ceil(float(local_height) / float(remote_height) * 100), 0)
|
||||
else:
|
||||
return {}
|
||||
except NameError:
|
||||
return {}
|
||||
elif self.component_manager.has_component(WALLET_COMPONENT):
|
||||
wallet_manager = self.component_manager.get_component(WALLET_COMPONENT)
|
||||
if wallet_manager and wallet_manager.ledger.network.remote_height > 0:
|
||||
local_height = wallet_manager.ledger.headers.height
|
||||
remote_height = wallet_manager.ledger.network.remote_height
|
||||
progress = self._round_progress(local_height, remote_height)
|
||||
return {
|
||||
'downloading_headers': True,
|
||||
'download_progress': progress
|
||||
} if progress < 100 else {}
|
||||
} if progress is not None and progress < 100 else {}
|
||||
|
||||
async def fetch_headers_from_s3(self):
|
||||
local_header_size = self.headers.bytes_size
|
||||
|
@ -158,8 +157,8 @@ class HeadersComponent(Component):
|
|||
if not await self.headers.connect(len(self.headers), chunk):
|
||||
log.warning("Error connecting downloaded headers from at %s.", self.headers.height)
|
||||
return
|
||||
self._headers_progress_percent = math.ceil(
|
||||
float(self.headers.bytes_size) / float(final_size_after_download) * 100
|
||||
self._headers_progress_percent = self._round_progress(
|
||||
self.headers.bytes_size, final_size_after_download
|
||||
)
|
||||
|
||||
def local_header_file_size(self):
|
||||
|
@ -167,12 +166,14 @@ class HeadersComponent(Component):
|
|||
return os.stat(self.headers_file).st_size
|
||||
return 0
|
||||
|
||||
async def get_download_height(self):
|
||||
async with utils.aiohttp_request('HEAD', HEADERS_URL) as response:
|
||||
if response.status != 200:
|
||||
log.warning("Header download error: %s", response.status)
|
||||
return 0
|
||||
return response.content_length // HEADER_SIZE
|
||||
async def get_downloadable_header_height(self) -> typing.Optional[int]:
|
||||
try:
|
||||
async with utils.aiohttp_request('HEAD', HEADERS_URL) as response:
|
||||
if response.status != 200:
|
||||
log.warning("Header download error, unexpected response code: %s", response.status)
|
||||
return response.content_length // HEADER_SIZE
|
||||
except OSError:
|
||||
log.exception("Failed to download headers using https.")
|
||||
|
||||
async def should_download_headers_from_s3(self):
|
||||
if self.conf.blockchain_name != "lbrycrd_main":
|
||||
|
@ -182,14 +183,11 @@ class HeadersComponent(Component):
|
|||
return False
|
||||
|
||||
local_height = self.headers.height
|
||||
try:
|
||||
remote_height = await self.get_download_height()
|
||||
except OSError:
|
||||
log.warning("Failed to download headers using https.")
|
||||
return False
|
||||
log.info("remote height: %i, local height: %i", remote_height, local_height)
|
||||
if remote_height > (local_height + s3_headers_depth):
|
||||
return True
|
||||
remote_height = await self.get_downloadable_header_height()
|
||||
if remote_height is not None:
|
||||
log.info("remote height: %i, local height: %i", remote_height, local_height)
|
||||
if remote_height > (local_height + s3_headers_depth):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def start(self):
|
||||
|
@ -199,15 +197,15 @@ class HeadersComponent(Component):
|
|||
log.warning("Moving old headers from %s to %s.", self.old_file, self.headers_file)
|
||||
os.rename(self.old_file, self.headers_file)
|
||||
await self.headers.open()
|
||||
self.headers.repair()
|
||||
self._downloading_headers = await self.should_download_headers_from_s3()
|
||||
if self._downloading_headers:
|
||||
await self.headers.repair()
|
||||
if await self.should_download_headers_from_s3():
|
||||
try:
|
||||
self.is_downloading_headers = True
|
||||
await self.fetch_headers_from_s3()
|
||||
except Exception as err:
|
||||
log.error("failed to fetch headers from s3: %s", err)
|
||||
finally:
|
||||
self._downloading_headers = False
|
||||
self.is_downloading_headers = False
|
||||
await self.headers.close()
|
||||
|
||||
async def stop(self):
|
||||
|
@ -401,10 +399,8 @@ class StreamManagerComponent(Component):
|
|||
blob_manager = self.component_manager.get_component(BLOB_COMPONENT)
|
||||
storage = self.component_manager.get_component(DATABASE_COMPONENT)
|
||||
wallet = self.component_manager.get_component(WALLET_COMPONENT)
|
||||
try:
|
||||
node = self.component_manager.get_component(DHT_COMPONENT)
|
||||
except NameError:
|
||||
node = None
|
||||
node = self.component_manager.get_component(DHT_COMPONENT)\
|
||||
if self.component_manager.has_component(DHT_COMPONENT) else None
|
||||
log.info('Starting the file manager')
|
||||
loop = asyncio.get_event_loop()
|
||||
self.stream_manager = StreamManager(
|
||||
|
|
|
@ -56,10 +56,14 @@ class ServerPickingTestCase(AsyncioTestCase):
|
|||
async def _make_fake_server(self, latency=1.0, port=1337):
|
||||
# local fake server with artificial latency
|
||||
proto = RPCSession()
|
||||
proto.handle_request = lambda _: asyncio.sleep(latency)
|
||||
async def __handler(_):
|
||||
await asyncio.sleep(latency)
|
||||
return {'height': 1}
|
||||
|
||||
proto.handle_request = __handler
|
||||
server = await self.loop.create_server(lambda: proto, host='127.0.0.1', port=port)
|
||||
self.addCleanup(server.close)
|
||||
return ('127.0.0.1', port)
|
||||
return '127.0.0.1', port
|
||||
|
||||
async def test_pick_fastest(self):
|
||||
ledger = Mock(config={
|
||||
|
|
|
@ -98,19 +98,19 @@ class BasicHeadersTests(BitcoinHeadersTestCase):
|
|||
headers = MainHeaders(':memory:')
|
||||
await headers.connect(0, self.get_bytes(block_bytes(3001)))
|
||||
self.assertEqual(headers.height, 3000)
|
||||
headers.repair()
|
||||
await headers.repair()
|
||||
self.assertEqual(headers.height, 3000)
|
||||
# corrupt the middle of it
|
||||
headers.io.seek(block_bytes(1500))
|
||||
headers.io.write(b"wtf")
|
||||
headers.repair()
|
||||
await headers.repair()
|
||||
self.assertEqual(headers.height, 1499)
|
||||
self.assertEqual(len(headers), 1500)
|
||||
# corrupt by appending
|
||||
headers.io.seek(block_bytes(len(headers)))
|
||||
headers.io.write(b"appending")
|
||||
headers._size = None
|
||||
headers.repair()
|
||||
await headers.repair()
|
||||
self.assertEqual(headers.height, 1499)
|
||||
await headers.connect(len(headers), self.get_bytes(block_bytes(3001 - 1500), after=block_bytes(1500)))
|
||||
self.assertEqual(headers.height, 3000)
|
||||
|
|
|
@ -168,7 +168,7 @@ class BaseHeaders:
|
|||
proof_of_work.value, target.value)
|
||||
)
|
||||
|
||||
def repair(self):
|
||||
async def repair(self):
|
||||
previous_header_hash = fail = None
|
||||
for height in range(self.height):
|
||||
raw = self.get_raw_header(height)
|
||||
|
|
|
@ -140,8 +140,7 @@ class BaseNetwork:
|
|||
return session
|
||||
|
||||
def _update_remote_height(self, header_args):
|
||||
if header_args and header_args[0]:
|
||||
self.remote_height = header_args[0]["height"]
|
||||
self.remote_height = header_args[0]["height"]
|
||||
|
||||
def ensure_server_version(self, required='1.2'):
|
||||
return self.rpc('server.version', [__version__, required])
|
||||
|
|
Loading…
Add table
Reference in a new issue