throttle other blob download attempts once the first gets to 64k bytes written
This commit is contained in:
parent
49c51267dc
commit
8871a611f9
3 changed files with 29 additions and 2 deletions
|
@ -111,6 +111,16 @@ class AbstractBlob:
|
|||
def _write_blob(self, blob_bytes: bytes):
|
||||
raise NotImplementedError()
|
||||
|
||||
def pause_other_writers(self, transport: asyncio.Transport):
|
||||
for other in self.writers:
|
||||
if other and other is not transport:
|
||||
other.pause_reading()
|
||||
|
||||
def resume_other_writers(self, transport: asyncio.Transport):
|
||||
for other in self.writers:
|
||||
if other and other is not transport:
|
||||
other.resume_reading()
|
||||
|
||||
def set_length(self, length) -> None:
|
||||
if self.length is not None and length == self.length:
|
||||
return
|
||||
|
@ -203,7 +213,11 @@ class AbstractBlob:
|
|||
if transport and transport in self.writers and not self.writers[transport].closed():
|
||||
raise OSError(f"attempted to download blob twice from {transport.get_extra_info('peername')}")
|
||||
fut = asyncio.Future(loop=self.loop)
|
||||
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
|
||||
writer = HashBlobWriter(
|
||||
self.blob_hash, self.get_length, fut,
|
||||
lambda: self.pause_other_writers(transport),
|
||||
lambda: self.resume_other_writers(transport)
|
||||
)
|
||||
self.writers[transport] = writer
|
||||
|
||||
def remove_writer(_):
|
||||
|
|
|
@ -11,9 +11,13 @@ log = logging.getLogger(__name__)
|
|||
|
||||
class HashBlobWriter:
|
||||
def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int],
|
||||
finished: 'asyncio.Future[bytes]'):
|
||||
finished: 'asyncio.Future[bytes]', pause_other_writers: typing.Callable[[], None],
|
||||
resume_other_writers: typing.Callable[[], None]):
|
||||
self.expected_blob_hash = expected_blob_hash
|
||||
self.get_length = get_length
|
||||
self.pause_other_writers = pause_other_writers
|
||||
self.resume_other_writers = resume_other_writers
|
||||
self.paused_others = False
|
||||
self.buffer = BytesIO()
|
||||
self.finished = finished
|
||||
self.finished.add_done_callback(lambda *_: self.close_handle())
|
||||
|
@ -60,6 +64,9 @@ class HashBlobWriter:
|
|||
elif self.finished and not (self.finished.done() or self.finished.cancelled()):
|
||||
self.finished.set_result(self.buffer.getvalue())
|
||||
self.close_handle()
|
||||
if self.len_so_far >= 64000 and not self.paused_others:
|
||||
self.paused_others = True
|
||||
self.pause_other_writers()
|
||||
|
||||
def close_handle(self):
|
||||
if not self.finished.done():
|
||||
|
@ -67,3 +74,5 @@ class HashBlobWriter:
|
|||
if self.buffer is not None:
|
||||
self.buffer.close()
|
||||
self.buffer = None
|
||||
if self.paused_others:
|
||||
self.resume_other_writers()
|
||||
|
|
|
@ -137,6 +137,10 @@ class TestBlobExchange(BlobExchangeTestBase):
|
|||
|
||||
t1 = asyncio.Transport()
|
||||
t2 = asyncio.Transport()
|
||||
t1.pause_reading = lambda: None
|
||||
t1.resume_reading = lambda: None
|
||||
t2.pause_reading = lambda: None
|
||||
t2.resume_reading = lambda: None
|
||||
|
||||
writer1 = blob.get_blob_writer(t1)
|
||||
writer2 = blob.get_blob_writer(t2)
|
||||
|
|
Loading…
Add table
Reference in a new issue