Merge pull request #1889 from lbryio/download_stuc

Refactor peer scoring and maintain connections to fix stuck downloads, remove locking from ping queue
This commit is contained in:
Jack Robison 2019-02-08 15:10:13 -05:00 committed by GitHub
commit 869a8b712b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 170 additions and 128 deletions

View file

@ -66,7 +66,7 @@ class BlobFile:
self.verified: asyncio.Event = asyncio.Event(loop=self.loop) self.verified: asyncio.Event = asyncio.Event(loop=self.loop)
self.finished_writing = asyncio.Event(loop=loop) self.finished_writing = asyncio.Event(loop=loop)
self.blob_write_lock = asyncio.Lock(loop=loop) self.blob_write_lock = asyncio.Lock(loop=loop)
if os.path.isfile(os.path.join(blob_dir, blob_hash)): if self.file_exists:
length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size) length = int(os.stat(os.path.join(blob_dir, blob_hash)).st_size)
self.length = length self.length = length
self.verified.set() self.verified.set()
@ -74,6 +74,10 @@ class BlobFile:
self.saved_verified_blob = False self.saved_verified_blob = False
self.blob_completed_callback = blob_completed_callback self.blob_completed_callback = blob_completed_callback
@property
def file_exists(self):
return os.path.isfile(self.file_path)
def writer_finished(self, writer: HashBlobWriter): def writer_finished(self, writer: HashBlobWriter):
def callback(finished: asyncio.Future): def callback(finished: asyncio.Future):
try: try:
@ -116,7 +120,7 @@ class BlobFile:
self.verified.set() self.verified.set()
def open_for_writing(self) -> HashBlobWriter: def open_for_writing(self) -> HashBlobWriter:
if os.path.exists(self.file_path): if self.file_exists:
raise OSError(f"File already exists '{self.file_path}'") raise OSError(f"File already exists '{self.file_path}'")
fut = asyncio.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
writer = HashBlobWriter(self.blob_hash, self.get_length, fut) writer = HashBlobWriter(self.blob_hash, self.get_length, fut)

View file

