fix race condition in reflector server
This commit is contained in:
parent
330862e487
commit
71f9f8ae9c
2 changed files with 11 additions and 4 deletions
|
@ -15,6 +15,7 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
class StreamReflectorClient(asyncio.Protocol):
|
class StreamReflectorClient(asyncio.Protocol):
|
||||||
def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'):
|
def __init__(self, blob_manager: 'BlobFileManager', descriptor: 'StreamDescriptor'):
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
self.transport: asyncio.StreamWriter = None
|
self.transport: asyncio.StreamWriter = None
|
||||||
self.blob_manager = blob_manager
|
self.blob_manager = blob_manager
|
||||||
self.descriptor = descriptor
|
self.descriptor = descriptor
|
||||||
|
@ -45,7 +46,7 @@ class StreamReflectorClient(asyncio.Protocol):
|
||||||
msg = json.dumps(request_dict)
|
msg = json.dumps(request_dict)
|
||||||
self.transport.write(msg.encode())
|
self.transport.write(msg.encode())
|
||||||
try:
|
try:
|
||||||
self.pending_request = asyncio.get_event_loop().create_task(self.response_queue.get())
|
self.pending_request = self.loop.create_task(self.response_queue.get())
|
||||||
return await self.pending_request
|
return await self.pending_request
|
||||||
finally:
|
finally:
|
||||||
self.pending_request = None
|
self.pending_request = None
|
||||||
|
|
|
@ -37,7 +37,8 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
try:
|
try:
|
||||||
self.writer.write(data)
|
self.writer.write(data)
|
||||||
except IOError as err:
|
except IOError as err:
|
||||||
log.error("error downloading blob: %s", err)
|
log.error("error receiving blob: %s", err)
|
||||||
|
self.transport.close()
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
request = json.loads(data.decode())
|
request = json.loads(data.decode())
|
||||||
|
@ -67,29 +68,34 @@ class ReflectorServerProtocol(asyncio.Protocol):
|
||||||
self.send_response({"send_sd_blob": True})
|
self.send_response({"send_sd_blob": True})
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop)
|
await asyncio.wait_for(self.sd_blob.finished_writing.wait(), 30, loop=self.loop)
|
||||||
self.send_response({"received_sd_blob": True})
|
|
||||||
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||||
)
|
)
|
||||||
self.incoming.clear()
|
self.incoming.clear()
|
||||||
self.writer.close_handle()
|
self.writer.close_handle()
|
||||||
self.writer = None
|
self.writer = None
|
||||||
|
self.send_response({"received_sd_blob": True})
|
||||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||||
self.send_response({"received_sd_blob": False})
|
|
||||||
self.incoming.clear()
|
self.incoming.clear()
|
||||||
self.writer.close_handle()
|
self.writer.close_handle()
|
||||||
self.writer = None
|
self.writer = None
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
|
self.send_response({"received_sd_blob": False})
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
self.descriptor = await StreamDescriptor.from_stream_descriptor_blob(
|
||||||
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
self.loop, self.blob_manager.blob_dir, self.sd_blob
|
||||||
)
|
)
|
||||||
|
self.incoming.clear()
|
||||||
|
if self.writer:
|
||||||
|
self.writer.close_handle()
|
||||||
|
self.writer = None
|
||||||
self.send_response({"send_sd_blob": False, 'needed': [
|
self.send_response({"send_sd_blob": False, 'needed': [
|
||||||
blob.blob_hash for blob in self.descriptor.blobs[:-1]
|
blob.blob_hash for blob in self.descriptor.blobs[:-1]
|
||||||
if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()
|
if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()
|
||||||
]})
|
]})
|
||||||
return
|
return
|
||||||
|
return
|
||||||
elif self.descriptor:
|
elif self.descriptor:
|
||||||
if 'blob_hash' not in request:
|
if 'blob_hash' not in request:
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
|
|
Loading…
Reference in a new issue