improve test

This commit is contained in:
Victor Shyba 2021-05-20 18:49:20 -03:00
parent 9bdf3d23e1
commit 352bf69409
2 changed files with 14 additions and 10 deletions

View file

@ -17,7 +17,7 @@ log = logging.getLogger(__name__)
class ReflectorServerProtocol(asyncio.Protocol): class ReflectorServerProtocol(asyncio.Protocol):
def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000, def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000,
stop_event: asyncio.Event = None, incoming_event: asyncio.Event = None, stop_event: asyncio.Event = None, incoming_event: asyncio.Event = None,
not_incoming_event: asyncio.Event = None, partial_needs=False): not_incoming_event: asyncio.Event = None, partial_event: asyncio.Event = None):
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.blob_manager = blob_manager self.blob_manager = blob_manager
self.server_task: asyncio.Task = None self.server_task: asyncio.Task = None
@ -34,7 +34,7 @@ class ReflectorServerProtocol(asyncio.Protocol):
self.stop_event = stop_event or asyncio.Event(loop=self.loop) self.stop_event = stop_event or asyncio.Event(loop=self.loop)
self.chunk_size = response_chunk_size self.chunk_size = response_chunk_size
self.wait_for_stop_task: typing.Optional[asyncio.Task] = None self.wait_for_stop_task: typing.Optional[asyncio.Task] = None
self.partial_needs = partial_needs self.partial_event = partial_event
async def wait_for_stop(self): async def wait_for_stop(self):
await self.stop_event.wait() await self.stop_event.wait()
@ -120,11 +120,9 @@ class ReflectorServerProtocol(asyncio.Protocol):
needs = [blob.blob_hash needs = [blob.blob_hash
for blob in self.descriptor.blobs[:-1] for blob in self.descriptor.blobs[:-1]
if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()] if not self.blob_manager.get_blob(blob.blob_hash).get_is_verified()]
print(self.partial_needs, needs) if needs and not self.partial_event.is_set():
if needs and self.partial_needs:
needs = needs[:3] needs = needs[:3]
self.partial_needs = False self.partial_event.set()
print(self.partial_needs, needs)
self.send_response({"send_sd_blob": False, 'needed_blobs': needs}) self.send_response({"send_sd_blob": False, 'needed_blobs': needs})
return return
return return
@ -177,11 +175,12 @@ class ReflectorServer:
raise Exception("already running") raise Exception("already running")
async def _start_server(): async def _start_server():
proto = ReflectorServerProtocol( partial_event = asyncio.Event()
if not self.partial_needs:
partial_event.set()
server = await self.loop.create_server(lambda: ReflectorServerProtocol(
self.blob_manager, self.response_chunk_size, self.stop_event, self.incoming_event, self.blob_manager, self.response_chunk_size, self.stop_event, self.incoming_event,
self.not_incoming_event, self.partial_needs self.not_incoming_event, partial_event), interface, port)
)
server = await self.loop.create_server(lambda: proto, interface, port)
self.started_listening.set() self.started_listening.set()
self.stopped_listening.clear() self.stopped_listening.clear()
log.info("Reflector server listening on TCP %s:%i", interface, port) log.info("Reflector server listening on TCP %s:%i", interface, port)

View file

@ -60,6 +60,11 @@ class TestReflector(AsyncioTestCase):
self.assertEqual(0, self.stream.reflector_progress) self.assertEqual(0, self.stream.reflector_progress)
sent = await self.stream.upload_to_reflector('127.0.0.1', 5566) sent = await self.stream.upload_to_reflector('127.0.0.1', 5566)
self.assertEqual(100, self.stream.reflector_progress) self.assertEqual(100, self.stream.reflector_progress)
if partial_needs:
self.assertFalse(self.stream.is_fully_reflected)
send_more = await self.stream.upload_to_reflector('127.0.0.1', 5566)
self.assertGreater(0, len(send_more))
sent.extend(send_more)
self.assertSetEqual( self.assertSetEqual(
set(sent), set(sent),
set(map(lambda b: b.blob_hash, set(map(lambda b: b.blob_hash,