@ -32,6 +32,8 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
if self._response_fut and not self._response_fut.done(): if self._response_fut and not self._response_fut.done():
self._response_fut.cancel() self._response_fut.cancel()
return return
if self._blob_bytes_received and not self.writer.closed():
return self._write(data)
response = BlobResponse.deserialize(data) response = BlobResponse.deserialize(data)
@ -51,25 +53,28 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
if response.blob_data and self.writer and not self.writer.closed(): if response.blob_data and self.writer and not self.writer.closed():
log.debug("got %i blob bytes from %s:%i", len(response.blob_data), self.peer_address, self.peer_port) log.debug("got %i blob bytes from %s:%i", len(response.blob_data), self.peer_address, self.peer_port)
# write blob bytes if we're writing a blob and have blob bytes to write # write blob bytes if we're writing a blob and have blob bytes to write
if len(response.blob_data) > (self.blob.get_length() - self._blob_bytes_received): self._write(response.blob_data)
data = response.blob_data[:(self.blob.get_length() - self._blob_bytes_received)]
log.warning("got more than asked from %s:%d, probable sendfile bug", self.peer_address, self.peer_port)
else:
data = response.blob_data
self._blob_bytes_received += len(data)
try:
self.writer.write(data)
return
except IOError as err:
log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err)
if self._response_fut and not self._response_fut.done():
self._response_fut.set_exception(err)
except (asyncio.CancelledError, asyncio.TimeoutError) as err: # TODO: is this needed?
log.error("%s downloading blob from %s:%i", str(err), self.peer_address, self.peer_port)
if self._response_fut and not self._response_fut.done():
self._response_fut.set_exception(err)
async def _download_blob(self) -> typing.Tuple[bool, bool]:
def _write(self, data):
if len(data) > (self.blob.get_length() - self._blob_bytes_received):
data = data[:(self.blob.get_length() - self._blob_bytes_received)]
log.warning("got more than asked from %s:%d, probable sendfile bug", self.peer_address, self.peer_port)
else:
data = data
self._blob_bytes_received += len(data)
try:
self.writer.write(data)
except IOError as err:
log.error("error downloading blob from %s:%i: %s", self.peer_address, self.peer_port, err)
if self._response_fut and not self._response_fut.done():
self._response_fut.set_exception(err)
except (asyncio.TimeoutError) as err: # TODO: is this needed?
log.error("%s downloading blob from %s:%i", str(err), self.peer_address, self.peer_port)
if self._response_fut and not self._response_fut.done():
self._response_fut.set_exception(err)
async def _download_blob(self) -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
""" """
:return: download success (bool), keep connection (bool) :return: download success (bool), keep connection (bool)
""" """
@ -87,43 +92,39 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
log.warning("%s not in availability response from %s:%i", self.blob.blob_hash, self.peer_address, log.warning("%s not in availability response from %s:%i", self.blob.blob_hash, self.peer_address,
self.peer_port) self.peer_port)
log.warning(response.to_dict()) log.warning(response.to_dict())
return False, False return self._blob_bytes_received, self.close()
elif availability_response.available_blobs and \ elif availability_response.available_blobs and \
availability_response.available_blobs != [self.blob.blob_hash]: availability_response.available_blobs != [self.blob.blob_hash]:
log.warning("blob availability response doesn't match our request from %s:%i", log.warning("blob availability response doesn't match our request from %s:%i",
self.peer_address, self.peer_port) self.peer_address, self.peer_port)
return False, False return self._blob_bytes_received, self.close()
if not price_response or price_response.blob_data_payment_rate != 'RATE_ACCEPTED': if not price_response or price_response.blob_data_payment_rate != 'RATE_ACCEPTED':
log.warning("data rate rejected by %s:%i", self.peer_address, self.peer_port) log.warning("data rate rejected by %s:%i", self.peer_address, self.peer_port)
return False, False return self._blob_bytes_received, self.close()
if not blob_response or blob_response.error: if not blob_response or blob_response.error:
log.warning("blob cant be downloaded from %s:%i", self.peer_address, self.peer_port) log.warning("blob cant be downloaded from %s:%i", self.peer_address, self.peer_port)
return False, True return self._blob_bytes_received, self.transport
if not blob_response.error and blob_response.blob_hash != self.blob.blob_hash: if not blob_response.error and blob_response.blob_hash != self.blob.blob_hash:
log.warning("incoming blob hash mismatch from %s:%i", self.peer_address, self.peer_port) log.warning("incoming blob hash mismatch from %s:%i", self.peer_address, self.peer_port)
return False, False return self._blob_bytes_received, self.close()
if self.blob.length is not None and self.blob.length != blob_response.length: if self.blob.length is not None and self.blob.length != blob_response.length:
log.warning("incoming blob unexpected length from %s:%i", self.peer_address, self.peer_port) log.warning("incoming blob unexpected length from %s:%i", self.peer_address, self.peer_port)
return False, False return self._blob_bytes_received, self.close()
msg = f"downloading {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}," \ msg = f"downloading {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}," \
f" timeout in {self.peer_timeout}" f" timeout in {self.peer_timeout}"
log.debug(msg) log.debug(msg)
msg = f"downloaded {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}" msg = f"downloaded {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}"
await asyncio.wait_for(self.writer.finished, self.peer_timeout, loop=self.loop) await asyncio.wait_for(self.writer.finished, self.peer_timeout, loop=self.loop)
log.info(msg) log.info(msg)
await self.blob.finished_writing.wait() # await self.blob.finished_writing.wait() not necessary, but a dangerous change. TODO: is it needed?
return True, True return self._blob_bytes_received, self.transport
except asyncio.CancelledError:
return False, True
except asyncio.TimeoutError: except asyncio.TimeoutError:
return False, False return self._blob_bytes_received, self.close()
except (InvalidBlobHashError, InvalidDataError): except (InvalidBlobHashError, InvalidDataError):
log.warning("invalid blob from %s:%i", self.peer_address, self.peer_port) log.warning("invalid blob from %s:%i", self.peer_address, self.peer_port)
return False, False return self._blob_bytes_received, self.close()
finally:
await self.close()
async def close(self): def close(self):
if self._response_fut and not self._response_fut.done(): if self._response_fut and not self._response_fut.done():
self._response_fut.cancel() self._response_fut.cancel()
if self.writer and not self.writer.closed(): if self.writer and not self.writer.closed():
@ -135,25 +136,26 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
self.transport.close() self.transport.close()
self.transport = None self.transport = None
async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[bool, bool]: async def download_blob(self, blob: 'BlobFile') -> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
if blob.get_is_verified(): if blob.get_is_verified() or blob.file_exists or blob.blob_write_lock.locked():
return False, True return 0, self.transport
try: try:
self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0 self.blob, self.writer, self._blob_bytes_received = blob, blob.open_for_writing(), 0
self._response_fut = asyncio.Future(loop=self.loop) self._response_fut = asyncio.Future(loop=self.loop)
return await self._download_blob() return await self._download_blob()
except OSError: except OSError as e:
log.error("race happened downloading from %s:%i", self.peer_address, self.peer_port) log.error("race happened downloading from %s:%i", self.peer_address, self.peer_port)
# i'm not sure how to fix this race condition - jack # i'm not sure how to fix this race condition - jack
return False, True log.exception(e)
return self._blob_bytes_received, self.transport
except asyncio.TimeoutError: except asyncio.TimeoutError:
if self._response_fut and not self._response_fut.done(): if self._response_fut and not self._response_fut.done():
self._response_fut.cancel() self._response_fut.cancel()
return False, False self.close()
return self._blob_bytes_received, None
except asyncio.CancelledError: except asyncio.CancelledError:
if self._response_fut and not self._response_fut.done(): self.close()
self._response_fut.cancel() raise
return False, True
def connection_made(self, transport: asyncio.Transport): def connection_made(self, transport: asyncio.Transport):
self.transport = transport self.transport = transport
@ -163,24 +165,31 @@ class BlobExchangeClientProtocol(asyncio.Protocol):
def connection_lost(self, reason): def connection_lost(self, reason):
log.debug("connection lost to %s:%i (reason: %s, %s)", self.peer_address, self.peer_port, str(reason), log.debug("connection lost to %s:%i (reason: %s, %s)", self.peer_address, self.peer_port, str(reason),
str(type(reason))) str(type(reason)))
self.transport = None self.close()
self.loop.create_task(self.close())
async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int, async def request_blob(loop: asyncio.BaseEventLoop, blob: 'BlobFile', address: str, tcp_port: int,
peer_connect_timeout: float, blob_download_timeout: float) -> typing.Tuple[bool, bool]: peer_connect_timeout: float, blob_download_timeout: float,
connected_transport: asyncio.Transport = None)\
-> typing.Tuple[int, typing.Optional[asyncio.Transport]]:
""" """
Returns [<downloaded blob>, <keep connection>] Returns [<downloaded blob>, <keep connection>]
""" """
if blob.get_is_verified() or blob.file_exists:
# file exists but not verified means someone is writing right now, give it time, come back later
return 0, connected_transport
protocol = BlobExchangeClientProtocol(loop, blob_download_timeout) protocol = BlobExchangeClientProtocol(loop, blob_download_timeout)
if blob.get_is_verified(): if connected_transport and not connected_transport.is_closing():
return False, True connected_transport.set_protocol(protocol)
protocol.connection_made(connected_transport)
log.debug("reusing connection for %s:%d", address, tcp_port)
else:
connected_transport = None
try: try:
await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port), if not connected_transport:
peer_connect_timeout, loop=loop) await asyncio.wait_for(loop.create_connection(lambda: protocol, address, tcp_port),
peer_connect_timeout, loop=loop)
return await protocol.download_blob(blob) return await protocol.download_blob(blob)
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionRefusedError, ConnectionAbortedError, OSError): except (asyncio.TimeoutError, ConnectionRefusedError, ConnectionAbortedError, OSError):
return False, False return 0, None
finally:
await protocol.close()

