From da392239a0da6abbde1510084eb540dce3184afc Mon Sep 17 00:00:00 2001 From: Doug Hoskisson Date: Wed, 2 Nov 2022 07:51:35 -0700 Subject: [PATCH] MultiServer and clients: async coroutine starter in Utils.py (#1143) * async coroutine starter in Utils.py * refactor from static class to function * async_start docstring --- CommonClient.py | 12 ++++++------ FF1Client.py | 5 +++-- FactorioClient.py | 9 +++++---- MultiServer.py | 30 +++++++++++++++--------------- OoTClient.py | 9 +++++---- PokemonClient.py | 9 +++++---- SNIClient.py | 18 +++++++++--------- Utils.py | 24 +++++++++++++++++++++++- ZillionClient.py | 15 ++++++++------- kvui.py | 5 +++-- 10 files changed, 82 insertions(+), 54 deletions(-) diff --git a/CommonClient.py b/CommonClient.py index 3703a581..3752a129 100644 --- a/CommonClient.py +++ b/CommonClient.py @@ -20,7 +20,7 @@ if __name__ == "__main__": from MultiServer import CommandProcessor from NetUtils import Endpoint, decode, NetworkItem, encode, JSONtoTextParser, \ ClientStatus, Permission, NetworkSlot, RawJSONtoTextParser -from Utils import Version, stream_input +from Utils import Version, stream_input, async_start from worlds import network_data_package, AutoWorldRegister import os @@ -46,14 +46,14 @@ class ClientCommandProcessor(CommandProcessor): """Connect to a MultiWorld Server""" self.ctx.server_address = None self.ctx.username = None - asyncio.create_task(self.ctx.connect(address if address else None), name="connecting") + async_start(self.ctx.connect(address if address else None), name="connecting") return True def _cmd_disconnect(self) -> bool: """Disconnect from a MultiWorld Server""" self.ctx.server_address = None self.ctx.username = None - asyncio.create_task(self.ctx.disconnect(), name="disconnecting") + async_start(self.ctx.disconnect(), name="disconnecting") return True def _cmd_received(self) -> bool: @@ -116,12 +116,12 @@ class ClientCommandProcessor(CommandProcessor): else: state = ClientStatus.CLIENT_CONNECTED self.output("Unreadied.") - asyncio.create_task(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}]), name="send StatusUpdate") + async_start(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}]), name="send StatusUpdate") def default(self, raw: str): raw = self.ctx.on_user_say(raw) if raw: - asyncio.create_task(self.ctx.send_msgs([{"cmd": "Say", "text": raw}]), name="send Say") + async_start(self.ctx.send_msgs([{"cmd": "Say", "text": raw}]), name="send Say") class CommonContext: @@ -562,7 +562,7 @@ async def server_loop(ctx: CommonContext, address: typing.Optional[str] = None) await ctx.connection_closed() if ctx.server_address: logger.info(f"... reconnecting in {ctx.current_reconnect_delay}s") - asyncio.create_task(server_autoreconnect(ctx), name="server auto reconnect") + async_start(server_autoreconnect(ctx), name="server auto reconnect") ctx.current_reconnect_delay *= 2 diff --git a/FF1Client.py b/FF1Client.py index 5a56d0dd..83c24846 100644 --- a/FF1Client.py +++ b/FF1Client.py @@ -7,6 +7,7 @@ from typing import List import Utils +from Utils import async_start from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \ get_base_parser @@ -69,7 +70,7 @@ class FF1Context(CommonContext): def on_package(self, cmd: str, args: dict): if cmd == 'Connected': - asyncio.create_task(parse_locations(self.locations_array, self, True)) + async_start(parse_locations(self.locations_array, self, True)) elif cmd == 'Print': msg = args['text'] if ': !' not in msg: @@ -180,7 +181,7 @@ async def nes_sync_task(ctx: FF1Context): # print(data_decoded) if ctx.game is not None and 'locations' in data_decoded: # Not just a keep alive ping, parse - asyncio.create_task(parse_locations(data_decoded['locations'], ctx, False)) + async_start(parse_locations(data_decoded['locations'], ctx, False)) if not ctx.auth: ctx.auth = ''.join([chr(i) for i in data_decoded['playerName'] if i != 0]) if ctx.auth == '': diff --git a/FactorioClient.py b/FactorioClient.py index 12ec2291..73b4ad13 100644 --- a/FactorioClient.py +++ b/FactorioClient.py @@ -25,6 +25,7 @@ if __name__ == "__main__": from CommonClient import CommonContext, server_loop, ClientCommandProcessor, logger, gui_enabled, get_base_parser from MultiServer import mark_raw from NetUtils import NetworkItem, ClientStatus, JSONtoTextParser, JSONMessagePart +from Utils import async_start from worlds.factorio import Factorio @@ -124,7 +125,7 @@ class FactorioContext(CommonContext): self.rcon_client.send_commands({item_name: f'/ap-get-technology ap-{item_name}-\t-1' for item_name in args["checked_locations"]}) if cmd == "Connected" and self.energy_link_increment: - asyncio.create_task(self.send_msgs([{ + async_start(self.send_msgs([{ "cmd": "SetNotify", "keys": ["EnergyLink"] }])) elif cmd == "SetReply": @@ -232,7 +233,7 @@ async def game_watcher(ctx: FactorioContext): if death_link_tick != ctx.death_link_tick: ctx.death_link_tick = death_link_tick if "DeathLink" in ctx.tags: - asyncio.create_task(ctx.send_death()) + async_start(ctx.send_death()) if ctx.energy_link_increment: in_world_bridges = data["energy_bridges"] if in_world_bridges: @@ -240,7 +241,7 @@ async def game_watcher(ctx: FactorioContext): if in_world_energy < (ctx.energy_link_increment * in_world_bridges): # attempt to refill ctx.last_deplete = time.time() - asyncio.create_task(ctx.send_msgs([{ + async_start(ctx.send_msgs([{ "cmd": "Set", "key": "EnergyLink", "operations": [{"operation": "add", "value": -ctx.energy_link_increment * in_world_bridges}, {"operation": "max", "value": 0}], @@ -250,7 +251,7 @@ async def game_watcher(ctx: FactorioContext): elif in_world_energy > (in_world_bridges * ctx.energy_link_increment * 5) - \ ctx.energy_link_increment*in_world_bridges: value = ctx.energy_link_increment * in_world_bridges - asyncio.create_task(ctx.send_msgs([{ + async_start(ctx.send_msgs([{ "cmd": "Set", "key": "EnergyLink", "operations": [{"operation": "add", "value": value}] }])) diff --git a/MultiServer.py b/MultiServer.py index ebd46713..fc19f8fe 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -31,7 +31,7 @@ except ImportError: import NetUtils import Utils -from Utils import version_tuple, restricted_loads, Version +from Utils import version_tuple, restricted_loads, Version, async_start from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \ SlotType @@ -273,16 +273,16 @@ class Context: def broadcast_all(self, msgs: typing.List[dict]): msgs = self.dumper(msgs) endpoints = (endpoint for endpoint in self.endpoints if endpoint.auth) - asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) + async_start(self.broadcast_send_encoded_msgs(endpoints, msgs)) def broadcast_team(self, team: int, msgs: typing.List[dict]): msgs = self.dumper(msgs) endpoints = (endpoint for endpoint in itertools.chain.from_iterable(self.clients[team].values())) - asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) + async_start(self.broadcast_send_encoded_msgs(endpoints, msgs)) def broadcast(self, endpoints: typing.Iterable[Client], msgs: typing.List[dict]): msgs = self.dumper(msgs) - asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) + async_start(self.broadcast_send_encoded_msgs(endpoints, msgs)) async def disconnect(self, endpoint: Client): if endpoint in self.endpoints: @@ -302,18 +302,18 @@ class Context: return logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text)) if client.version >= print_command_compatability_threshold: - asyncio.create_task(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }]}])) + async_start(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }]}])) else: - asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}])) + async_start(self.send_msgs(client, [{"cmd": "Print", "text": text}])) def notify_client_multiple(self, client: Client, texts: typing.List[str]): if not client.auth: return if client.version >= print_command_compatability_threshold: - asyncio.create_task(self.send_msgs(client, + async_start(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }]} for text in texts])) else: - asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts])) + async_start(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts])) # loading @@ -627,7 +627,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint], onl continue client_hints = [datum[1] for datum in sorted(hint_data, key=lambda x: x[0].finding_player == slot)] for client in clients: - asyncio.create_task(ctx.send_msgs(client, client_hints)) + async_start(ctx.send_msgs(client, client_hints)) def update_aliases(ctx: Context, team: int): @@ -636,7 +636,7 @@ def update_aliases(ctx: Context, team: int): for clients in ctx.clients[team].values(): for client in clients: - asyncio.create_task(ctx.send_encoded_msgs(client, cmd)) + async_start(ctx.send_encoded_msgs(client, cmd)) async def server(websocket, path: str = "/", ctx: Context = None): @@ -814,7 +814,7 @@ def send_new_items(ctx: Context): items = get_received_items(ctx, team, slot, client.remote_items) if len(start_inventory) + len(items) > client.send_index: first_new_item = max(0, client.send_index - len(start_inventory)) - asyncio.create_task(ctx.send_msgs(client, [{ + async_start(ctx.send_msgs(client, [{ "cmd": "ReceivedItems", "index": client.send_index, "items": start_inventory[client.send_index:] + items[first_new_item:]}])) @@ -1090,7 +1090,7 @@ class CommonCommandProcessor(CommandProcessor): timer = int(seconds, 10) except ValueError: timer = 10 - asyncio.create_task(countdown(self.ctx, timer)) + async_start(countdown(self.ctx, timer)) return True def _cmd_options(self): @@ -1771,7 +1771,7 @@ class ServerCommandProcessor(CommonCommandProcessor): def _cmd_exit(self) -> bool: """Shutdown the server""" - asyncio.create_task(self.ctx.server.ws_server._close()) + async_start(self.ctx.server.ws_server._close()) if self.ctx.shutdown_task: self.ctx.shutdown_task.cancel() self.ctx.exit_event.set() @@ -2084,7 +2084,7 @@ async def auto_shutdown(ctx, to_cancel=None): await asyncio.sleep(ctx.auto_shutdown) while not ctx.exit_event.is_set(): if not ctx.client_activity_timers.values(): - asyncio.create_task(ctx.server.ws_server._close()) + async_start(ctx.server.ws_server._close()) ctx.exit_event.set() if to_cancel: for task in to_cancel: @@ -2095,7 +2095,7 @@ async def auto_shutdown(ctx, to_cancel=None): delta = datetime.datetime.now(datetime.timezone.utc) - newest_activity seconds = ctx.auto_shutdown - delta.total_seconds() if seconds < 0: - asyncio.create_task(ctx.server.ws_server._close()) + async_start(ctx.server.ws_server._close()) ctx.exit_event.set() if to_cancel: for task in to_cancel: diff --git a/OoTClient.py b/OoTClient.py index 22420e0e..de428c73 100644 --- a/OoTClient.py +++ b/OoTClient.py @@ -9,6 +9,7 @@ from asyncio import StreamReader, StreamWriter from CommonClient import CommonContext, server_loop, gui_enabled, \ ClientCommandProcessor, logger, get_base_parser import Utils +from Utils import async_start from worlds import network_data_package from worlds.oot.Rom import Rom, compress_rom_file from worlds.oot.N64Patch import apply_patch_file @@ -69,7 +70,7 @@ class OoTCommandProcessor(ClientCommandProcessor): if isinstance(self.ctx, OoTContext): self.ctx.deathlink_client_override = True self.ctx.deathlink_enabled = not self.ctx.deathlink_enabled - asyncio.create_task(self.ctx.update_death_link(self.ctx.deathlink_enabled), name="Update Deathlink") + async_start(self.ctx.update_death_link(self.ctx.deathlink_enabled), name="Update Deathlink") class OoTContext(CommonContext): @@ -203,7 +204,7 @@ async def n64_sync_task(ctx: OoTContext): if reported_version >= script_version: if ctx.game is not None and 'locations' in data_decoded: # Not just a keep alive ping, parse - asyncio.create_task(parse_payload(data_decoded, ctx, False)) + async_start(parse_payload(data_decoded, ctx, False)) if not ctx.auth: ctx.auth = data_decoded['playerName'] if ctx.awaiting_rom: @@ -279,7 +280,7 @@ async def patch_and_run_game(apz5_file): os.chdir(data_path("Compress")) compress_rom_file(decomp_path, comp_path) os.remove(decomp_path) - asyncio.create_task(run_game(comp_path)) + async_start(run_game(comp_path)) if __name__ == '__main__': @@ -295,7 +296,7 @@ if __name__ == '__main__': if args.apz5_file: logger.info("APZ5 file supplied, beginning patching process...") - asyncio.create_task(patch_and_run_game(args.apz5_file)) + async_start(patch_and_run_game(args.apz5_file)) ctx = OoTContext(args.connect, args.password) ctx.server_task = asyncio.create_task(server_loop(ctx), name="Server Loop") diff --git a/PokemonClient.py b/PokemonClient.py index 33e902a3..f8d26c0d 100644 --- a/PokemonClient.py +++ b/PokemonClient.py @@ -11,6 +11,7 @@ from typing import List import Utils +from Utils import async_start from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \ get_base_parser @@ -185,7 +186,7 @@ async def gb_sync_task(ctx: GBContext): if 'locations' in data_decoded and ctx.game and ctx.gb_status == CONNECTION_CONNECTED_STATUS \ and not error_status and ctx.auth: # Not just a keep alive ping, parse - asyncio.create_task(parse_locations(data_decoded['locations'], ctx)) + async_start(parse_locations(data_decoded['locations'], ctx)) except asyncio.TimeoutError: logger.debug("Read Timed Out, Reconnecting") error_status = CONNECTION_TIMING_OUT_STATUS @@ -265,7 +266,7 @@ async def patch_and_run_game(game_version, patch_file, ctx): with open(comp_path, "wb") as patched_rom_file: patched_rom_file.write(patched_rom_data) - asyncio.create_task(run_game(comp_path)) + async_start(run_game(comp_path)) else: msg = "Patch supplied was not generated with the same base patch version as this client. Patching failed." logger.warning(msg) @@ -295,10 +296,10 @@ if __name__ == '__main__': ext = args.patch_file.split(".")[len(args.patch_file.split(".")) - 1].lower() if ext == "apred": logger.info("APRED file supplied, beginning patching process...") - asyncio.create_task(patch_and_run_game("red", args.patch_file, ctx)) + async_start(patch_and_run_game("red", args.patch_file, ctx)) elif ext == "apblue": logger.info("APBLUE file supplied, beginning patching process...") - asyncio.create_task(patch_and_run_game("blue", args.patch_file, ctx)) + async_start(patch_and_run_game("blue", args.patch_file, ctx)) else: logger.warning(f"Unknown patch file extension {ext}") diff --git a/SNIClient.py b/SNIClient.py index 03e1ff57..b89c588d 100644 --- a/SNIClient.py +++ b/SNIClient.py @@ -18,7 +18,7 @@ from json import loads, dumps from CommonClient import CommonContext, server_loop, ClientCommandProcessor, gui_enabled, get_base_parser import Utils - +from Utils import async_start from MultiServer import mark_raw if typing.TYPE_CHECKING: from worlds.AutoSNIClient import SNIClient @@ -84,7 +84,7 @@ class SNIClientCommandProcessor(ClientCommandProcessor): """Close connection to a currently connected snes""" self.ctx.snes_reconnect_address = None if self.ctx.snes_socket is not None and not self.ctx.snes_socket.closed: - asyncio.create_task(self.ctx.snes_socket.close()) + async_start(self.ctx.snes_socket.close()) return True else: return False @@ -96,7 +96,7 @@ class SNIClientCommandProcessor(ClientCommandProcessor): # self.output("No attached SNES Device.") # return False # snes_buffered_write(self.ctx, int(address, 16), bytes([int(data)])) - # asyncio.create_task(snes_flush_writes(self.ctx)) + # async_start(snes_flush_writes(self.ctx)) # self.output("Data Sent") # return True @@ -167,7 +167,7 @@ class SNIContext(CommonContext): def event_invalid_slot(self) -> typing.NoReturn: if self.snes_socket is not None and not self.snes_socket.closed: - asyncio.create_task(self.snes_socket.close()) + async_start(self.snes_socket.close()) raise Exception("Invalid ROM detected, " "please verify that you have loaded the correct rom and reconnect your snes (/snes)") @@ -230,7 +230,7 @@ class SNIContext(CommonContext): # since the player will likely need that item. # Once the games handled by SNIClient gets made to be remote items, # this will no longer be needed. - asyncio.create_task(self.send_msgs([{"cmd": "LocationScouts", "locations": list(new_locations)}])) + async_start(self.send_msgs([{"cmd": "LocationScouts", "locations": list(new_locations)}])) def run_gui(self) -> None: from kvui import GameManager @@ -443,7 +443,7 @@ async def snes_connect(ctx: SNIContext, address: str, deviceIndex: int = -1) -> snes_logger.error("Error connecting to snes (%s)" % e) else: snes_logger.error(f"Error connecting to snes, attempt again in {_global_snes_reconnect_delay}s") - asyncio.create_task(snes_autoreconnect(ctx)) + async_start(snes_autoreconnect(ctx)) _global_snes_reconnect_delay *= 2 else: @@ -488,7 +488,7 @@ async def snes_recv_loop(ctx: SNIContext) -> None: if ctx.snes_reconnect_address: snes_logger.info(f"...reconnecting in {_global_snes_reconnect_delay}s") - asyncio.create_task(snes_autoreconnect(ctx)) + async_start(snes_autoreconnect(ctx)) async def snes_read(ctx: SNIContext, address: int, size: int) -> typing.Optional[bytes]: @@ -674,9 +674,9 @@ async def main() -> None: elif args.diff_file.endswith(".aplttp"): from worlds.alttp.Client import get_alttp_settings adjustedromfile, adjusted = get_alttp_settings(romfile) - asyncio.create_task(run_game(adjustedromfile if adjusted else romfile)) + async_start(run_game(adjustedromfile if adjusted else romfile)) else: - asyncio.create_task(run_game(romfile)) + async_start(run_game(romfile)) ctx = SNIContext(args.snes, args.connect, args.password) if ctx.server_task is None: diff --git a/Utils.py b/Utils.py index 447d39d2..a88611b6 100644 --- a/Utils.py +++ b/Utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import typing import builtins import os @@ -11,7 +12,7 @@ import io import collections import importlib import logging -from typing import BinaryIO +from typing import BinaryIO, ClassVar, Coroutine, Optional, Set from yaml import load, load_all, dump, SafeLoader @@ -650,3 +651,24 @@ def read_snes_rom(stream: BinaryIO, strip_header: bool = True) -> bytearray: if strip_header and len(buffer) % 0x400 == 0x200: return buffer[0x200:] return buffer + + +_faf_tasks: "Set[asyncio.Task[None]]" = set() + + +def async_start(co: Coroutine[None, None, None], name: Optional[str] = None) -> None: + """ + Use this to start a task when you don't keep a reference to it or immediately await it, + to prevent early garbage collection. "fire-and-forget" + """ + # https://docs.python.org/3.10/library/asyncio-task.html#asyncio.create_task + # Python docs: + # ``` + # Important: Save a reference to the result of [asyncio.create_task], + # to avoid a task disappearing mid-execution. + # ``` + # This implementation follows the pattern given in that documentation. + + task = asyncio.create_task(co, name=name) + _faf_tasks.add(task) + task.add_done_callback(_faf_tasks.discard) diff --git a/ZillionClient.py b/ZillionClient.py index e2ce697c..d518f5ef 100644 --- a/ZillionClient.py +++ b/ZillionClient.py @@ -8,6 +8,7 @@ from CommonClient import CommonContext, server_loop, gui_enabled, \ ClientCommandProcessor, logger, get_base_parser from NetUtils import ClientStatus import Utils +from Utils import async_start import colorama # type: ignore @@ -263,7 +264,7 @@ class ZillionContext(CommonContext): "cmd": "Get", "keys": [f"zillion-{self.auth}-doors"] } - asyncio.create_task(self.send_msgs([payload])) + async_start(self.send_msgs([payload])) elif cmd == "Retrieved": if "keys" not in args: logger.warning(f"invalid Retrieved packet to ZillionClient: {args}") @@ -304,7 +305,7 @@ class ZillionContext(CommonContext): self.ap_local_count += 1 n_locations = len(self.missing_locations) + len(self.checked_locations) - 1 # -1 to ignore win logger.info(f'New Check: {loc_name} ({self.ap_local_count}/{n_locations})') - asyncio.create_task(self.send_msgs([ + async_start(self.send_msgs([ {"cmd": 'LocationChecks', "locations": [server_id]} ])) else: @@ -312,10 +313,10 @@ class ZillionContext(CommonContext): # because all the key words are local and unwatched by the server. logger.debug(f"DEBUG: {loc_name} not in missing") elif isinstance(event_from_game, events.DeathEventFromGame): - asyncio.create_task(self.send_death()) + async_start(self.send_death()) elif isinstance(event_from_game, events.WinEventFromGame): if not self.finished_game: - asyncio.create_task(self.send_msgs([ + async_start(self.send_msgs([ {"cmd": "StatusUpdate", "status": ClientStatus.CLIENT_GOAL} ])) self.finished_game = True @@ -327,7 +328,7 @@ class ZillionContext(CommonContext): "key": f"zillion-{self.auth}-doors", "operations": [{"operation": "replace", "value": doors_b64}] } - asyncio.create_task(self.send_msgs([payload])) + async_start(self.send_msgs([payload])) else: logger.warning(f"WARNING: unhandled event from game {event_from_game}") @@ -410,7 +411,7 @@ async def zillion_sync_task(ctx: ZillionContext) -> None: ctx.next_item = 0 ctx.ap_local_count = len(ctx.checked_locations) else: # no slot data yet - asyncio.create_task(ctx.send_connect()) + async_start(ctx.send_connect()) log_no_spam("logging in to server...") await asyncio.wait(( ctx.got_slot_data.wait(), @@ -434,7 +435,7 @@ async def zillion_sync_task(ctx: ZillionContext) -> None: memory.reset_game_state() ctx.auth = name - asyncio.create_task(ctx.connect()) + async_start(ctx.connect()) await asyncio.wait(( ctx.got_room_info.wait(), ctx.exit_event.wait(), diff --git a/kvui.py b/kvui.py index 94e0a091..071c07e6 100644 --- a/kvui.py +++ b/kvui.py @@ -49,6 +49,7 @@ fade_in_animation = Animation(opacity=0, duration=0) + Animation(opacity=1, dura from NetUtils import JSONtoTextParser, JSONMessagePart, SlotType +from Utils import async_start if typing.TYPE_CHECKING: import CommonClient @@ -427,9 +428,9 @@ class GameManager(App): if self.ctx.server: self.ctx.server_address = None self.ctx.username = None - asyncio.create_task(self.ctx.disconnect()) + async_start(self.ctx.disconnect()) else: - asyncio.create_task(self.ctx.connect(self.server_connect_bar.text.replace("/connect ", ""))) + async_start(self.ctx.connect(self.server_connect_bar.text.replace("/connect ", ""))) def on_stop(self): # "kill" input tasks