forked from LBRYCommunity/lbry-sdk
add connection id workaround
This commit is contained in:
parent
b2f63a1545
commit
24e073680b
4 changed files with 19 additions and 18 deletions
|
@ -189,7 +189,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
@cache_concurrent
|
||||
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int,
|
||||
peer_connect_timeout: float, blob_download_timeout: float,
|
||||
connected_transport: asyncio.Transport = None)\
|
||||
connected_transport: asyncio.Transport = None, connection_id: int = 0)\
|
||||
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
|
||||
"""
|
||||
Returns [<downloaded blob>, <keep connection>]
|
||||
|
|
|
@ -33,7 +33,7 @@ class BlobDownloader:
|
|||
return False
|
||||
return not (blob.get_is_verified() or not blob.is_writeable())
|
||||
|
||||
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'):
|
||||
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer', connection_id: int = 0):
|
||||
if blob.get_is_verified():
|
||||
return
|
||||
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
|
||||
|
@ -41,7 +41,7 @@ class BlobDownloader:
|
|||
start = self.loop.time()
|
||||
bytes_received, transport = await request_blob(
|
||||
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
|
||||
self.config.blob_download_timeout, connected_transport=transport
|
||||
self.config.blob_download_timeout, connected_transport=transport, connection_id=connection_id
|
||||
)
|
||||
if not transport and peer not in self.ignored:
|
||||
self.ignored[peer] = self.loop.time()
|
||||
|
@ -74,7 +74,8 @@ class BlobDownloader:
|
|||
))
|
||||
|
||||
@cache_concurrent
|
||||
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob':
|
||||
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None,
|
||||
connection_id: int = 0) -> 'AbstractBlob':
|
||||
blob = self.blob_manager.get_blob(blob_hash, length)
|
||||
if blob.get_is_verified():
|
||||
return blob
|
||||
|
@ -94,7 +95,7 @@ class BlobDownloader:
|
|||
break
|
||||
if peer not in self.active_connections and peer not in self.ignored:
|
||||
log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port)
|
||||
t = self.loop.create_task(self.request_blob_from_peer(blob, peer))
|
||||
t = self.loop.create_task(self.request_blob_from_peer(blob, peer, connection_id))
|
||||
self.active_connections[peer] = t
|
||||
await self.new_peer_or_finished()
|
||||
self.cleanup_active()
|
||||
|
|
|
@ -58,14 +58,14 @@ class StreamDownloader:
|
|||
|
||||
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers)
|
||||
|
||||
async def load_descriptor(self):
|
||||
async def load_descriptor(self, connection_id: int = 0):
|
||||
# download or get the sd blob
|
||||
sd_blob = self.blob_manager.get_blob(self.sd_hash)
|
||||
if not sd_blob.get_is_verified():
|
||||
try:
|
||||
now = self.loop.time()
|
||||
sd_blob = await asyncio.wait_for(
|
||||
self.blob_downloader.download_blob(self.sd_hash),
|
||||
self.blob_downloader.download_blob(self.sd_hash, connection_id),
|
||||
self.config.blob_download_timeout, loop=self.loop
|
||||
)
|
||||
log.info("downloaded sd blob %s", self.sd_hash)
|
||||
|
@ -79,7 +79,7 @@ class StreamDownloader:
|
|||
)
|
||||
log.info("loaded stream manifest %s", self.sd_hash)
|
||||
|
||||
async def start(self, node: typing.Optional['Node'] = None):
|
||||
async def start(self, node: typing.Optional['Node'] = None, connection_id: int = 0):
|
||||
# set up peer accumulation
|
||||
if node:
|
||||
self.node = node
|
||||
|
@ -90,7 +90,7 @@ class StreamDownloader:
|
|||
log.info("searching for peers for stream %s", self.sd_hash)
|
||||
|
||||
if not self.descriptor:
|
||||
await self.load_descriptor()
|
||||
await self.load_descriptor(connection_id)
|
||||
|
||||
# add the head blob to the peer search
|
||||
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
|
||||
|
@ -101,10 +101,10 @@ class StreamDownloader:
|
|||
self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor
|
||||
)
|
||||
|
||||
async def download_stream_blob(self, blob_info: 'BlobInfo') -> 'AbstractBlob':
|
||||
async def download_stream_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> 'AbstractBlob':
|
||||
if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]):
|
||||
raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}")
|
||||
blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length)
|
||||
blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length, connection_id)
|
||||
return blob
|
||||
|
||||
def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
|
||||
|
@ -112,11 +112,11 @@ class StreamDownloader:
|
|||
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
|
||||
)
|
||||
|
||||
async def read_blob(self, blob_info: 'BlobInfo') -> bytes:
|
||||
async def read_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> bytes:
|
||||
start = None
|
||||
if self.time_to_first_bytes is None:
|
||||
start = self.loop.time()
|
||||
blob = await self.download_stream_blob(blob_info)
|
||||
blob = await self.download_stream_blob(blob_info, connection_id)
|
||||
decrypted = self.decrypt_blob(blob_info, blob)
|
||||
if start:
|
||||
self.time_to_first_bytes = self.loop.time() - start
|
||||
|
|
|
@ -255,7 +255,7 @@ class ManagedStream:
|
|||
timeout = timeout or self.config.download_timeout
|
||||
if self._running.is_set():
|
||||
return
|
||||
log.info("start downloader for lbry://%s#%s (sd hash %s...)", self.claim_name, self.claim_id, self.sd_hash[:6])
|
||||
log.info("start downloader for stream (sd hash: %s)", self.sd_hash)
|
||||
self._running.set()
|
||||
try:
|
||||
await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop)
|
||||
|
@ -286,13 +286,13 @@ class ManagedStream:
|
|||
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
|
||||
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
|
||||
|
||||
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0)\
|
||||
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0, connection_id: int = 0)\
|
||||
-> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
|
||||
if start_blob_num >= len(self.descriptor.blobs[:-1]):
|
||||
raise IndexError(start_blob_num)
|
||||
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
|
||||
assert i + start_blob_num == blob_info.blob_num
|
||||
decrypted = await self.downloader.read_blob(blob_info)
|
||||
decrypted = await self.downloader.read_blob(blob_info, connection_id)
|
||||
yield (blob_info, decrypted)
|
||||
|
||||
async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse:
|
||||
|
@ -309,7 +309,7 @@ class ManagedStream:
|
|||
self.streaming.set()
|
||||
try:
|
||||
wrote = 0
|
||||
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs):
|
||||
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs, connection_id=2):
|
||||
if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
|
||||
decrypted += (b'\x00' * (size - len(decrypted) - wrote - (skip_blobs * 2097151)))
|
||||
await response.write_eof(decrypted)
|
||||
|
@ -336,7 +336,7 @@ class ManagedStream:
|
|||
self.started_writing.clear()
|
||||
try:
|
||||
with open(output_path, 'wb') as file_write_handle:
|
||||
async for blob_info, decrypted in self._aiter_read_stream():
|
||||
async for blob_info, decrypted in self._aiter_read_stream(connection_id=1):
|
||||
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
|
||||
file_write_handle.write(decrypted)
|
||||
file_write_handle.flush()
|
||||
|
|
Loading…
Add table
Reference in a new issue