update get_blob_writer to accept an asyncio.Transport instead of an address/port

This commit is contained in:
Jack Robison 2019-08-12 12:53:25 -04:00
parent cb84abfb88
commit 49c51267dc
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
5 changed files with 22 additions and 19 deletions

View file

@ -79,7 +79,7 @@ class AbstractBlob:
self.length = length self.length = length
self.blob_completed_callback = blob_completed_callback self.blob_completed_callback = blob_completed_callback
self.blob_directory = blob_directory self.blob_directory = blob_directory
self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {} self.writers: typing.Dict[typing.Optional[asyncio.Transport], HashBlobWriter] = {}
self.verified: asyncio.Event = asyncio.Event(loop=self.loop) self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.writing: asyncio.Event = asyncio.Event(loop=self.loop) self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
self.readers: typing.List[typing.BinaryIO] = [] self.readers: typing.List[typing.BinaryIO] = []
@ -199,17 +199,16 @@ class AbstractBlob:
if self.blob_completed_callback: if self.blob_completed_callback:
self.blob_completed_callback(self) self.blob_completed_callback(self)
def get_blob_writer(self, peer_address: typing.Optional[str] = None, def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter:
peer_port: typing.Optional[int] = None) -> HashBlobWriter: if transport and transport in self.writers and not self.writers[transport].closed():
if (peer_address, peer_port) in self.writers and not self.writers[(peer_address, peer_port)].closed(): raise OSError(f"attempted to download blob twice from {transport.get_extra_info('peername')}")
raise OSError(f"attempted to download blob twice from {peer_address}:{peer_port}")
fut = asyncio.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut) writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
self.writers[(peer_address, peer_port)] = writer self.writers[transport] = writer
def remove_writer(_): def remove_writer(_):
if (peer_address, peer_port) in self.writers: if transport in self.writers:
del self.writers[(peer_address, peer_port)] del self.writers[transport]
fut.add_done_callback(remove_writer) fut.add_done_callback(remove_writer)
@ -299,11 +298,10 @@ class BlobFile(AbstractBlob):
def is_writeable(self) -> bool: def is_writeable(self) -> bool:
return super().is_writeable() and not os.path.isfile(self.file_path) return super().is_writeable() and not os.path.isfile(self.file_path)
def get_blob_writer(self, peer_address: typing.Optional[str] = None, def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter:
peer_port: typing.Optional[str] = None) -> HashBlobWriter:
if self.file_exists: if self.file_exists:
raise OSError(f"File already exists '{self.file_path}'") raise OSError(f"File already exists '{self.file_path}'")
return super().get_blob_writer(peer_address, peer_port) return super().get_blob_writer(transport)
@contextlib.contextmanager @contextlib.contextmanager
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]: def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:

View file

@ -1,16 +1,17 @@
import typing import typing
import logging import logging
import asyncio
from io import BytesIO from io import BytesIO
from lbry.error import InvalidBlobHashError, InvalidDataError from lbry.error import InvalidBlobHashError, InvalidDataError
from lbry.cryptoutils import get_lbry_hash_obj from lbry.cryptoutils import get_lbry_hash_obj
if typing.TYPE_CHECKING:
import asyncio
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class HashBlobWriter: class HashBlobWriter:
def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int], def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int],
finished: asyncio.Future): finished: 'asyncio.Future[bytes]'):
self.expected_blob_hash = expected_blob_hash self.expected_blob_hash = expected_blob_hash
self.get_length = get_length self.get_length = get_length
self.buffer = BytesIO() self.buffer = BytesIO()

View file

@ -85,6 +85,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self._blob_bytes_received += len(data) self._blob_bytes_received += len(data)
try: try:
self.writer.write(data) self.writer.write(data)
except IOError as err: except IOError as err:
log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err) log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err)
if self._response_fut and not self._response_fut.done(): if self._response_fut and not self._response_fut.done():
@ -180,7 +181,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
return 0, self.transport return 0, self.transport
try: try:
self._blob_bytes_received = 0 self._blob_bytes_received = 0
self.blob, self.writer = blob, blob.get_blob_writer(self.peer_address, self.peer_port) self.blob, self.writer = blob, blob.get_blob_writer(self.transport)
self._response_fut = asyncio.Future(loop=self.loop) self._response_fut = asyncio.Future(loop=self.loop)
return await self._download_blob() return await self._download_blob()
except OSError: except OSError:

View file

@ -72,7 +72,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
return return
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
if not self.sd_blob.get_is_verified(): if not self.sd_blob.get_is_verified():
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = self.sd_blob.get_blob_writer(self.transport)
self.incoming.set() self.incoming.set()
self.send_response({"send_sd_blob": True}) self.send_response({"send_sd_blob": True})
try: try:
@ -111,7 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
return return
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
if not blob.get_is_verified(): if not blob.get_is_verified():
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername')) self.writer = blob.get_blob_writer(self.transport)
self.incoming.set() self.incoming.set()
self.send_response({"send_blob": True}) self.send_response({"send_blob": True})
try: try:

View file

@ -135,12 +135,15 @@ class TestBlobExchange(BlobExchangeTestBase):
write_blob(blob_bytes) write_blob(blob_bytes)
blob._write_blob = wrap_write_blob blob._write_blob = wrap_write_blob
writer1 = blob.get_blob_writer(peer_port=1) t1 = asyncio.Transport()
writer2 = blob.get_blob_writer(peer_port=2) t2 = asyncio.Transport()
writer1 = blob.get_blob_writer(t1)
writer2 = blob.get_blob_writer(t2)
reader1_ctx_before_write = blob.reader_context() reader1_ctx_before_write = blob.reader_context()
with self.assertRaises(OSError): with self.assertRaises(OSError):
blob.get_blob_writer(peer_port=2) blob.get_blob_writer(t2)
with self.assertRaises(OSError): with self.assertRaises(OSError):
with blob.reader_context(): with blob.reader_context():
pass pass