update get_blob_writer to accept an asyncio.Transport instead of an address/port
This commit is contained in:
parent
cb84abfb88
commit
49c51267dc
5 changed files with 22 additions and 19 deletions
|
@ -79,7 +79,7 @@ class AbstractBlob:
|
|||
self.length = length
|
||||
self.blob_completed_callback = blob_completed_callback
|
||||
self.blob_directory = blob_directory
|
||||
self.writers: typing.Dict[typing.Tuple[typing.Optional[str], typing.Optional[int]], HashBlobWriter] = {}
|
||||
self.writers: typing.Dict[typing.Optional[asyncio.Transport], HashBlobWriter] = {}
|
||||
self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||
self.writing: asyncio.Event = asyncio.Event(loop=self.loop)
|
||||
self.readers: typing.List[typing.BinaryIO] = []
|
||||
|
@ -199,17 +199,16 @@ class AbstractBlob:
|
|||
if self.blob_completed_callback:
|
||||
self.blob_completed_callback(self)
|
||||
|
||||
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
|
||||
peer_port: typing.Optional[int] = None) -> HashBlobWriter:
|
||||
if (peer_address, peer_port) in self.writers and not self.writers[(peer_address, peer_port)].closed():
|
||||
raise OSError(f"attempted to download blob twice from {peer_address}:{peer_port}")
|
||||
def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter:
|
||||
if transport and transport in self.writers and not self.writers[transport].closed():
|
||||
raise OSError(f"attempted to download blob twice from {transport.get_extra_info('peername')}")
|
||||
fut = asyncio.Future(loop=self.loop)
|
||||
writer = HashBlobWriter(self.blob_hash, self.get_length, fut)
|
||||
self.writers[(peer_address, peer_port)] = writer
|
||||
self.writers[transport] = writer
|
||||
|
||||
def remove_writer(_):
|
||||
if (peer_address, peer_port) in self.writers:
|
||||
del self.writers[(peer_address, peer_port)]
|
||||
if transport in self.writers:
|
||||
del self.writers[transport]
|
||||
|
||||
fut.add_done_callback(remove_writer)
|
||||
|
||||
|
@ -299,11 +298,10 @@ class BlobFile(AbstractBlob):
|
|||
def is_writeable(self) -> bool:
|
||||
return super().is_writeable() and not os.path.isfile(self.file_path)
|
||||
|
||||
def get_blob_writer(self, peer_address: typing.Optional[str] = None,
|
||||
peer_port: typing.Optional[str] = None) -> HashBlobWriter:
|
||||
def get_blob_writer(self, transport: typing.Optional[asyncio.Transport] = None) -> HashBlobWriter:
|
||||
if self.file_exists:
|
||||
raise OSError(f"File already exists '{self.file_path}'")
|
||||
return super().get_blob_writer(peer_address, peer_port)
|
||||
return super().get_blob_writer(transport)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _reader_context(self) -> typing.ContextManager[typing.BinaryIO]:
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import typing
|
||||
import logging
|
||||
import asyncio
|
||||
from io import BytesIO
|
||||
from lbry.error import InvalidBlobHashError, InvalidDataError
|
||||
from lbry.cryptoutils import get_lbry_hash_obj
|
||||
if typing.TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HashBlobWriter:
|
||||
def __init__(self, expected_blob_hash: str, get_length: typing.Callable[[], int],
|
||||
finished: asyncio.Future):
|
||||
finished: 'asyncio.Future[bytes]'):
|
||||
self.expected_blob_hash = expected_blob_hash
|
||||
self.get_length = get_length
|
||||
self.buffer = BytesIO()
|
||||
|
|
|
@ -85,6 +85,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
self._blob_bytes_received += len(data)
|
||||
try:
|
||||
self.writer.write(data)
|
||||
|
||||
except IOError as err:
|
||||
log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err)
|
||||
if self._response_fut and not self._response_fut.done():
|
||||
|
@ -180,7 +181,7 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
|
|||
return 0, self.transport
|
||||
try:
|
||||
self._blob_bytes_received = 0
|
||||
self.blob, self.writer = blob, blob.get_blob_writer(self.peer_address, self.peer_port)
|
||||
self.blob, self.writer = blob, blob.get_blob_writer(self.transport)
|
||||
self._response_fut = asyncio.Future(loop=self.loop)
|
||||
return await self._download_blob()
|
||||
except OSError:
|
||||
|
|
|
@ -72,7 +72,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
return
|
||||
self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size'])
|
||||
if not self.sd_blob.get_is_verified():
|
||||
self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||
self.writer = self.sd_blob.get_blob_writer(self.transport)
|
||||
self.incoming.set()
|
||||
self.send_response({"send_sd_blob": True})
|
||||
try:
|
||||
|
@ -111,7 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
return
|
||||
blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size'])
|
||||
if not blob.get_is_verified():
|
||||
self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername'))
|
||||
self.writer = blob.get_blob_writer(self.transport)
|
||||
self.incoming.set()
|
||||
self.send_response({"send_blob": True})
|
||||
try:
|
||||
|
|
|
@ -135,12 +135,15 @@ class TestBlobExchange(BlobExchangeTestBase):
|
|||
write_blob(blob_bytes)
|
||||
blob._write_blob = wrap_write_blob
|
||||
|
||||
writer1 = blob.get_blob_writer(peer_port=1)
|
||||
writer2 = blob.get_blob_writer(peer_port=2)
|
||||
t1 = asyncio.Transport()
|
||||
t2 = asyncio.Transport()
|
||||
|
||||
writer1 = blob.get_blob_writer(t1)
|
||||
writer2 = blob.get_blob_writer(t2)
|
||||
reader1_ctx_before_write = blob.reader_context()
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
blob.get_blob_writer(peer_port=2)
|
||||
blob.get_blob_writer(t2)
|
||||
with self.assertRaises(OSError):
|
||||
with blob.reader_context():
|
||||
pass
|
||||
|
|
Loading…
Add table
Reference in a new issue