View file

@ -13,11 +13,6 @@ if typing.TYPE_CHECKING:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def drain_into(a: list, b: list):
while a:
b.append(a.pop())
class BlobDownloader: class BlobDownloader:
def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager', def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobFileManager',
peer_queue: asyncio.Queue): peer_queue: asyncio.Queue):
@ -28,34 +23,58 @@ class BlobDownloader:
self.active_connections: typing.Dict['KademliaPeer', asyncio.Task] = {} # active request_blob calls self.active_connections: typing.Dict['KademliaPeer', asyncio.Task] = {} # active request_blob calls
self.ignored: typing.Set['KademliaPeer'] = set() self.ignored: typing.Set['KademliaPeer'] = set()
self.scores: typing.Dict['KademliaPeer', int] = {} self.scores: typing.Dict['KademliaPeer', int] = {}
self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {}
self.rounds_won: typing.Dict['KademliaPeer', int] = {}
def should_race_continue(self):
if len(self.active_connections) >= self.config.max_connections_per_download:
return False
# if a peer won 3 or more blob races and is active as a downloader, stop the race so bandwidth improves
# the safe net side is that any failure will reset the peer score, triggering the race back
# TODO: this is a good idea for low bandwidth, but doesnt play nice on high bandwidth
# for peer, task in self.active_connections.items():
# if self.scores.get(peer, 0) >= 0 and self.rounds_won.get(peer, 0) >= 3 and not task.done():
# return False
return True
async def request_blob_from_peer(self, blob: 'BlobFile', peer: 'KademliaPeer'): async def request_blob_from_peer(self, blob: 'BlobFile', peer: 'KademliaPeer'):
if blob.get_is_verified(): if blob.get_is_verified():
return return
success, keep_connection = await request_blob( self.scores[peer] = self.scores.get(peer, 0) - 1 # starts losing score, to account for cancelled ones
transport = self.connections.get(peer)
start = self.loop.time()
bytes_received, transport = await request_blob(
self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout, self.loop, blob, peer.address, peer.tcp_port, self.config.peer_connect_timeout,
self.config.blob_download_timeout self.config.blob_download_timeout, connected_transport=transport
) )
if not keep_connection and peer not in self.ignored: if bytes_received == blob.get_length():
self.rounds_won[peer] = self.rounds_won.get(peer, 0) + 1
if not transport and peer not in self.ignored:
self.ignored.add(peer) self.ignored.add(peer)
log.debug("drop peer %s:%i", peer.address, peer.tcp_port) log.debug("drop peer %s:%i", peer.address, peer.tcp_port)
elif keep_connection: if peer in self.connections:
del self.connections[peer]
elif transport:
log.debug("keep peer %s:%i", peer.address, peer.tcp_port) log.debug("keep peer %s:%i", peer.address, peer.tcp_port)
if success: self.connections[peer] = transport
self.scores[peer] = self.scores.get(peer, 0) + 2 rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0
else: self.scores[peer] = rough_speed
self.scores[peer] = self.scores.get(peer, 0) - 1
async def new_peer_or_finished(self, blob: 'BlobFile'): async def new_peer_or_finished(self, blob: 'BlobFile'):
async def get_and_re_add_peers(): async def get_and_re_add_peers():
new_peers = await self.peer_queue.get() new_peers = await self.peer_queue.get()
self.peer_queue.put_nowait(new_peers) self.peer_queue.put_nowait(new_peers)
tasks = [self.loop.create_task(get_and_re_add_peers()), self.loop.create_task(blob.verified.wait())] tasks = [self.loop.create_task(get_and_re_add_peers()), self.loop.create_task(blob.verified.wait())]
active_tasks = list(self.active_connections.values())
try: try:
await asyncio.wait(tasks, loop=self.loop, return_when='FIRST_COMPLETED') await asyncio.wait(tasks + active_tasks, loop=self.loop, return_when='FIRST_COMPLETED')
except asyncio.CancelledError: finally:
drain_tasks(tasks) drain_tasks(tasks)
raise
def cleanup_active(self):
to_remove = [peer for (peer, task) in self.active_connections.items() if task.done()]
for peer in to_remove:
del self.active_connections[peer]
async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile': async def download_blob(self, blob_hash: str, length: typing.Optional[int] = None) -> 'BlobFile':
blob = self.blob_manager.get_blob(blob_hash, length) blob = self.blob_manager.get_blob(blob_hash, length)
@ -65,23 +84,23 @@ class BlobDownloader:
while not blob.get_is_verified(): while not blob.get_is_verified():
batch: typing.List['KademliaPeer'] = [] batch: typing.List['KademliaPeer'] = []
while not self.peer_queue.empty(): while not self.peer_queue.empty():
batch.extend(await self.peer_queue.get()) batch.extend(self.peer_queue.get_nowait())
batch.sort(key=lambda peer: self.scores.get(peer, 0), reverse=True)
log.debug(
"running, %d peers, %d ignored, %d active",
len(batch), len(self.ignored), len(self.active_connections)
)
for peer in batch: for peer in batch:
if len(self.active_connections) >= self.config.max_connections_per_download: if not self.should_race_continue():
break break
if peer not in self.active_connections and peer not in self.ignored: if peer not in self.active_connections and peer not in self.ignored:
log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port) log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port)
t = self.loop.create_task(self.request_blob_from_peer(blob, peer)) t = self.loop.create_task(self.request_blob_from_peer(blob, peer))
self.active_connections[peer] = t self.active_connections[peer] = t
t.add_done_callback(
lambda _:
None if peer not in self.active_connections else self.active_connections.pop(peer)
)
await self.new_peer_or_finished(blob) await self.new_peer_or_finished(blob)
to_re_add = list(set(filter(lambda peer: peer not in self.ignored, batch))) self.cleanup_active()
to_re_add.sort(key=lambda peer: self.scores.get(peer, 0), reverse=True) if batch:
if to_re_add: self.peer_queue.put_nowait(set(batch).difference(self.ignored))
self.peer_queue.put_nowait(to_re_add)
while self.active_connections: while self.active_connections:
peer, task = self.active_connections.popitem() peer, task = self.active_connections.popitem()
if task and not task.done(): if task and not task.done():
@ -95,6 +114,13 @@ class BlobDownloader:
if task and not task.done(): if task and not task.done():
task.cancel() task.cancel()
raise raise
except (OSError, Exception) as e:
log.exception(e)
raise e
def close(self):
for transport in self.connections.values():
transport.close()
async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', node: 'Node', async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager', node: 'Node',
@ -108,3 +134,4 @@ async def download_blob(loop, config: 'Config', blob_manager: 'BlobFileManager',
finally: finally:
if accumulate_task and not accumulate_task.done(): if accumulate_task and not accumulate_task.done():
accumulate_task.cancel() accumulate_task.cancel()
downloader.close()

View file

@ -475,8 +475,8 @@ class Config(CLIConfig):
# protocol timeouts # protocol timeouts
download_timeout = Float("Cumulative timeout for a stream to begin downloading before giving up", 30.0) download_timeout = Float("Cumulative timeout for a stream to begin downloading before giving up", 30.0)
blob_download_timeout = Float("Timeout to download a blob from a peer", 20.0) blob_download_timeout = Float("Timeout to download a blob from a peer", 30.0)
peer_connect_timeout = Float("Timeout to establish a TCP connection to a peer", 3.0) peer_connect_timeout = Float("Timeout to establish a TCP connection to a peer", 2.0)
node_rpc_timeout = Float("Timeout when making a DHT request", constants.rpc_timeout) node_rpc_timeout = Float("Timeout when making a DHT request", constants.rpc_timeout)
# blob announcement and download # blob announcement and download

View file

@ -64,7 +64,7 @@ class Node:
# ping the set of peers; upon success/failure the routing able and last replied/failed time will be updated # ping the set of peers; upon success/failure the routing able and last replied/failed time will be updated
to_ping = [peer for peer in set(total_peers) if self.protocol.peer_manager.peer_is_good(peer) is not True] to_ping = [peer for peer in set(total_peers) if self.protocol.peer_manager.peer_is_good(peer) is not True]
if to_ping: if to_ping:
await self.protocol.ping_queue.enqueue_maybe_ping(*to_ping, delay=0) self.protocol.ping_queue.enqueue_maybe_ping(*to_ping, delay=0)
fut = asyncio.Future(loop=self.loop) fut = asyncio.Future(loop=self.loop)
self.loop.call_later(constants.refresh_interval, fut.set_result, None) self.loop.call_later(constants.refresh_interval, fut.set_result, None)

View file

@ -192,23 +192,21 @@ class PingQueue:
self._process_task: asyncio.Task = None self._process_task: asyncio.Task = None
self._next_task: asyncio.Future = None self._next_task: asyncio.Future = None
self._next_timer: asyncio.TimerHandle = None self._next_timer: asyncio.TimerHandle = None
self._lock = asyncio.Lock()
self._running = False self._running = False
@property @property
def running(self): def running(self):
return self._running return self._running
async def enqueue_maybe_ping(self, *peers: 'KademliaPeer', delay: typing.Optional[float] = None): def enqueue_maybe_ping(self, *peers: 'KademliaPeer', delay: typing.Optional[float] = None):
delay = constants.check_refresh_interval if delay is None else delay delay = constants.check_refresh_interval if delay is None else delay
async with self._lock: for peer in peers:
for peer in peers: if delay and peer not in self._enqueued_contacts:
if delay and peer not in self._enqueued_contacts: self._pending_contacts[peer] = self._loop.time() + delay
self._pending_contacts[peer] = self._loop.time() + delay elif peer not in self._enqueued_contacts:
elif peer not in self._enqueued_contacts: self._enqueued_contacts.append(peer)
self._enqueued_contacts.append(peer) if peer in self._pending_contacts:
if peer in self._pending_contacts: del self._pending_contacts[peer]
del self._pending_contacts[peer]
async def _process(self): async def _process(self):
async def _ping(p: 'KademliaPeer'): async def _ping(p: 'KademliaPeer'):
@ -223,17 +221,16 @@ class PingQueue:
while True: while True:
tasks = [] tasks = []
async with self._lock: if self._enqueued_contacts or self._pending_contacts:
if self._enqueued_contacts or self._pending_contacts: now = self._loop.time()
now = self._loop.time() scheduled = [k for k, d in self._pending_contacts.items() if now >= d]
scheduled = [k for k, d in self._pending_contacts.items() if now >= d] for k in scheduled:
for k in scheduled: del self._pending_contacts[k]
del self._pending_contacts[k] if k not in self._enqueued_contacts:
if k not in self._enqueued_contacts: self._enqueued_contacts.append(k)
self._enqueued_contacts.append(k) while self._enqueued_contacts:
while self._enqueued_contacts: peer = self._enqueued_contacts.pop()
peer = self._enqueued_contacts.pop() tasks.append(self._loop.create_task(_ping(peer)))
tasks.append(self._loop.create_task(_ping(peer)))
if tasks: if tasks:
await asyncio.wait(tasks, loop=self._loop) await asyncio.wait(tasks, loop=self._loop)
@ -282,7 +279,6 @@ class KademliaProtocol(DatagramProtocol):
self.data_store = DictDataStore(self.loop, self.peer_manager) self.data_store = DictDataStore(self.loop, self.peer_manager)
self.ping_queue = PingQueue(self.loop, self) self.ping_queue = PingQueue(self.loop, self)
self.node_rpc = KademliaRPC(self, self.loop, self.peer_port) self.node_rpc = KademliaRPC(self, self.loop, self.peer_port)
self.lock = asyncio.Lock(loop=self.loop)
self.rpc_timeout = rpc_timeout self.rpc_timeout = rpc_timeout
self._split_lock = asyncio.Lock(loop=self.loop) self._split_lock = asyncio.Lock(loop=self.loop)
@ -424,7 +420,7 @@ class KademliaProtocol(DatagramProtocol):
# will be added to our routing table if successful # will be added to our routing table if successful
is_good = self.peer_manager.peer_is_good(peer) is_good = self.peer_manager.peer_is_good(peer)
if is_good is None: if is_good is None:
await self.ping_queue.enqueue_maybe_ping(peer) self.ping_queue.enqueue_maybe_ping(peer)
elif is_good is True: elif is_good is True:
await self.add_peer(peer) await self.add_peer(peer)
@ -553,26 +549,25 @@ class KademliaProtocol(DatagramProtocol):
if message.rpc_id in self.sent_messages: if message.rpc_id in self.sent_messages:
self.sent_messages.pop(message.rpc_id) self.sent_messages.pop(message.rpc_id)
async with self.lock: if isinstance(message, RequestDatagram):
response_fut = self.loop.create_future()
response_fut.add_done_callback(pop_from_sent_messages)
self.sent_messages[message.rpc_id] = (peer, response_fut, message)
try:
self.transport.sendto(data, (peer.address, peer.udp_port))
except OSError as err:
# TODO: handle ENETUNREACH
if err.errno == socket.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesn't seem to affect anything -grin
log.warning("Can't send data to dht: EWOULDBLOCK")
else:
log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
len(data), peer.address, peer.udp_port, str(err), err.errno)
if isinstance(message, RequestDatagram): if isinstance(message, RequestDatagram):
response_fut = self.loop.create_future() self.sent_messages[message.rpc_id][1].set_exception(err)
response_fut.add_done_callback(pop_from_sent_messages) else:
self.sent_messages[message.rpc_id] = (peer, response_fut, message) raise err
try:
self.transport.sendto(data, (peer.address, peer.udp_port))
except OSError as err:
# TODO: handle ENETUNREACH
if err.errno == socket.EWOULDBLOCK:
# i'm scared this may swallow important errors, but i get a million of these
# on Linux and it doesn't seem to affect anything -grin
log.warning("Can't send data to dht: EWOULDBLOCK")
else:
log.error("DHT socket error sending %i bytes to %s:%i - %s (code %i)",
len(data), peer.address, peer.udp_port, str(err), err.errno)
if isinstance(message, RequestDatagram):
self.sent_messages[message.rpc_id][1].set_exception(err)
else:
raise err
if isinstance(message, RequestDatagram): if isinstance(message, RequestDatagram):
self.peer_manager.report_last_sent(peer.address, peer.udp_port) self.peer_manager.report_last_sent(peer.address, peer.udp_port)
elif isinstance(message, ErrorDatagram): elif isinstance(message, ErrorDatagram):

