Compare commits

...
Sign in to create a new pull request.

3 commits

Author SHA1 Message Date
Jack Robison
a288adbfc6
fix test 2019-08-12 16:29:22 -04:00
Jack Robison
8871a611f9
throttle other blob download attempts once the first gets to 64k bytes written 2019-08-12 15:19:42 -04:00
Jack Robison
49c51267dc
update get_blob_writer to accept an asyncio.Transport instead of an address/port 2019-08-12 15:19:38 -04:00
6 changed files with 73 additions and 26 deletions

View file

@ -79,7 +79,7 @@ class AbstractBlob:
self.length = length self.length = length
self.blob_completed_callback = blob_completed_callback self.blob_completed_callback = blob_completed_callback
self.blob_directory = blob_directory self.blob_directory = blob_directory
self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {} self.writers: typing.Dict[typing.Optional[asyncio.Transport], HashBlobWriter] = {}
self.verified: asyncio.Event = asyncio.Event(loop=self.loop) self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.writing: asyncio.Event = asyncio.Event(loop=self.loop) self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
self.readers: typing.List[typing.BinaryIO] = [] self.readers: typing.List[typing.BinaryIO] = []
@ -111,6 +111,16 @@ class AbstractBlob:
def _write_blob(self, blob_bytes: bytes): def _write_blob(self, blob_bytes: bytes):
raise NotImplementedError() raise NotImplementedError()
def pause_other_writers(self, transport: asyncio.Transport):
for other in self.writers:
if other and other is not transport:
other.pause_reading()
def resume_other_writers(self, transport: asyncio.Transport):
for other in self.writers:
if other and other is not transport:
other.resume_reading()
def set_length(self, length) -> None: def set_length(self, length) -> None:
if self.length is not None and length == self.length: if self.length is not None and length == self.length:
return return
@ -199,17 +209,20 @@ class AbstractBlob:
if self.blob_completed_callback: if self.blob_completed_callback:
self.blob_completed_callback(self) self.blob_completed_callback(self)
def get_blob_writer(self, peer_address: typing.Optional[str] = None, def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter:
peer_port: typing.Optional[int] = None) -> HashBlobWriter: if transport and transport in self.writers and not self.writers[transport].closed():
if (peer_address, peer_port) in self.writers and not self.writers[(peer_address, peer_port)].closed(): raise OSError(f"attempted to download blob twice from {transport.get_extra_info('peername')}")
raise OSError(f"attempted to download blob twice from {peer_address}:{peer_port}")
fut = asyncio.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut) writer = HashBlobWriter(
self.writers[(peer_address, peer_port)] = writer self.blob_hash, self.get_length, fut,
lambda: self.pause_other_writers(transport),
lambda: self.resume_other_writers(transport)
)
self.writers[transport] = writer
def remove_writer(_): def remove_writer(_):
if (peer_address, peer_port) in self.writers: if transport in self.writers:
del self.writers[(peer_address, peer_port)] del self.writers[transport]
fut.add_done_callback(remove_writer) fut.add_done_callback(remove_writer)
@ -299,11 +312,10 @@ class BlobFile(AbstractBlob):
def is_writeable(self) -> bool: def is_writeable(self) -> bool:
return super().is_writeable() and not os.path.isfile(self.file_path) return super().is_writeable() and not os.path.isfile(self.file_path)
def get_blob_writer(self, peer_address: typing.Optional[str] = None, def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter:
peer_port: typing.Optional[str] = None) -> HashBlobWriter:
if self.file_exists: if self.file_exists:
raise OSError(f"File already exists '{self.file_path}'") raise OSError(f"File already exists '{self.file_path}'")
return super().get_blob_writer(peer_address, peer_port) return super().get_blob_writer(transport)
@contextlib.contextmanager @contextlib.contextmanager
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:

View file

@ -1,18 +1,23 @@
import typing import typing
import logging import logging
import asyncio
from io import BytesIO from io import BytesIO
from lbry.error import InvalidBlobHashError, InvalidDataError from lbry.error import InvalidBlobHashError, InvalidDataError
from lbry.cryptoutils import get_lbry_hash_obj from lbry.cryptoutils import get_lbry_hash_obj
if typing.TYPE_CHECKING:
import asyncio
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class HashBlobWriter: class HashBlobWriter:
def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int], def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int],
finished: asyncio.Future): finished: 'asyncio.Future[bytes]', pause_other_writers: typing.Callable[[], None],
resume_other_writers: typing.Callable[[], None]):
self.expected_blob_hash = expected_blob_hash self.expected_blob_hash = expected_blob_hash
self.get_length = get_length self.get_length = get_length
self.pause_other_writers = pause_other_writers
self.resume_other_writers = resume_other_writers
self.paused_others = False
self.buffer = BytesIO() self.buffer = BytesIO()
self.finished = finished self.finished = finished
self.finished.add_done_callback(lambda *_: self.close_handle()) self.finished.add_done_callback(lambda *_: self.close_handle())
@ -59,6 +64,9 @@ class HashBlobWriter:
elif self.finished and not (self.finished.done() or self.finished.cancelled()): elif self.finished and not (self.finished.done() or self.finished.cancelled()):
self.finished.set_result(self.buffer.getvalue()) self.finished.set_result(self.buffer.getvalue())
self.close_handle() self.close_handle()
if self.len_so_far >= 64000 and not self.paused_others:
self.paused_others = True
self.pause_other_writers()
def close_handle(self): def close_handle(self):
if not self.finished.done(): if not self.finished.done():
@ -66,3 +74,5 @@ class HashBlobWriter:
if self.buffer is not None: if self.buffer is not None:
self.buffer.close() self.buffer.close()
self.buffer = None self.buffer = None
if self.paused_others:
self.resume_other_writers()

