throttle other blob download attempts once the first gets to 64k bytes written

This commit is contained in:
Jack Robison 2019-08-12 13:50:26 -04:00
parent 49c51267dc
commit 8871a611f9
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
3 changed files with 29 additions and 2 deletions

View file

@ -111,6 +111,16 @@ class AbstractBlob:
def _write_blob(self, blob_bytes: bytes):
raise NotImplementedError()
def pause_other_writers(self, transport: asyncio.Transport):
for other in self.writers:
if other and other is not transport:
other.pause_reading()
def resume_other_writers(self, transport: asyncio.Transport):
for other in self.writers:
if other and other is not transport:
other.resume_reading()
def set_length(self, length) -> None:
if self.length is not None and length == self.length:
return
@ -203,7 +213,11 @@ class AbstractBlob:
if transport and transport in self.writers and not self.writers[transport].closed():
raise OSError(f"attempted to download blob twice from {transport.get_extra_info('peername')}")
fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
writer = HashBlobWriter(
self.blob_hash, self.get_length, fut,
lambda: self.pause_other_writers(transport),
lambda: self.resume_other_writers(transport)
)
self.writers[transport] = writer
def remove_writer(_):

View file

@ -11,9 +11,13 @@ log = logging.getLogger(__name__)
class HashBlobWriter:
def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int],
finished: 'asyncio.Future[bytes]'):
finished: 'asyncio.Future[bytes]', pause_other_writers: typing.Callable[[], None],
resume_other_writers: typing.Callable[[], None]):
self.expected_blob_hash = expected_blob_hash
self.get_length = get_length
self.pause_other_writers = pause_other_writers
self.resume_other_writers = resume_other_writers
self.paused_others = False
self.buffer = BytesIO()
self.finished = finished
self.finished.add_done_callback(lambda *_: self.close_handle())
@ -60,6 +64,9 @@ class HashBlobWriter:
elif self.finished and not (self.finished.done() or self.finished.cancelled()):
self.finished.set_result(self.buffer.getvalue())
self.close_handle()
if self.len_so_far >= 64000 and not self.paused_others:
self.paused_others = True
self.pause_other_writers()
def close_handle(self):
if not self.finished.done():
@ -67,3 +74,5 @@ class HashBlobWriter:
if self.buffer is not None:
self.buffer.close()
self.buffer = None
if self.paused_others:
self.resume_other_writers()

View file

@ -137,6 +137,10 @@ class TestBlobExchange(BlobExchangeTestBase):
t1 = asyncio.Transport()
t2 = asyncio.Transport()
t1.pause_reading = lambda: None
t1.resume_reading = lambda: None
t2.pause_reading = lambda: None
t2.resume_reading = lambda: None
writer1 = blob.get_blob_writer(t1)
writer2 = blob.get_blob_writer(t2)