View file

@ -86,6 +86,7 @@ class StreamAssembler:
) )
await self.blob_manager.blob_completed(self.sd_blob) await self.blob_manager.blob_completed(self.sd_blob)
written_blobs = None written_blobs = None
save_tasks = []
try: try:
with open(self.output_path, 'wb') as stream_handle: with open(self.output_path, 'wb') as stream_handle:
self.stream_handle = stream_handle self.stream_handle = stream_handle
@ -101,7 +102,7 @@ class StreamAssembler:
await self.blob_manager.delete_blobs([blob_info.blob_hash]) await self.blob_manager.delete_blobs([blob_info.blob_hash])
continue continue
if await self._decrypt_blob(blob, blob_info, self.descriptor.key): if await self._decrypt_blob(blob, blob_info, self.descriptor.key):
await self.blob_manager.blob_completed(blob) save_tasks.append(asyncio.ensure_future(self.blob_manager.blob_completed(blob)))
written_blobs = i written_blobs = i
if not self.wrote_bytes_event.is_set(): if not self.wrote_bytes_event.is_set():
self.wrote_bytes_event.set() self.wrote_bytes_event.set()
@ -115,6 +116,8 @@ class StreamAssembler:
self.descriptor.sd_hash) self.descriptor.sd_hash)
continue continue
finally: finally:
if save_tasks:
await asyncio.wait(save_tasks)
if written_blobs == len(self.descriptor.blobs) - 2: if written_blobs == len(self.descriptor.blobs) - 2:
log.debug("finished decrypting and assembling stream") log.debug("finished decrypting and assembling stream")
await self.after_finished() await self.after_finished()

