fix race condition in reflector server
This commit is contained in:
parent
330862e487
commit
71f9f8ae9c
2 changed files with 11 additions and 4 deletions
|
@ -15,6 +15,7 @@ log = logging.getLogger(__name__)
|
|||
|
||||
class StreamReflectorClient(asyncio.Protocol):
|
||||
def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.transport: asyncio.StreamWriter = None
|
||||
self.blob_manager = blob_manager
|
||||
self.descriptor = descriptor
|
||||
|
@ -45,7 +46,7 @@ class StreamReflectorClient(asyncio.Protocol):
|
|||
msg = json.dumps(request_dict)
|
||||
self.transport.write(msg.encode())
|
||||
try:
|
||||
self.pending_request = asyncio.get_event_loop().create_task(self.response_queue.get())
|
||||
self.pending_request = self.loop.create_task(self.response_queue.get())
|
||||
return await self.pending_request
|
||||
finally:
|
||||
self.pending_request = None
|
||||
|
|
|
@ -37,7 +37,8 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
try:
|
||||
self.writer.write(data)
|
||||
except IOError as err:
|
||||
log.error("error downloading blob: %s", err)
|
||||
log.error("error receiving blob: %s", err)
|
||||
self.transport.close()
|
||||
return
|
||||
try:
|
||||
request = json.loads(data.decode())
|
||||
|
@ -67,29 +68,34 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
|||
self.send_response({"send_sd_blob": True})
|
||||
try:
|
||||
await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop)
|
||||
self.send_response({"received_sd_blob": True})
|
||||
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||
)
|
||||
self.incoming.clear()
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
self.send_response({"received_sd_blob": True})
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
self.send_response({"received_sd_blob": False})
|
||||
self.incoming.clear()
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
self.transport.close()
|
||||
self.send_response({"received_sd_blob": False})
|
||||
return
|
||||
else:
|
||||
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||
)
|
||||
self.incoming.clear()
|
||||
if self.writer:
|
||||
self.writer.close_handle()
|
||||
self.writer = None
|
||||
self.send_response({"send_sd_blob": False, 'needed': [
|
||||
blob.blob_hash for blob in self.descriptor.blobs[:-1]
|
||||
if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()
|
||||
]})
|
||||
return
|
||||
return
|
||||
elif self.descriptor:
|
||||
if 'blob_hash' not in request:
|
||||
self.transport.close()
|
||||
|
|
Loading…
Reference in a new issue