add peer address/port to hash blob writer
This commit is contained in:
parent
432fe444f8
commit
3234d70270
4 changed files with 22 additions and 16 deletions
|
@ -79,7 +79,7 @@ class AbstractBlob:
|
||||||
self.length = length
|
self.length = length
|
||||||
self.blob_completed_callback = blob_completed_callback
|
self.blob_completed_callback = blob_completed_callback
|
||||||
self.blob_directory = blob_directory
|
self.blob_directory = blob_directory
|
||||||
self.writers: typing.List[HashBlobWriter] = []
|
self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {}
|
||||||
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
|
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||||
self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
|
self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||||
self.readers: typing.List[typing.BinaryIO] = []
|
self.readers: typing.List[typing.BinaryIO] = []
|
||||||
|
@ -99,7 +99,7 @@ class AbstractBlob:
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
||||||
if not self.is_readable():
|
if not self.is_readable():
|
||||||
raise OSError("not readable")
|
raise OSError(f"{str(type(self))} not readable, {len(self.readers)} readers {len(self.writers)} writers")
|
||||||
with self._reader_context() as reader:
|
with self._reader_context() as reader:
|
||||||
try:
|
try:
|
||||||
self.readers.append(reader)
|
self.readers.append(reader)
|
||||||
|
@ -142,7 +142,8 @@ class AbstractBlob:
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
while self.writers:
|
while self.writers:
|
||||||
self.writers.pop().finished.cancel()
|
peer, writer = self.writers.popitem()
|
||||||
|
writer.finished.cancel()
|
||||||
while self.readers:
|
while self.readers:
|
||||||
reader = self.readers.pop()
|
reader = self.readers.pop()
|
||||||
if reader:
|
if reader:
|
||||||
|
@ -198,10 +199,13 @@ class AbstractBlob:
|
||||||
else:
|
else:
|
||||||
self.verified.set()
|
self.verified.set()
|
||||||
|
|
||||||
def get_blob_writer(self) -> HashBlobWriter:
|
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
|
||||||
|
peer_port: typing.Optional[int] = None) -> HashBlobWriter:
|
||||||
|
if (peer_address, peer_port) in self.writers:
|
||||||
|
log.exception("attempted to download blob twice from %s:%s", peer_address, peer_port)
|
||||||
fut = asyncio.Future(loop=self.loop)
|
fut = asyncio.Future(loop=self.loop)
|
||||||
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
|
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
|
||||||
self.writers.append(writer)
|
self.writers[(peer_address, peer_port)] = writer
|
||||||
|
|
||||||
def writer_finished_callback(finished: asyncio.Future):
|
def writer_finished_callback(finished: asyncio.Future):
|
||||||
try:
|
try:
|
||||||
|
@ -210,18 +214,18 @@ class AbstractBlob:
|
||||||
raise err
|
raise err
|
||||||
verified_bytes = finished.result()
|
verified_bytes = finished.result()
|
||||||
while self.writers:
|
while self.writers:
|
||||||
other = self.writers.pop()
|
_, other = self.writers.popitem()
|
||||||
if other is not writer:
|
if other is not writer:
|
||||||
other.finished.cancel()
|
other.finished.cancel()
|
||||||
self.save_verified_blob(verified_bytes)
|
self.save_verified_blob(verified_bytes)
|
||||||
return
|
return
|
||||||
except (InvalidBlobHashError, InvalidDataError) as error:
|
except (InvalidBlobHashError, InvalidDataError) as error:
|
||||||
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error))
|
log.warning("writer error downloading %s: %s", self.blob_hash[:8], str(error))
|
||||||
except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError):
|
except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
if writer in self.writers:
|
if (peer_address, peer_port) in self.writers:
|
||||||
self.writers.remove(writer)
|
self.writers.pop((peer_address, peer_port))
|
||||||
fut.add_done_callback(writer_finished_callback)
|
fut.add_done_callback(writer_finished_callback)
|
||||||
return writer
|
return writer
|
||||||
|
|
||||||
|
@ -292,10 +296,11 @@ class BlobFile(AbstractBlob):
|
||||||
def is_writeable(self) -> bool:
|
def is_writeable(self) -> bool:
|
||||||
return super().is_writeable() and not os.path.isfile(self.file_path)
|
return super().is_writeable() and not os.path.isfile(self.file_path)
|
||||||
|
|
||||||
def get_blob_writer(self) -> HashBlobWriter:
|
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
|
||||||
|
peer_port: typing.Optional[str] = None) -> HashBlobWriter:
|
||||||
if self.file_exists:
|
if self.file_exists:
|
||||||
raise OSError(f"File already exists '{self.file_path}'")
|
raise OSError(f"File already exists '{self.file_path}'")
|
||||||
return super().get_blob_writer()
|
return super().get_blob_writer(peer_address, peer_port)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
||||||
|
|
|
@ -148,8 +148,9 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
if blob.get_is_verified() or not blob.is_writeable():
|
if blob.get_is_verified() or not blob.is_writeable():
|
||||||
return 0, self.transport
|
return 0, self.transport
|
||||||
try:
|
try:
|
||||||
blob.get_blob_writer()
|
|
||||||
self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(), 0
|
self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(self.peer_address,
|
||||||
|
self.peer_port), 0
|
||||||
self._response_fut = asyncio.Future(loop=self.loop)
|
self._response_fut = asyncio.Future(loop=self.loop)
|
||||||
return await self._download_blob()
|
return await self._download_blob()
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
|
|
|
@ -63,7 +63,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
return
|
return
|
||||||
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
||||||
if not self.sd_blob.get_is_verified():
|
if not self.sd_blob.get_is_verified():
|
||||||
self.writer = self.sd_blob.get_blob_writer()
|
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||||
self.incoming.set()
|
self.incoming.set()
|
||||||
self.send_response({"send_sd_blob": True})
|
self.send_response({"send_sd_blob": True})
|
||||||
try:
|
try:
|
||||||
|
@ -102,7 +102,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
return
|
return
|
||||||
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
||||||
if not blob.get_is_verified():
|
if not blob.get_is_verified():
|
||||||
self.writer = blob.get_blob_writer()
|
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||||
self.incoming.set()
|
self.incoming.set()
|
||||||
self.send_response({"send_blob": True})
|
self.send_response({"send_blob": True})
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -41,7 +41,7 @@ class TestBlob(AsyncioTestCase):
|
||||||
|
|
||||||
async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None):
|
async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None):
|
||||||
blob = self._get_blob(blob_class, blob_directory=blob_directory)
|
blob = self._get_blob(blob_class, blob_directory=blob_directory)
|
||||||
writers = [blob.get_blob_writer() for _ in range(5)]
|
writers = [blob.get_blob_writer('1.2.3.4', port) for port in range(5)]
|
||||||
self.assertEqual(5, len(blob.writers))
|
self.assertEqual(5, len(blob.writers))
|
||||||
|
|
||||||
# test that writing too much causes the writer to fail with InvalidDataError and to be removed
|
# test that writing too much causes the writer to fail with InvalidDataError and to be removed
|
||||||
|
|
Loading…
Reference in a new issue