View file

@ -51,6 +51,7 @@ class StreamDownloader(StreamAssembler):
async def after_finished(self): async def after_finished(self):
log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path) log.info("downloaded stream %s -> %s", self.sd_hash, self.output_path)
await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished') await self.blob_manager.storage.change_file_status(self.descriptor.stream_hash, 'finished')
self.blob_downloader.close()
def stop(self): def stop(self):
if self.accumulate_task: if self.accumulate_task:

View file

@ -70,6 +70,7 @@ class TestBlobExchange(BlobExchangeTestBase):
# download the blob # download the blob
downloaded = await request_blob(self.loop, client_blob, self.server_from_client.address, downloaded = await request_blob(self.loop, client_blob, self.server_from_client.address,
self.server_from_client.tcp_port, 2, 3) self.server_from_client.tcp_port, 2, 3)
await client_blob.finished_writing.wait()
self.assertEqual(client_blob.get_is_verified(), True) self.assertEqual(client_blob.get_is_verified(), True)
self.assertTrue(downloaded) self.assertTrue(downloaded)
@ -111,6 +112,7 @@ class TestBlobExchange(BlobExchangeTestBase):
), ),
self._test_transfer_blob(blob_hash) self._test_transfer_blob(blob_hash)
) )
await second_client_blob.finished_writing.wait()
self.assertEqual(second_client_blob.get_is_verified(), True) self.assertEqual(second_client_blob.get_is_verified(), True)
async def test_host_different_blobs_to_multiple_peers_at_once(self): async def test_host_different_blobs_to_multiple_peers_at_once(self):
@ -140,7 +142,8 @@ class TestBlobExchange(BlobExchangeTestBase):
self.loop, second_client_blob, server_from_second_client.address, self.loop, second_client_blob, server_from_second_client.address,
server_from_second_client.tcp_port, 2, 3 server_from_second_client.tcp_port, 2, 3
), ),
self._test_transfer_blob(sd_hash) self._test_transfer_blob(sd_hash),
second_client_blob.finished_writing.wait()
) )
self.assertEqual(second_client_blob.get_is_verified(), True) self.assertEqual(second_client_blob.get_is_verified(), True)