add connection id workaround

This commit is contained in:
Jack Robison 2019-05-05 20:22:10 -04:00
parent b2f63a1545
commit 24e073680b
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 19 additions and 18 deletions

View file

@ -189,7 +189,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
@cache_concurrent @cache_concurrent
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int, async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int,
peer_connect_timeout: float, blob_download_timeout: float, peer_connect_timeout: float, blob_download_timeout: float,
connected_transport: asyncio.Transport = None)\ connected_transport: asyncio.Transport = None, connection_id: int = 0)\
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]: -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
""" """
Returns [<downloaded blob>, <keep connection>] Returns [<downloaded blob>, <keep connection>]

View file

@ -33,7 +33,7 @@ class BlobDownloader:
return False return False
return not (blob.get_is_verified() or not blob.is_writeable()) return not (blob.get_is_verified() or not blob.is_writeable())
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'): async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer', connection_id: int = 0):
if blob.get_is_verified(): if blob.get_is_verified():
return return
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
@ -41,7 +41,7 @@ class BlobDownloader:
start = self.loop.time() start = self.loop.time()
bytes_received, transport = await request_blob( bytes_received, transport = await request_blob(
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout, self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
self.config.blob_download_timeout, connected_transport=transport self.config.blob_download_timeout, connected_transport=transport, connection_id=connection_id
) )
if not transport and peer not in self.ignored: if not transport and peer not in self.ignored:
self.ignored[peer] = self.loop.time() self.ignored[peer] = self.loop.time()
@ -74,7 +74,8 @@ class BlobDownloader:
)) ))
@cache_concurrent @cache_concurrent
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob': async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None,
connection_id: int = 0) -> 'AbstractBlob':
blob = self.blob_manager.get_blob(blob_hash, length) blob = self.blob_manager.get_blob(blob_hash, length)
if blob.get_is_verified(): if blob.get_is_verified():
return blob return blob
@ -94,7 +95,7 @@ class BlobDownloader:
break break
if peer not in self.active_connections and peer not in self.ignored: if peer not in self.active_connections and peer not in self.ignored:
log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port) log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port)
t = self.loop.create_task(self.request_blob_from_peer(blob, peer)) t = self.loop.create_task(self.request_blob_from_peer(blob, peer, connection_id))
self.active_connections[peer] = t self.active_connections[peer] = t
await self.new_peer_or_finished() await self.new_peer_or_finished()
self.cleanup_active() self.cleanup_active()

View file

@ -58,14 +58,14 @@ class StreamDownloader:
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers) self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers)
async def load_descriptor(self): async def load_descriptor(self, connection_id: int = 0):
# download or get the sd blob # download or get the sd blob
sd_blob = self.blob_manager.get_blob(self.sd_hash) sd_blob = self.blob_manager.get_blob(self.sd_hash)
if not sd_blob.get_is_verified(): if not sd_blob.get_is_verified():
try: try:
now = self.loop.time() now = self.loop.time()
sd_blob = await asyncio.wait_for( sd_blob = await asyncio.wait_for(
self.blob_downloader.download_blob(self.sd_hash), self.blob_downloader.download_blob(self.sd_hash, connection_id),
self.config.blob_download_timeout, loop=self.loop self.config.blob_download_timeout, loop=self.loop
) )
log.info("downloaded sd blob %s", self.sd_hash) log.info("downloaded sd blob %s", self.sd_hash)
@ -79,7 +79,7 @@ class StreamDownloader:
) )
log.info("loaded stream manifest %s", self.sd_hash) log.info("loaded stream manifest %s", self.sd_hash)
async def start(self, node: typing.Optional['Node'] = None): async def start(self, node: typing.Optional['Node'] = None, connection_id: int = 0):
# set up peer accumulation # set up peer accumulation
if node: if node:
self.node = node self.node = node
@ -90,7 +90,7 @@ class StreamDownloader:
log.info("searching for peers for stream %s", self.sd_hash) log.info("searching for peers for stream %s", self.sd_hash)
if not self.descriptor: if not self.descriptor:
await self.load_descriptor() await self.load_descriptor(connection_id)
# add the head blob to the peer search # add the head blob to the peer search
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash) self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
@ -101,10 +101,10 @@ class StreamDownloader:
self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor
) )
async def download_stream_blob(self, blob_info: 'BlobInfo') -> 'AbstractBlob': async def download_stream_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> 'AbstractBlob':
if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]): if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]):
raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}") raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}")
blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length) blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length, connection_id)
return blob return blob
def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes: def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
@ -112,11 +112,11 @@ class StreamDownloader:
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode()) binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
) )
async def read_blob(self, blob_info: 'BlobInfo') -> bytes: async def read_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> bytes:
start = None start = None
if self.time_to_first_bytes is None: if self.time_to_first_bytes is None:
start = self.loop.time() start = self.loop.time()
blob = await self.download_stream_blob(blob_info) blob = await self.download_stream_blob(blob_info, connection_id)
decrypted = self.decrypt_blob(blob_info, blob) decrypted = self.decrypt_blob(blob_info, blob)
if start: if start:
self.time_to_first_bytes = self.loop.time() - start self.time_to_first_bytes = self.loop.time() - start

View file

@ -255,7 +255,7 @@ class ManagedStream:
timeout = timeout or self.config.download_timeout timeout = timeout or self.config.download_timeout
if self._running.is_set(): if self._running.is_set():
return return
log.info("start downloader for lbry://%s#%s (sd hash %s...)", self.claim_name, self.claim_id, self.sd_hash[:6]) log.info("start downloader for stream (sd hash: %s)", self.sd_hash)
self._running.set() self._running.set()
try: try:
await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop) await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop)
@ -286,13 +286,13 @@ class ManagedStream:
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING: if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED) await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0)\ async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0, connection_id: int = 0)\
-> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]: -> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
if start_blob_num >= len(self.descriptor.blobs[:-1]): if start_blob_num >= len(self.descriptor.blobs[:-1]):
raise IndexError(start_blob_num) raise IndexError(start_blob_num)
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]): for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
assert i + start_blob_num == blob_info.blob_num assert i + start_blob_num == blob_info.blob_num
decrypted = await self.downloader.read_blob(blob_info) decrypted = await self.downloader.read_blob(blob_info, connection_id)
yield (blob_info, decrypted) yield (blob_info, decrypted)
async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse: async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse:
@ -309,7 +309,7 @@ class ManagedStream:
self.streaming.set() self.streaming.set()
try: try:
wrote = 0 wrote = 0
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs): async for blob_info, decrypted in self._aiter_read_stream(skip_blobs, connection_id=2):
if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size): if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
decrypted += (b'\x00' * (size - len(decrypted) - wrote - (skip_blobs * 2097151))) decrypted += (b'\x00' * (size - len(decrypted) - wrote - (skip_blobs * 2097151)))
await response.write_eof(decrypted) await response.write_eof(decrypted)
@ -336,7 +336,7 @@ class ManagedStream:
self.started_writing.clear() self.started_writing.clear()
try: try:
with open(output_path, 'wb') as file_write_handle: with open(output_path, 'wb') as file_write_handle:
async for blob_info, decrypted in self._aiter_read_stream(): async for blob_info, decrypted in self._aiter_read_stream(connection_id=1):
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
file_write_handle.write(decrypted) file_write_handle.write(decrypted)
file_write_handle.flush() file_write_handle.flush()