View file

@ -85,6 +85,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self._blob_bytes_received += len(data) self._blob_bytes_received += len(data)
try: try:
self.writer.write(data) self.writer.write(data)
except IOError as err: except IOError as err:
log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err) log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err)
if self._response_fut and not self._response_fut.done(): if self._response_fut and not self._response_fut.done():
@ -180,7 +181,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
return 0, self.transport return 0, self.transport
try: try:
self._blob_bytes_received = 0 self._blob_bytes_received = 0
self.blob, self.writer = blob, blob.get_blob_writer(self.peer_address, self.peer_port) self.blob, self.writer = blob, blob.get_blob_writer(self.transport)
self._response_fut = asyncio.Future(loop=self.loop) self._response_fut = asyncio.Future(loop=self.loop)
return await self._download_blob() return await self._download_blob()
except OSError: except OSError:

View file

@ -72,7 +72,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
return return
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
if not self.sd_blob.get_is_verified(): if not self.sd_blob.get_is_verified():
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = self.sd_blob.get_blob_writer(self.transport)
self.incoming.set() self.incoming.set()
self.send_response({"send_sd_blob": True}) self.send_response({"send_sd_blob": True})
try: try:
@ -111,7 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
return return
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
if not blob.get_is_verified(): if not blob.get_is_verified():
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = blob.get_blob_writer(self.transport)
self.incoming.set() self.incoming.set()
self.send_response({"send_blob": True}) self.send_response({"send_blob": True})
try: try:

View file

@ -10,6 +10,22 @@ from lbry.blob.blob_manager import BlobManager
from lbry.blob.blob_file import BlobFile, BlobBuffer, AbstractBlob from lbry.blob.blob_file import BlobFile, BlobBuffer, AbstractBlob
class MockTransport(asyncio.Transport):
def __init__(self):
self.closed = asyncio.Event()
self.paused = asyncio.Event()
self._extra = {}
def close(self):
self.closed.set()
def pause_reading(self) -> None:
self.paused.set()
def resume_reading(self) -> None:
self.paused.clear()
class TestBlob(AsyncioTestCase): class TestBlob(AsyncioTestCase):
blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed"
blob_bytes = b'1' * ((2 * 2 ** 20) - 1) blob_bytes = b'1' * ((2 * 2 ** 20) - 1)
@ -41,7 +57,7 @@ class TestBlob(AsyncioTestCase):
async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None): async def _test_close_writers_on_finished(self, blob_class=AbstractBlob, blob_directory=None):
blob = self._get_blob(blob_class, blob_directory=blob_directory) blob = self._get_blob(blob_class, blob_directory=blob_directory)
writers = [blob.get_blob_writer('1.2.3.4', port) for port in range(5)] writers = [blob.get_blob_writer(MockTransport()) for _ in range(5)]
self.assertEqual(5, len(blob.writers)) self.assertEqual(5, len(blob.writers))
# test that writing too much causes the writer to fail with InvalidDataError and to be removed # test that writing too much causes the writer to fail with InvalidDataError and to be removed
@ -137,14 +153,15 @@ class TestBlob(AsyncioTestCase):
blob_directory = tempfile.mkdtemp() blob_directory = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(blob_directory)) self.addCleanup(lambda: shutil.rmtree(blob_directory))
blob = self._get_blob(BlobBuffer, blob_directory=blob_directory) blob = self._get_blob(BlobBuffer, blob_directory=blob_directory)
writer = blob.get_blob_writer('1.1.1.1', 1337) transport = MockTransport()
writer = blob.get_blob_writer(transport)
self.assertEqual(1, len(blob.writers)) self.assertEqual(1, len(blob.writers))
with self.assertRaises(OSError): with self.assertRaises(OSError):
blob.get_blob_writer('1.1.1.1', 1337) blob.get_blob_writer(transport)
writer.close_handle() writer.close_handle()
self.assertTrue(blob.writers[('1.1.1.1', 1337)].closed()) self.assertTrue(blob.writers[(transport)].closed())
writer = blob.get_blob_writer('1.1.1.1', 1337) writer = blob.get_blob_writer(transport)
self.assertEqual(blob.writers[('1.1.1.1', 1337)], writer) self.assertEqual(blob.writers[transport], writer)
writer.close_handle() writer.close_handle()
await asyncio.sleep(0.000000001) # flush callbacks await asyncio.sleep(0.000000001) # flush callbacks
self.assertEqual(0, len(blob.writers)) self.assertEqual(0, len(blob.writers))

View file

@ -135,12 +135,19 @@ class TestBlobExchange(BlobExchangeTestBase):
write_blob(blob_bytes) write_blob(blob_bytes)
blob._write_blob = wrap_write_blob blob._write_blob = wrap_write_blob
writer1 = blob.get_blob_writer(peer_port=1) t1 = asyncio.Transport()
writer2 = blob.get_blob_writer(peer_port=2) t2 = asyncio.Transport()
t1.pause_reading = lambda: None
t1.resume_reading = lambda: None
t2.pause_reading = lambda: None
t2.resume_reading = lambda: None
writer1 = blob.get_blob_writer(t1)
writer2 = blob.get_blob_writer(t2)
reader1_ctx_before_write = blob.reader_context() reader1_ctx_before_write = blob.reader_context()
with self.assertRaises(OSError): with self.assertRaises(OSError):
blob.get_blob_writer(peer_port=2) blob.get_blob_writer(t2)
with self.assertRaises(OSError): with self.assertRaises(OSError):
with blob.reader_context(): with blob.reader_context():
pass pass