CommonClient: consolidate shutdown handling

This commit is contained in:
Fabian Dill 2021-11-21 02:02:40 +01:00
parent 3e40de72b2
commit a27d09f81a
3 changed files with 30 additions and 43 deletions

View File

@ -39,13 +39,13 @@ class ClientCommandProcessor(CommandProcessor):
def _cmd_connect(self, address: str = "") -> bool: def _cmd_connect(self, address: str = "") -> bool:
"""Connect to a MultiWorld Server""" """Connect to a MultiWorld Server"""
self.ctx.server_address = None self.ctx.server_address = None
asyncio.create_task(self.ctx.connect(address if address else None)) asyncio.create_task(self.ctx.connect(address if address else None), name="connecting")
return True return True
def _cmd_disconnect(self) -> bool: def _cmd_disconnect(self) -> bool:
"""Disconnect from a MultiWorld Server""" """Disconnect from a MultiWorld Server"""
self.ctx.server_address = None self.ctx.server_address = None
asyncio.create_task(self.ctx.disconnect()) asyncio.create_task(self.ctx.disconnect(), name="disconnecting")
return True return True
def _cmd_received(self) -> bool: def _cmd_received(self) -> bool:
@ -89,10 +89,10 @@ class ClientCommandProcessor(CommandProcessor):
else: else:
state = ClientStatus.CLIENT_CONNECTED state = ClientStatus.CLIENT_CONNECTED
self.output("Unreadied.") self.output("Unreadied.")
asyncio.create_task(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}])) asyncio.create_task(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}]), name="send StatusUpdate")
def default(self, raw: str): def default(self, raw: str):
asyncio.create_task(self.ctx.send_msgs([{"cmd": "Say", "text": raw}])) asyncio.create_task(self.ctx.send_msgs([{"cmd": "Say", "text": raw}]), name="send Say")
class CommonContext(): class CommonContext():
@ -149,7 +149,7 @@ class CommonContext():
self.set_getters(network_data_package) self.set_getters(network_data_package)
# execution # execution
self.keep_alive_task = asyncio.create_task(keep_alive(self)) self.keep_alive_task = asyncio.create_task(keep_alive(self), name="Bouncy")
@property @property
def total_locations(self) -> typing.Optional[int]: def total_locations(self) -> typing.Optional[int]:
@ -236,7 +236,7 @@ class CommonContext():
async def connect(self, address=None): async def connect(self, address=None):
await self.disconnect() await self.disconnect()
self.server_task = asyncio.create_task(server_loop(self, address)) self.server_task = asyncio.create_task(server_loop(self, address), name="server loop")
def on_print(self, args: dict): def on_print(self, args: dict):
logger.info(args["text"]) logger.info(args["text"])
@ -282,6 +282,18 @@ class CommonContext():
} }
}]) }])
async def shutdown(self):
self.server_address = None
if self.server and not self.server.socket.closed:
await self.server.socket.close()
if self.server_task:
await self.server_task
while self.input_requests > 0:
self.input_queue.put_nowait(None)
self.input_requests -= 1
self.keep_alive_task.cancel()
async def keep_alive(ctx: CommonContext, seconds_between_checks=100): async def keep_alive(ctx: CommonContext, seconds_between_checks=100):
"""some ISPs/network configurations drop TCP connections if no payload is sent (ignore TCP-keep-alive) """some ISPs/network configurations drop TCP connections if no payload is sent (ignore TCP-keep-alive)
@ -340,14 +352,14 @@ async def server_loop(ctx: CommonContext, address=None):
await ctx.connection_closed() await ctx.connection_closed()
if ctx.server_address: if ctx.server_address:
logger.info(f"... reconnecting in {ctx.current_reconnect_delay}s") logger.info(f"... reconnecting in {ctx.current_reconnect_delay}s")
asyncio.create_task(server_autoreconnect(ctx)) asyncio.create_task(server_autoreconnect(ctx), name="server auto reconnect")
ctx.current_reconnect_delay *= 2 ctx.current_reconnect_delay *= 2
async def server_autoreconnect(ctx: CommonContext): async def server_autoreconnect(ctx: CommonContext):
await asyncio.sleep(ctx.current_reconnect_delay) await asyncio.sleep(ctx.current_reconnect_delay)
if ctx.server_address and ctx.server_task is None: if ctx.server_address and ctx.server_task is None:
ctx.server_task = asyncio.create_task(server_loop(ctx)) ctx.server_task = asyncio.create_task(server_loop(ctx), name="server loop")
async def process_server_cmd(ctx: CommonContext, args: dict): async def process_server_cmd(ctx: CommonContext, args: dict):
@ -555,7 +567,7 @@ if __name__ == '__main__':
async def main(args): async def main(args):
ctx = TextContext(args.connect, args.password) ctx = TextContext(args.connect, args.password)
ctx.server_task = asyncio.create_task(server_loop(ctx), name="ServerLoop") ctx.server_task = asyncio.create_task(server_loop(ctx), name="server loop")
if gui_enabled: if gui_enabled:
input_task = None input_task = None
from kvui import TextManager from kvui import TextManager
@ -566,16 +578,7 @@ if __name__ == '__main__':
ui_task = None ui_task = None
await ctx.exit_event.wait() await ctx.exit_event.wait()
ctx.server_address = None await ctx.shutdown()
if ctx.server and not ctx.server.socket.closed:
await ctx.server.socket.close()
if ctx.server_task:
await ctx.server_task
while ctx.input_requests > 0:
ctx.input_queue.put_nowait(None)
ctx.input_requests -= 1
if ui_task: if ui_task:
await ui_task await ui_task

