fix race condition in reflector server

This commit is contained in:
Jack Robison 2019-01-28 16:51:11 -05:00
parent 330862e487
commit 71f9f8ae9c
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2
2 changed files with 11 additions and 4 deletions

View file

@ -15,6 +15,7 @@ log = logging.getLogger(__name__)
class StreamReflectorClient(asyncio.Protocol): class StreamReflectorClient(asyncio.Protocol):
def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'): def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'):
self.loop = asyncio.get_event_loop()
self.transport: asyncio.StreamWriter = None self.transport: asyncio.StreamWriter = None
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.descriptor = descriptor self.descriptor = descriptor
@ -45,7 +46,7 @@ class StreamReflectorClient(asyncio.Protocol):
msg = json.dumps(request_dict) msg = json.dumps(request_dict)
self.transport.write(msg.encode()) self.transport.write(msg.encode())
try: try:
self.pending_request = asyncio.get_event_loop().create_task(self.response_queue.get()) self.pending_request = self.loop.create_task(self.response_queue.get())
return await self.pending_request return await self.pending_request
finally: finally:
self.pending_request = None self.pending_request = None

View file

@ -37,7 +37,8 @@ class ReflectorServerProtocol(asyncio.Protocol):
try: try:
self.writer.write(data) self.writer.write(data)
except IOError as err: except IOError as err:
log.error("error downloading blob: %s", err) log.error("error receiving blob: %s", err)
self.transport.close()
return return
try: try:
request = json.loads(data.decode()) request = json.loads(data.decode())
@ -67,29 +68,34 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.send_response({"send_sd_blob": True}) self.send_response({"send_sd_blob": True})
try: try:
await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop) await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop)
self.send_response({"received_sd_blob": True})
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
self.loop, self.blob_manager.blob_dir, self.sd_blob self.loop, self.blob_manager.blob_dir, self.sd_blob
) )
self.incoming.clear() self.incoming.clear()
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
self.send_response({"received_sd_blob": True})
except (asyncio.TimeoutError, asyncio.CancelledError): except (asyncio.TimeoutError, asyncio.CancelledError):
self.send_response({"received_sd_blob": False})
self.incoming.clear() self.incoming.clear()
self.writer.close_handle() self.writer.close_handle()
self.writer = None self.writer = None
self.transport.close() self.transport.close()
self.send_response({"received_sd_blob": False})
return return
else: else:
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob( self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
self.loop, self.blob_manager.blob_dir, self.sd_blob self.loop, self.blob_manager.blob_dir, self.sd_blob
) )
self.incoming.clear()
if self.writer:
self.writer.close_handle()
self.writer = None
self.send_response({"send_sd_blob": False, 'needed': [ self.send_response({"send_sd_blob": False, 'needed': [
blob.blob_hash for blob in self.descriptor.blobs[:-1] blob.blob_hash for blob in self.descriptor.blobs[:-1]
if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified() if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()
]}) ]})
return return
return
elif self.descriptor: elif self.descriptor:
if 'blob_hash' not in request: if 'blob_hash' not in request:
self.transport.close() self.transport.close()