Compare commits
3 commits
master
...
throttle-c
Author | SHA1 | Date | |
---|---|---|---|
|
a288adbfc6 | ||
|
8871a611f9 | ||
|
49c51267dc |
6 changed files with 73 additions and 26 deletions
|
@ -79,7 +79,7 @@ class AbstractBlob:
|
||||||
self.length = length
|
self.length = length
|
||||||
self.blob_completed_callback = blob_completed_callback
|
self.blob_completed_callback = blob_completed_callback
|
||||||
self.blob_directory = blob_directory
|
self.blob_directory = blob_directory
|
||||||
self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {}
|
self.writers: typing.Dict[typing.Optional[asyncio.Transport], HashBlobWriter] = {}
|
||||||
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
|
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||||
self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
|
self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||||
self.readers: typing.List[typing.BinaryIO] = []
|
self.readers: typing.List[typing.BinaryIO] = []
|
||||||
|
@ -111,6 +111,16 @@ class AbstractBlob:
|
||||||
def _write_blob(self, blob_bytes: bytes):
|
def _write_blob(self, blob_bytes: bytes):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def pause_other_writers(self, transport: asyncio.Transport):
|
||||||
|
for other in self.writers:
|
||||||
|
if other and other is not transport:
|
||||||
|
other.pause_reading()
|
||||||
|
|
||||||
|
def resume_other_writers(self, transport: asyncio.Transport):
|
||||||
|
for other in self.writers:
|
||||||
|
if other and other is not transport:
|
||||||
|
other.resume_reading()
|
||||||
|
|
||||||
def set_length(self, length) -> None:
|
def set_length(self, length) -> None:
|
||||||
if self.length is not None and length == self.length:
|
if self.length is not None and length == self.length:
|
||||||
return
|
return
|
||||||
|
@ -199,17 +209,20 @@ class AbstractBlob:
|
||||||
if self.blob_completed_callback:
|
if self.blob_completed_callback:
|
||||||
self.blob_completed_callback(self)
|
self.blob_completed_callback(self)
|
||||||
|
|
||||||
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
|
def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter:
|
||||||
peer_port: typing.Optional[int] = None) -> HashBlobWriter:
|
if transport and transport in self.writers and not self.writers[transport].closed():
|
||||||
if (peer_address, peer_port) in self.writers and not self.writers[(peer_address, peer_port)].closed():
|
raise OSError(f"attempted to download blob twice from {transport.get_extra_info('peername')}")
|
||||||
raise OSError(f"attempted to download blob twice from {peer_address}:{peer_port}")
|
|
||||||
fut = asyncio.Future(loop=self.loop)
|
fut = asyncio.Future(loop=self.loop)
|
||||||
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
|
writer = HashBlobWriter(
|
||||||
self.writers[(peer_address, peer_port)] = writer
|
self.blob_hash, self.get_length, fut,
|
||||||
|
lambda: self.pause_other_writers(transport),
|
||||||
|
lambda: self.resume_other_writers(transport)
|
||||||
|
)
|
||||||
|
self.writers[transport] = writer
|
||||||
|
|
||||||
def remove_writer(_):
|
def remove_writer(_):
|
||||||
if (peer_address, peer_port) in self.writers:
|
if transport in self.writers:
|
||||||
del self.writers[(peer_address, peer_port)]
|
del self.writers[transport]
|
||||||
|
|
||||||
fut.add_done_callback(remove_writer)
|
fut.add_done_callback(remove_writer)
|
||||||
|
|
||||||
|
@ -299,11 +312,10 @@ class BlobFile(AbstractBlob):
|
||||||
def is_writeable(self) -> bool:
|
def is_writeable(self) -> bool:
|
||||||
return super().is_writeable() and not os.path.isfile(self.file_path)
|
return super().is_writeable() and not os.path.isfile(self.file_path)
|
||||||
|
|
||||||
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
|
def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter:
|
||||||
peer_port: typing.Optional[str] = None) -> HashBlobWriter:
|
|
||||||
if self.file_exists:
|
if self.file_exists:
|
||||||
raise OSError(f"File already exists '{self.file_path}'")
|
raise OSError(f"File already exists '{self.file_path}'")
|
||||||
return super().get_blob_writer(peer_address, peer_port)
|
return super().get_blob_writer(transport)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
||||||
|
|
|
@ -1,18 +1,23 @@
|
||||||
import typing
|
import typing
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from lbry.error import InvalidBlobHashError, InvalidDataError
|
from lbry.error import InvalidBlobHashError, InvalidDataError
|
||||||
from lbry.cryptoutils import get_lbry_hash_obj
|
from lbry.cryptoutils import get_lbry_hash_obj
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HashBlobWriter:
|
class HashBlobWriter:
|
||||||
def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int],
|
def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int],
|
||||||
finished: asyncio.Future):
|
finished: 'asyncio.Future[bytes]', pause_other_writers: typing.Callable[[], None],
|
||||||
|
resume_other_writers: typing.Callable[[], None]):
|
||||||
self.expected_blob_hash = expected_blob_hash
|
self.expected_blob_hash = expected_blob_hash
|
||||||
self.get_length = get_length
|
self.get_length = get_length
|
||||||
|
self.pause_other_writers = pause_other_writers
|
||||||
|
self.resume_other_writers = resume_other_writers
|
||||||
|
self.paused_others = False
|
||||||
self.buffer = BytesIO()
|
self.buffer = BytesIO()
|
||||||
self.finished = finished
|
self.finished = finished
|
||||||
self.finished.add_done_callback(lambda *_: self.close_handle())
|
self.finished.add_done_callback(lambda *_: self.close_handle())
|
||||||
|
@ -59,6 +64,9 @@ class HashBlobWriter:
|
||||||
elif self.finished and not (self.finished.done() or self.finished.cancelled()):
|
elif self.finished and not (self.finished.done() or self.finished.cancelled()):
|
||||||
self.finished.set_result(self.buffer.getvalue())
|
self.finished.set_result(self.buffer.getvalue())
|
||||||
self.close_handle()
|
self.close_handle()
|
||||||
|
if self.len_so_far >= 64000 and not self.paused_others:
|
||||||
|
self.paused_others = True
|
||||||
|
self.pause_other_writers()
|
||||||
|
|
||||||
def close_handle(self):
|
def close_handle(self):
|
||||||
if not self.finished.done():
|
if not self.finished.done():
|
||||||
|
@ -66,3 +74,5 @@ class HashBlobWriter:
|
||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
self.buffer.close()
|
self.buffer.close()
|
||||||
self.buffer = None
|
self.buffer = None
|
||||||
|
if self.paused_others:
|
||||||
|
self.resume_other_writers()
|
||||||
|
|
|
@ -85,6 +85,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
self._blob_bytes_received += len(data)
|
self._blob_bytes_received += len(data)
|
||||||
try:
|
try:
|
||||||
self.writer.write(data)
|
self.writer.write(data)
|
||||||
|
|
||||||
except IOError as err:
|
except IOError as err:
|
||||||
log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err)
|
log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err)
|
||||||
if self._response_fut and not self._response_fut.done():
|
if self._response_fut and not self._response_fut.done():
|
||||||
|
@ -180,7 +181,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
||||||
return 0, self.transport
|
return 0, self.transport
|
||||||
try:
|
try:
|
||||||
self._blob_bytes_received = 0
|
self._blob_bytes_received = 0
|
||||||
self.blob, self.writer = blob, blob.get_blob_writer(self.peer_address, self.peer_port)
|
self.blob, self.writer = blob, blob.get_blob_writer(self.transport)
|
||||||
self._response_fut = asyncio.Future(loop=self.loop)
|
self._response_fut = asyncio.Future(loop=self.loop)
|
||||||
return await self._download_blob()
|
return await self._download_blob()
|
||||||
except OSError:
|
except OSError:
|
||||||
|
|
|
@ -72,7 +72,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
return
|
return
|
||||||
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
||||||
if not self.sd_blob.get_is_verified():
|
if not self.sd_blob.get_is_verified():
|
||||||
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
self.writer = self.sd_blob.get_blob_writer(self.transport)
|
||||||
self.incoming.set()
|
self.incoming.set()
|
||||||
self.send_response({"send_sd_blob": True})
|
self.send_response({"send_sd_blob": True})
|
||||||
try:
|
try:
|
||||||
|
@ -111,7 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
return
|
return
|
||||||
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
||||||
if not blob.get_is_verified():
|
if not blob.get_is_verified():
|
||||||
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
self.writer = blob.get_blob_writer(self.transport)
|
||||||
self.incoming.set()
|
self.incoming.set()
|
||||||
self.send_response({"send_blob": True})
|
self.send_response({"send_blob": True})
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -10,6 +10,22 @@ from lbry.blob.blob_manager import BlobManager
|
||||||
from lbry.blob.blob_file import BlobFile, BlobBuffer, AbstractBlob
|
from lbry.blob.blob_file import BlobFile, BlobBuffer, AbstractBlob
|
||||||
|
|
||||||
|
|
||||||
|
class MockTransport(asyncio.Transport):
|
||||||
|
def __init__(self):
|
||||||
|
self.closed = asyncio.Event()
|
||||||
|
self.paused = asyncio.Event()
|
||||||
|
self._extra = {}
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed.set()
|
||||||
|
|
||||||
|
def pause_reading(self) -> None:
|
||||||
|
self.paused.set()
|
||||||
|
|
||||||
|
def resume_reading(self) -> None:
|
||||||
|
self.paused.clear()
|
||||||
|
|
||||||
|
|
||||||
class TestBlob(AsyncioTestCase):
|
class TestBlob(AsyncioTestCase):
|
||||||
blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed"
|
blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed"
|
||||||
blob_bytes = b'1' * ((2 * 2 ** 20) - 1)
|
blob_bytes = b'1' * ((2 * 2 ** 20) - 1)
|
||||||
|
@ -41,7 +57,7 @@ class TestBlob(AsyncioTestCase):
|
||||||
|
|
||||||
async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None):
|
async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None):
|
||||||
blob = self._get_blob(blob_class, blob_directory=blob_directory)
|
blob = self._get_blob(blob_class, blob_directory=blob_directory)
|
||||||
writers = [blob.get_blob_writer('1.2.3.4', port) for port in range(5)]
|
writers = [blob.get_blob_writer(MockTransport()) for _ in range(5)]
|
||||||
self.assertEqual(5, len(blob.writers))
|
self.assertEqual(5, len(blob.writers))
|
||||||
|
|
||||||
# test that writing too much causes the writer to fail with InvalidDataError and to be removed
|
# test that writing too much causes the writer to fail with InvalidDataError and to be removed
|
||||||
|
@ -137,14 +153,15 @@ class TestBlob(AsyncioTestCase):
|
||||||
blob_directory = tempfile.mkdtemp()
|
blob_directory = tempfile.mkdtemp()
|
||||||
self.addCleanup(lambda: shutil.rmtree(blob_directory))
|
self.addCleanup(lambda: shutil.rmtree(blob_directory))
|
||||||
blob = self._get_blob(BlobBuffer, blob_directory=blob_directory)
|
blob = self._get_blob(BlobBuffer, blob_directory=blob_directory)
|
||||||
writer = blob.get_blob_writer('1.1.1.1', 1337)
|
transport = MockTransport()
|
||||||
|
writer = blob.get_blob_writer(transport)
|
||||||
self.assertEqual(1, len(blob.writers))
|
self.assertEqual(1, len(blob.writers))
|
||||||
with self.assertRaises(OSError):
|
with self.assertRaises(OSError):
|
||||||
blob.get_blob_writer('1.1.1.1', 1337)
|
blob.get_blob_writer(transport)
|
||||||
writer.close_handle()
|
writer.close_handle()
|
||||||
self.assertTrue(blob.writers[('1.1.1.1', 1337)].closed())
|
self.assertTrue(blob.writers[(transport)].closed())
|
||||||
writer = blob.get_blob_writer('1.1.1.1', 1337)
|
writer = blob.get_blob_writer(transport)
|
||||||
self.assertEqual(blob.writers[('1.1.1.1', 1337)], writer)
|
self.assertEqual(blob.writers[transport], writer)
|
||||||
writer.close_handle()
|
writer.close_handle()
|
||||||
await asyncio.sleep(0.000000001) # flush callbacks
|
await asyncio.sleep(0.000000001) # flush callbacks
|
||||||
self.assertEqual(0, len(blob.writers))
|
self.assertEqual(0, len(blob.writers))
|
||||||
|
|
|
@ -135,12 +135,19 @@ class TestBlobExchange(BlobExchangeTestBase):
|
||||||
write_blob(blob_bytes)
|
write_blob(blob_bytes)
|
||||||
blob._write_blob = wrap_write_blob
|
blob._write_blob = wrap_write_blob
|
||||||
|
|
||||||
writer1 = blob.get_blob_writer(peer_port=1)
|
t1 = asyncio.Transport()
|
||||||
writer2 = blob.get_blob_writer(peer_port=2)
|
t2 = asyncio.Transport()
|
||||||
|
t1.pause_reading = lambda: None
|
||||||
|
t1.resume_reading = lambda: None
|
||||||
|
t2.pause_reading = lambda: None
|
||||||
|
t2.resume_reading = lambda: None
|
||||||
|
|
||||||
|
writer1 = blob.get_blob_writer(t1)
|
||||||
|
writer2 = blob.get_blob_writer(t2)
|
||||||
reader1_ctx_before_write = blob.reader_context()
|
reader1_ctx_before_write = blob.reader_context()
|
||||||
|
|
||||||
with self.assertRaises(OSError):
|
with self.assertRaises(OSError):
|
||||||
blob.get_blob_writer(peer_port=2)
|
blob.get_blob_writer(t2)
|
||||||
with self.assertRaises(OSError):
|
with self.assertRaises(OSError):
|
||||||
with blob.reader_context():
|
with blob.reader_context():
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Reference in a new issue