add peer address/port to hash blob writer
This commit is contained in:
parent
432fe444f8
commit
3234d70270
4 changed files with 22 additions and 16 deletions
|
@ -79,7 +79,7 @@ class AbstractBlob:
|
|||
self.length = length
|
||||
self.blob_completed_callback = blob_completed_callback
|
||||
self.blob_directory = blob_directory
|
||||
self.writers: typing.List[HashBlobWriter] = []
|
||||
self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {}
|
||||
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||
self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||
self.readers: typing.List[typing.BinaryIO] = []
|
||||
|
@ -99,7 +99,7 @@ class AbstractBlob:
|
|||
@contextlib.contextmanager
|
||||
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
||||
if not self.is_readable():
|
||||
raise OSError("not readable")
|
||||
raise OSError(f"{str(type(self))} not readable, {len(self.readers)} readers {len(self.writers)} writers")
|
||||
with self._reader_context() as reader:
|
||||
try:
|
||||
self.readers.append(reader)
|
||||
|
@ -142,7 +142,8 @@ class AbstractBlob:
|
|||
|
||||
def close(self):
|
||||
while self.writers:
|
||||
self.writers.pop().finished.cancel()
|
||||
peer, writer = self.writers.popitem()
|
||||
writer.finished.cancel()
|
||||
while self.readers:
|
||||
reader = self.readers.pop()
|
||||
if reader:
|
||||
|
@ -198,10 +199,13 @@ class AbstractBlob:
|
|||
else:
|
||||
self.verified.set()
|
||||
|
||||
def get_blob_writer(self) -> HashBlobWriter:
|
||||
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
|
||||
peer_port: typing.Optional[int] = None) -> HashBlobWriter:
|
||||
if (peer_address, peer_port) in self.writers:
|
||||
log.exception("attempted to download blob twice from %s:%s", peer_address, peer_port)
|
||||
fut = asyncio.Future(loop=self.loop)
|
||||
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
|
||||
self.writers.append(writer)
|
||||
self.writers[(peer_address, peer_port)] = writer
|
||||
|
||||
def writer_finished_callback(finished: asyncio.Future):
|
||||
try:
|
||||
|
@ -210,18 +214,18 @@ class AbstractBlob:
|
|||
raise err
|
||||
verified_bytes = finished.result()
|
||||
while self.writers:
|
||||
other = self.writers.pop()
|
||||
_, other = self.writers.popitem()
|
||||
if other is not writer:
|
||||
other.finished.cancel()
|
||||
self.save_verified_blob(verified_bytes)
|
||||
return
|
||||
except (InvalidBlobHashError, InvalidDataError) as error:
|
||||
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
|
||||
log.warning("writer error downloading %s: %s", self.blob_hash[:8], str(error))
|
||||
except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
finally:
|
||||
if writer in self.writers:
|
||||
self.writers.remove(writer)
|
||||
if (peer_address, peer_port) in self.writers:
|
||||
self.writers.pop((peer_address, peer_port))
|
||||
fut.add_done_callback(writer_finished_callback)
|
||||
return writer
|
||||
|
||||
|
@ -292,10 +296,11 @@ class BlobFile(AbstractBlob):
|
|||
def is_writeable(self) -> bool:
|
||||
return super().is_writeable() and not os.path.isfile(self.file_path)
|
||||
|
||||
def get_blob_writer(self) -> HashBlobWriter:
|
||||
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
|
||||
peer_port: typing.Optional[str] = None) -> HashBlobWriter:
|
||||
if self.file_exists:
|
||||
raise OSError(f"File already exists '{self.file_path}'")
|
||||
return super().get_blob_writer()
|
||||
return super().get_blob_writer(peer_address, peer_port)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
||||
|
|
|
@ -148,8 +148,9 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
if blob.get_is_verified() or not blob.is_writeable():
|
||||
return 0, self.transport
|
||||
try:
|
||||
blob.get_blob_writer()
|
||||
self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(), 0
|
||||
|
||||
self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(self.peer_address,
|
||||
self.peer_port), 0
|
||||
self._response_fut = asyncio.Future(loop=self.loop)
|
||||
return await self._download_blob()
|
||||
except OSError as e:
|
||||
|
|
|
@ -63,7 +63,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
return
|
||||
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
||||
if not self.sd_blob.get_is_verified():
|
||||
self.writer = self.sd_blob.get_blob_writer()
|
||||
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||
self.incoming.set()
|
||||
self.send_response({"send_sd_blob": True})
|
||||
try:
|
||||
|
@ -102,7 +102,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
return
|
||||
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
||||
if not blob.get_is_verified():
|
||||
self.writer = blob.get_blob_writer()
|
||||
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||
self.incoming.set()
|
||||
self.send_response({"send_blob": True})
|
||||
try:
|
||||
|
|
|
@ -41,7 +41,7 @@ class TestBlob(AsyncioTestCase):
|
|||
|
||||
async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None):
|
||||
blob = self._get_blob(blob_class, blob_directory=blob_directory)
|
||||
writers = [blob.get_blob_writer() for _ in range(5)]
|
||||
writers = [blob.get_blob_writer('1.2.3.4', port) for port in range(5)]
|
||||
self.assertEqual(5, len(blob.writers))
|
||||
|
||||
# test that writing too much causes the writer to fail with InvalidDataError and to be removed
|
||||
|
|
Loading…
Reference in a new issue