add connection id workaround

This commit is contained in:
Jack Robison 2019-05-05 20:22:10 -04:00
parent b2f63a1545
commit 24e073680b
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 19 additions and 18 deletions

View file

@ -189,7 +189,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
@cache_concurrent
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int,
peer_connect_timeout: float, blob_download_timeout: float,
connected_transport: asyncio.Transport = None)\
connected_transport: asyncio.Transport = None, connection_id: int = 0)\
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
"""
Returns [<downloaded blob>, <keep connection>]

View file

@ -33,7 +33,7 @@ class BlobDownloader:
return False
return not (blob.get_is_verified() or not blob.is_writeable())
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'):
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer', connection_id: int = 0):
if blob.get_is_verified():
return
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
@ -41,7 +41,7 @@ class BlobDownloader:
start = self.loop.time()
bytes_received, transport = await request_blob(
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
self.config.blob_download_timeout, connected_transport=transport
self.config.blob_download_timeout, connected_transport=transport, connection_id=connection_id
)
if not transport and peer not in self.ignored:
self.ignored[peer] = self.loop.time()
@ -74,7 +74,8 @@ class BlobDownloader:
))
@cache_concurrent
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob':
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None,
connection_id: int = 0) -> 'AbstractBlob':
blob = self.blob_manager.get_blob(blob_hash, length)
if blob.get_is_verified():
return blob
@ -94,7 +95,7 @@ class BlobDownloader:
break
if peer not in self.active_connections and peer not in self.ignored:
log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port)
t = self.loop.create_task(self.request_blob_from_peer(blob, peer))
t = self.loop.create_task(self.request_blob_from_peer(blob, peer, connection_id))
self.active_connections[peer] = t
await self.new_peer_or_finished()
self.cleanup_active()

View file

@ -58,14 +58,14 @@ class StreamDownloader:
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers)
async def load_descriptor(self):
async def load_descriptor(self, connection_id: int = 0):
# download or get the sd blob
sd_blob = self.blob_manager.get_blob(self.sd_hash)
if not sd_blob.get_is_verified():
try:
now = self.loop.time()
sd_blob = await asyncio.wait_for(
self.blob_downloader.download_blob(self.sd_hash),
self.blob_downloader.download_blob(self.sd_hash, connection_id),
self.config.blob_download_timeout, loop=self.loop
)
log.info("downloaded sd blob %s", self.sd_hash)
@ -79,7 +79,7 @@ class StreamDownloader:
)
log.info("loaded stream manifest %s", self.sd_hash)
async def start(self, node: typing.Optional['Node'] = None):
async def start(self, node: typing.Optional['Node'] = None, connection_id: int = 0):
# set up peer accumulation
if node:
self.node = node
@ -90,7 +90,7 @@ class StreamDownloader:
log.info("searching for peers for stream %s", self.sd_hash)
if not self.descriptor:
await self.load_descriptor()
await self.load_descriptor(connection_id)
# add the head blob to the peer search
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
@ -101,10 +101,10 @@ class StreamDownloader:
self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor
)
async def download_stream_blob(self, blob_info: 'BlobInfo') -> 'AbstractBlob':
async def download_stream_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> 'AbstractBlob':
if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]):
raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}")
blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length)
blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length, connection_id)
return blob
def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
@ -112,11 +112,11 @@ class StreamDownloader:
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
)
async def read_blob(self, blob_info: 'BlobInfo') -> bytes:
async def read_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> bytes:
start = None
if self.time_to_first_bytes is None:
start = self.loop.time()
blob = await self.download_stream_blob(blob_info)
blob = await self.download_stream_blob(blob_info, connection_id)
decrypted = self.decrypt_blob(blob_info, blob)
if start:
self.time_to_first_bytes = self.loop.time() - start

View file

@ -255,7 +255,7 @@ class ManagedStream:
timeout = timeout or self.config.download_timeout
if self._running.is_set():
return
log.info("start downloader for lbry://%s#%s (sd hash %s...)", self.claim_name, self.claim_id, self.sd_hash[:6])
log.info("start downloader for stream (sd hash: %s)", self.sd_hash)
self._running.set()
try:
await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop)
@ -286,13 +286,13 @@ class ManagedStream:
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0)\
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0, connection_id: int = 0)\
-> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
if start_blob_num >= len(self.descriptor.blobs[:-1]):
raise IndexError(start_blob_num)
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
assert i + start_blob_num == blob_info.blob_num
decrypted = await self.downloader.read_blob(blob_info)
decrypted = await self.downloader.read_blob(blob_info, connection_id)
yield (blob_info, decrypted)
async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse:
@ -309,7 +309,7 @@ class ManagedStream:
self.streaming.set()
try:
wrote = 0
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs):
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs, connection_id=2):
if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
decrypted += (b'\x00' * (size - len(decrypted) - wrote - (skip_blobs * 2097151)))
await response.write_eof(decrypted)
@ -336,7 +336,7 @@ class ManagedStream:
self.started_writing.clear()
try:
with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self._aiter_read_stream():
async for blob_info, decrypted in self._aiter_read_stream(connection_id=1):
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
file_write_handle.write(decrypted)
file_write_handle.flush()