View File

@ -322,14 +322,7 @@ async def main(args):
await progression_watcher await progression_watcher
await factorio_server_task await factorio_server_task
if ctx.server and not ctx.server.socket.closed: await ctx.shutdown()
await ctx.server.socket.close()
if ctx.server_task:
await ctx.server_task
while ctx.input_requests > 0:
ctx.input_queue.put_nowait(None)
ctx.input_requests -= 1
if ui_task: if ui_task:
await ui_task await ui_task

View File

@ -73,7 +73,7 @@ class LttPCommandProcessor(ClientCommandProcessor):
pass pass
self.ctx.snes_reconnect_address = None self.ctx.snes_reconnect_address = None
asyncio.create_task(snes_connect(self.ctx, snes_address, snes_device_number)) asyncio.create_task(snes_connect(self.ctx, snes_address, snes_device_number), name="SNES Connect")
return True return True
def _cmd_snes_close(self) -> bool: def _cmd_snes_close(self) -> bool:
@ -1113,28 +1113,19 @@ async def main():
input_task = asyncio.create_task(console_loop(ctx), name="Input") input_task = asyncio.create_task(console_loop(ctx), name="Input")
ui_task = None ui_task = None
snes_connect_task = asyncio.create_task(snes_connect(ctx, ctx.snes_address)) snes_connect_task = asyncio.create_task(snes_connect(ctx, ctx.snes_address), name="SNES Connect")
watcher_task = asyncio.create_task(game_watcher(ctx), name="GameWatcher") watcher_task = asyncio.create_task(game_watcher(ctx), name="GameWatcher")
await ctx.exit_event.wait() await ctx.exit_event.wait()
if snes_connect_task:
snes_connect_task.cancel()
ctx.server_address = None ctx.server_address = None
ctx.snes_reconnect_address = None ctx.snes_reconnect_address = None
await watcher_task
if ctx.server and not ctx.server.socket.closed:
await ctx.server.socket.close()
if ctx.server_task:
await ctx.server_task
if ctx.snes_socket is not None and not ctx.snes_socket.closed: if ctx.snes_socket is not None and not ctx.snes_socket.closed:
await ctx.snes_socket.close() await ctx.snes_socket.close()
if snes_connect_task:
while ctx.input_requests > 0: snes_connect_task.cancel()
ctx.input_queue.put_nowait(None) await watcher_task
ctx.input_requests -= 1 await ctx.shutdown()
if ui_task: if ui_task:
await ui_task await ui_task