add peer address/port to hash blob writer

This commit is contained in:
Jack Robison 2019-04-23 09:52:20 -04:00
parent 432fe444f8
commit 3234d70270
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
4 changed files with 22 additions and 16 deletions

View file

@ -79,7 +79,7 @@ class AbstractBlob:
self.length = length self.length = length
self.blob_completed_callback = blob_completed_callback self.blob_completed_callback = blob_completed_callback
self.blob_directory = blob_directory self.blob_directory = blob_directory
self.writers: typing.List[HashBlobWriter] = [] self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {}
self.verified: asyncio.Event = asyncio.Event(loop=self.loop) self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.writing: asyncio.Event = asyncio.Event(loop=self.loop) self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
self.readers: typing.List[typing.BinaryIO] = [] self.readers: typing.List[typing.BinaryIO] = []
@ -99,7 +99,7 @@ class AbstractBlob:
@contextlib.contextmanager @contextlib.contextmanager
def reader_context(self) -> typing.ContextManager[typing.BinaryIO]: def reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
if not self.is_readable(): if not self.is_readable():
raise OSError("not readable") raise OSError(f"{str(type(self))} not readable, {len(self.readers)} readers {len(self.writers)} writers")
with self._reader_context() as reader: with self._reader_context() as reader:
try: try:
self.readers.append(reader) self.readers.append(reader)
@ -142,7 +142,8 @@ class AbstractBlob:
def close(self): def close(self):
while self.writers: while self.writers:
self.writers.pop().finished.cancel() peer, writer = self.writers.popitem()
writer.finished.cancel()
while self.readers: while self.readers:
reader = self.readers.pop() reader = self.readers.pop()
if reader: if reader:
@ -198,10 +199,13 @@ class AbstractBlob:
else: else:
self.verified.set() self.verified.set()
def get_blob_writer(self) -> HashBlobWriter: def get_blob_writer(self, peer_address: typing.Optional[str] = None,
peer_port: typing.Optional[int] = None) -> HashBlobWriter:
if (peer_address, peer_port) in self.writers:
log.exception("attempted to download blob twice from %s:%s", peer_address, peer_port)
fut = asyncio.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut) writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
self.writers.append(writer) self.writers[(peer_address, peer_port)] = writer
def writer_finished_callback(finished: asyncio.Future): def writer_finished_callback(finished: asyncio.Future):
try: try:
@ -210,18 +214,18 @@ class AbstractBlob:
raise err raise err
verified_bytes = finished.result() verified_bytes = finished.result()
while self.writers: while self.writers:
other = self.writers.pop() _, other = self.writers.popitem()
if other is not writer: if other is not writer:
other.finished.cancel() other.finished.cancel()
self.save_verified_blob(verified_bytes) self.save_verified_blob(verified_bytes)
return return
except (InvalidBlobHashError, InvalidDataError) as error: except (InvalidBlobHashError, InvalidDataError) as error:
log.debug("writer error downloading %s: %s", self.blob_hash[:8], str(error)) log.warning("writer error downloading %s: %s", self.blob_hash[:8], str(error))
except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError): except (DownloadCancelledError, asyncio.CancelledError, asyncio.TimeoutError):
pass pass
finally: finally:
if writer in self.writers: if (peer_address, peer_port) in self.writers:
self.writers.remove(writer) self.writers.pop((peer_address, peer_port))
fut.add_done_callback(writer_finished_callback) fut.add_done_callback(writer_finished_callback)
return writer return writer
@ -292,10 +296,11 @@ class BlobFile(AbstractBlob):
def is_writeable(self) -> bool: def is_writeable(self) -> bool:
return super().is_writeable() and not os.path.isfile(self.file_path) return super().is_writeable() and not os.path.isfile(self.file_path)
def get_blob_writer(self) -> HashBlobWriter: def get_blob_writer(self, peer_address: typing.Optional[str] = None,
peer_port: typing.Optional[str] = None) -> HashBlobWriter:
if self.file_exists: if self.file_exists:
raise OSError(f"File already exists '{self.file_path}'") raise OSError(f"File already exists '{self.file_path}'")
return super().get_blob_writer() return super().get_blob_writer(peer_address, peer_port)
@contextlib.contextmanager @contextlib.contextmanager
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:

View file

@ -148,8 +148,9 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
if blob.get_is_verified() or not blob.is_writeable(): if blob.get_is_verified() or not blob.is_writeable():
return 0, self.transport return 0, self.transport
try: try:
blob.get_blob_writer()
self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(), 0 self.blob, self.writer, self._blob_bytes_received = blob, blob.get_blob_writer(self.peer_address,
self.peer_port), 0
self._response_fut = asyncio.Future(loop=self.loop) self._response_fut = asyncio.Future(loop=self.loop)
return await self._download_blob() return await self._download_blob()
except OSError as e: except OSError as e:

View file

@ -63,7 +63,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
return return
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
if not self.sd_blob.get_is_verified(): if not self.sd_blob.get_is_verified():
self.writer = self.sd_blob.get_blob_writer() self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.incoming.set() self.incoming.set()
self.send_response({"send_sd_blob": True}) self.send_response({"send_sd_blob": True})
try: try:
@ -102,7 +102,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
return return
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
if not blob.get_is_verified(): if not blob.get_is_verified():
self.writer = blob.get_blob_writer() self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
self.incoming.set() self.incoming.set()
self.send_response({"send_blob": True}) self.send_response({"send_blob": True})
try: try:

View file

@ -41,7 +41,7 @@ class TestBlob(AsyncioTestCase):
async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None): async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None):
blob = self._get_blob(blob_class, blob_directory=blob_directory) blob = self._get_blob(blob_class, blob_directory=blob_directory)
writers = [blob.get_blob_writer() for _ in range(5)] writers = [blob.get_blob_writer('1.2.3.4', port) for port in range(5)]
self.assertEqual(5, len(blob.writers)) self.assertEqual(5, len(blob.writers))
# test that writing too much causes the writer to fail with InvalidDataError and to be removed # test that writing too much causes the writer to fail with InvalidDataError and to be removed