add connection id workaround
This commit is contained in:
parent
b2f63a1545
commit
24e073680b
4 changed files with 19 additions and 18 deletions
|
@ -189,7 +189,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
@cache_concurrent
|
@cache_concurrent
|
||||||
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int,
|
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'AbstractBlob', address: str, tcp_port: int,
|
||||||
peer_connect_timeout: float, blob_download_timeout: float,
|
peer_connect_timeout: float, blob_download_timeout: float,
|
||||||
connected_transport: asyncio.Transport = None)\
|
connected_transport: asyncio.Transport = None, connection_id: int = 0)\
|
||||||
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
|
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
|
||||||
"""
|
"""
|
||||||
Returns [<downloaded blob>, <keep connection>]
|
Returns [<downloaded blob>, <keep connection>]
|
||||||
|
|
|
@ -33,7 +33,7 @@ class BlobDownloader:
|
||||||
return False
|
return False
|
||||||
return not (blob.get_is_verified() or not blob.is_writeable())
|
return not (blob.get_is_verified() or not blob.is_writeable())
|
||||||
|
|
||||||
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer'):
|
async def request_blob_from_peer(self, blob: 'AbstractBlob', peer: 'KademliaPeer', connection_id: int = 0):
|
||||||
if blob.get_is_verified():
|
if blob.get_is_verified():
|
||||||
return
|
return
|
||||||
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
|
self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
|
||||||
|
@ -41,7 +41,7 @@ class BlobDownloader:
|
||||||
start = self.loop.time()
|
start = self.loop.time()
|
||||||
bytes_received, transport = await request_blob(
|
bytes_received, transport = await request_blob(
|
||||||
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
|
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
|
||||||
self.config.blob_download_timeout, connected_transport=transport
|
self.config.blob_download_timeout, connected_transport=transport, connection_id=connection_id
|
||||||
)
|
)
|
||||||
if not transport and peer not in self.ignored:
|
if not transport and peer not in self.ignored:
|
||||||
self.ignored[peer] = self.loop.time()
|
self.ignored[peer] = self.loop.time()
|
||||||
|
@ -74,7 +74,8 @@ class BlobDownloader:
|
||||||
))
|
))
|
||||||
|
|
||||||
@cache_concurrent
|
@cache_concurrent
|
||||||
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'AbstractBlob':
|
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None,
|
||||||
|
connection_id: int = 0) -> 'AbstractBlob':
|
||||||
blob = self.blob_manager.get_blob(blob_hash, length)
|
blob = self.blob_manager.get_blob(blob_hash, length)
|
||||||
if blob.get_is_verified():
|
if blob.get_is_verified():
|
||||||
return blob
|
return blob
|
||||||
|
@ -94,7 +95,7 @@ class BlobDownloader:
|
||||||
break
|
break
|
||||||
if peer not in self.active_connections and peer not in self.ignored:
|
if peer not in self.active_connections and peer not in self.ignored:
|
||||||
log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port)
|
log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port)
|
||||||
t = self.loop.create_task(self.request_blob_from_peer(blob, peer))
|
t = self.loop.create_task(self.request_blob_from_peer(blob, peer, connection_id))
|
||||||
self.active_connections[peer] = t
|
self.active_connections[peer] = t
|
||||||
await self.new_peer_or_finished()
|
await self.new_peer_or_finished()
|
||||||
self.cleanup_active()
|
self.cleanup_active()
|
||||||
|
|
|
@ -58,14 +58,14 @@ class StreamDownloader:
|
||||||
|
|
||||||
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers)
|
self.fixed_peers_handle = self.loop.call_later(self.fixed_peers_delay, _delayed_add_fixed_peers)
|
||||||
|
|
||||||
async def load_descriptor(self):
|
async def load_descriptor(self, connection_id: int = 0):
|
||||||
# download or get the sd blob
|
# download or get the sd blob
|
||||||
sd_blob = self.blob_manager.get_blob(self.sd_hash)
|
sd_blob = self.blob_manager.get_blob(self.sd_hash)
|
||||||
if not sd_blob.get_is_verified():
|
if not sd_blob.get_is_verified():
|
||||||
try:
|
try:
|
||||||
now = self.loop.time()
|
now = self.loop.time()
|
||||||
sd_blob = await asyncio.wait_for(
|
sd_blob = await asyncio.wait_for(
|
||||||
self.blob_downloader.download_blob(self.sd_hash),
|
self.blob_downloader.download_blob(self.sd_hash, connection_id),
|
||||||
self.config.blob_download_timeout, loop=self.loop
|
self.config.blob_download_timeout, loop=self.loop
|
||||||
)
|
)
|
||||||
log.info("downloaded sd blob %s", self.sd_hash)
|
log.info("downloaded sd blob %s", self.sd_hash)
|
||||||
|
@ -79,7 +79,7 @@ class StreamDownloader:
|
||||||
)
|
)
|
||||||
log.info("loaded stream manifest %s", self.sd_hash)
|
log.info("loaded stream manifest %s", self.sd_hash)
|
||||||
|
|
||||||
async def start(self, node: typing.Optional['Node'] = None):
|
async def start(self, node: typing.Optional['Node'] = None, connection_id: int = 0):
|
||||||
# set up peer accumulation
|
# set up peer accumulation
|
||||||
if node:
|
if node:
|
||||||
self.node = node
|
self.node = node
|
||||||
|
@ -90,7 +90,7 @@ class StreamDownloader:
|
||||||
log.info("searching for peers for stream %s", self.sd_hash)
|
log.info("searching for peers for stream %s", self.sd_hash)
|
||||||
|
|
||||||
if not self.descriptor:
|
if not self.descriptor:
|
||||||
await self.load_descriptor()
|
await self.load_descriptor(connection_id)
|
||||||
|
|
||||||
# add the head blob to the peer search
|
# add the head blob to the peer search
|
||||||
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
|
self.search_queue.put_nowait(self.descriptor.blobs[0].blob_hash)
|
||||||
|
@ -101,10 +101,10 @@ class StreamDownloader:
|
||||||
self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor
|
self.blob_manager.get_blob(self.sd_hash, length=self.descriptor.length), self.descriptor
|
||||||
)
|
)
|
||||||
|
|
||||||
async def download_stream_blob(self, blob_info: 'BlobInfo') -> 'AbstractBlob':
|
async def download_stream_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> 'AbstractBlob':
|
||||||
if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]):
|
if not filter(lambda blob: blob.blob_hash == blob_info.blob_hash, self.descriptor.blobs[:-1]):
|
||||||
raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}")
|
raise ValueError(f"blob {blob_info.blob_hash} is not part of stream with sd hash {self.sd_hash}")
|
||||||
blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length)
|
blob = await self.blob_downloader.download_blob(blob_info.blob_hash, blob_info.length, connection_id)
|
||||||
return blob
|
return blob
|
||||||
|
|
||||||
def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
|
def decrypt_blob(self, blob_info: 'BlobInfo', blob: 'AbstractBlob') -> bytes:
|
||||||
|
@ -112,11 +112,11 @@ class StreamDownloader:
|
||||||
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
|
binascii.unhexlify(self.descriptor.key.encode()), binascii.unhexlify(blob_info.iv.encode())
|
||||||
)
|
)
|
||||||
|
|
||||||
async def read_blob(self, blob_info: 'BlobInfo') -> bytes:
|
async def read_blob(self, blob_info: 'BlobInfo', connection_id: int = 0) -> bytes:
|
||||||
start = None
|
start = None
|
||||||
if self.time_to_first_bytes is None:
|
if self.time_to_first_bytes is None:
|
||||||
start = self.loop.time()
|
start = self.loop.time()
|
||||||
blob = await self.download_stream_blob(blob_info)
|
blob = await self.download_stream_blob(blob_info, connection_id)
|
||||||
decrypted = self.decrypt_blob(blob_info, blob)
|
decrypted = self.decrypt_blob(blob_info, blob)
|
||||||
if start:
|
if start:
|
||||||
self.time_to_first_bytes = self.loop.time() - start
|
self.time_to_first_bytes = self.loop.time() - start
|
||||||
|
|
|
@ -255,7 +255,7 @@ class ManagedStream:
|
||||||
timeout = timeout or self.config.download_timeout
|
timeout = timeout or self.config.download_timeout
|
||||||
if self._running.is_set():
|
if self._running.is_set():
|
||||||
return
|
return
|
||||||
log.info("start downloader for lbry://%s#%s (sd hash %s...)", self.claim_name, self.claim_id, self.sd_hash[:6])
|
log.info("start downloader for stream (sd hash: %s)", self.sd_hash)
|
||||||
self._running.set()
|
self._running.set()
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop)
|
await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop)
|
||||||
|
@ -286,13 +286,13 @@ class ManagedStream:
|
||||||
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
|
if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING:
|
||||||
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
|
await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED)
|
||||||
|
|
||||||
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0)\
|
async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0, connection_id: int = 0)\
|
||||||
-> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
|
-> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]:
|
||||||
if start_blob_num >= len(self.descriptor.blobs[:-1]):
|
if start_blob_num >= len(self.descriptor.blobs[:-1]):
|
||||||
raise IndexError(start_blob_num)
|
raise IndexError(start_blob_num)
|
||||||
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
|
for i, blob_info in enumerate(self.descriptor.blobs[start_blob_num:-1]):
|
||||||
assert i + start_blob_num == blob_info.blob_num
|
assert i + start_blob_num == blob_info.blob_num
|
||||||
decrypted = await self.downloader.read_blob(blob_info)
|
decrypted = await self.downloader.read_blob(blob_info, connection_id)
|
||||||
yield (blob_info, decrypted)
|
yield (blob_info, decrypted)
|
||||||
|
|
||||||
async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse:
|
async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse:
|
||||||
|
@ -309,7 +309,7 @@ class ManagedStream:
|
||||||
self.streaming.set()
|
self.streaming.set()
|
||||||
try:
|
try:
|
||||||
wrote = 0
|
wrote = 0
|
||||||
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs):
|
async for blob_info, decrypted in self._aiter_read_stream(skip_blobs, connection_id=2):
|
||||||
if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
|
if (blob_info.blob_num == len(self.descriptor.blobs) - 2) or (len(decrypted) + wrote >= size):
|
||||||
decrypted += (b'\x00' * (size - len(decrypted) - wrote - (skip_blobs * 2097151)))
|
decrypted += (b'\x00' * (size - len(decrypted) - wrote - (skip_blobs * 2097151)))
|
||||||
await response.write_eof(decrypted)
|
await response.write_eof(decrypted)
|
||||||
|
@ -336,7 +336,7 @@ class ManagedStream:
|
||||||
self.started_writing.clear()
|
self.started_writing.clear()
|
||||||
try:
|
try:
|
||||||
with open(output_path, 'wb') as file_write_handle:
|
with open(output_path, 'wb') as file_write_handle:
|
||||||
async for blob_info, decrypted in self._aiter_read_stream():
|
async for blob_info, decrypted in self._aiter_read_stream(connection_id=1):
|
||||||
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
|
log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1)
|
||||||
file_write_handle.write(decrypted)
|
file_write_handle.write(decrypted)
|
||||||
file_write_handle.flush()
|
file_write_handle.flush()
|
||||||
|
|
Loading…
Reference in a new issue