MultiServer and clients: async coroutine starter in Utils.py (#1143)
* async coroutine starter in Utils.py * refactor from static class to function * async_start docstring
This commit is contained in:
		
							parent
							
								
									a6e1e14fee
								
							
						
					
					
						commit
						da392239a0
					
				|  | @ -20,7 +20,7 @@ if __name__ == "__main__": | |||
| from MultiServer import CommandProcessor | ||||
| from NetUtils import Endpoint, decode, NetworkItem, encode, JSONtoTextParser, \ | ||||
|     ClientStatus, Permission, NetworkSlot, RawJSONtoTextParser | ||||
| from Utils import Version, stream_input | ||||
| from Utils import Version, stream_input, async_start | ||||
| from worlds import network_data_package, AutoWorldRegister | ||||
| import os | ||||
| 
 | ||||
|  | @ -46,14 +46,14 @@ class ClientCommandProcessor(CommandProcessor): | |||
|         """Connect to a MultiWorld Server""" | ||||
|         self.ctx.server_address = None | ||||
|         self.ctx.username = None | ||||
|         asyncio.create_task(self.ctx.connect(address if address else None), name="connecting") | ||||
|         async_start(self.ctx.connect(address if address else None), name="connecting") | ||||
|         return True | ||||
| 
 | ||||
|     def _cmd_disconnect(self) -> bool: | ||||
|         """Disconnect from a MultiWorld Server""" | ||||
|         self.ctx.server_address = None | ||||
|         self.ctx.username = None | ||||
|         asyncio.create_task(self.ctx.disconnect(), name="disconnecting") | ||||
|         async_start(self.ctx.disconnect(), name="disconnecting") | ||||
|         return True | ||||
| 
 | ||||
|     def _cmd_received(self) -> bool: | ||||
|  | @ -116,12 +116,12 @@ class ClientCommandProcessor(CommandProcessor): | |||
|         else: | ||||
|             state = ClientStatus.CLIENT_CONNECTED | ||||
|             self.output("Unreadied.") | ||||
|         asyncio.create_task(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}]), name="send StatusUpdate") | ||||
|         async_start(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}]), name="send StatusUpdate") | ||||
| 
 | ||||
|     def default(self, raw: str): | ||||
|         raw = self.ctx.on_user_say(raw) | ||||
|         if raw: | ||||
|             asyncio.create_task(self.ctx.send_msgs([{"cmd": "Say", "text": raw}]), name="send Say") | ||||
|             async_start(self.ctx.send_msgs([{"cmd": "Say", "text": raw}]), name="send Say") | ||||
| 
 | ||||
| 
 | ||||
| class CommonContext: | ||||
|  | @ -562,7 +562,7 @@ async def server_loop(ctx: CommonContext, address: typing.Optional[str] = None) | |||
|         await ctx.connection_closed() | ||||
|         if ctx.server_address: | ||||
|             logger.info(f"... reconnecting in {ctx.current_reconnect_delay}s") | ||||
|             asyncio.create_task(server_autoreconnect(ctx), name="server auto reconnect") | ||||
|             async_start(server_autoreconnect(ctx), name="server auto reconnect") | ||||
|         ctx.current_reconnect_delay *= 2 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ from typing import List | |||
| 
 | ||||
| 
 | ||||
| import Utils | ||||
| from Utils import async_start | ||||
| from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \ | ||||
|     get_base_parser | ||||
| 
 | ||||
|  | @ -69,7 +70,7 @@ class FF1Context(CommonContext): | |||
| 
 | ||||
|     def on_package(self, cmd: str, args: dict): | ||||
|         if cmd == 'Connected': | ||||
|             asyncio.create_task(parse_locations(self.locations_array, self, True)) | ||||
|             async_start(parse_locations(self.locations_array, self, True)) | ||||
|         elif cmd == 'Print': | ||||
|             msg = args['text'] | ||||
|             if ': !' not in msg: | ||||
|  | @ -180,7 +181,7 @@ async def nes_sync_task(ctx: FF1Context): | |||
|                     # print(data_decoded) | ||||
|                     if ctx.game is not None and 'locations' in data_decoded: | ||||
|                         # Not just a keep alive ping, parse | ||||
|                         asyncio.create_task(parse_locations(data_decoded['locations'], ctx, False)) | ||||
|                         async_start(parse_locations(data_decoded['locations'], ctx, False)) | ||||
|                     if not ctx.auth: | ||||
|                         ctx.auth = ''.join([chr(i) for i in data_decoded['playerName'] if i != 0]) | ||||
|                         if ctx.auth == '': | ||||
|  |  | |||
|  | @ -25,6 +25,7 @@ if __name__ == "__main__": | |||
| from CommonClient import CommonContext, server_loop, ClientCommandProcessor, logger, gui_enabled, get_base_parser | ||||
| from MultiServer import mark_raw | ||||
| from NetUtils import NetworkItem, ClientStatus, JSONtoTextParser, JSONMessagePart | ||||
| from Utils import async_start | ||||
| 
 | ||||
| from worlds.factorio import Factorio | ||||
| 
 | ||||
|  | @ -124,7 +125,7 @@ class FactorioContext(CommonContext): | |||
|                 self.rcon_client.send_commands({item_name: f'/ap-get-technology ap-{item_name}-\t-1' for | ||||
|                                                 item_name in args["checked_locations"]}) | ||||
|             if cmd == "Connected" and self.energy_link_increment: | ||||
|                 asyncio.create_task(self.send_msgs([{ | ||||
|                 async_start(self.send_msgs([{ | ||||
|                     "cmd": "SetNotify", "keys": ["EnergyLink"] | ||||
|                 }])) | ||||
|         elif cmd == "SetReply": | ||||
|  | @ -232,7 +233,7 @@ async def game_watcher(ctx: FactorioContext): | |||
|                     if death_link_tick != ctx.death_link_tick: | ||||
|                         ctx.death_link_tick = death_link_tick | ||||
|                         if "DeathLink" in ctx.tags: | ||||
|                             asyncio.create_task(ctx.send_death()) | ||||
|                             async_start(ctx.send_death()) | ||||
|                     if ctx.energy_link_increment: | ||||
|                         in_world_bridges = data["energy_bridges"] | ||||
|                         if in_world_bridges: | ||||
|  | @ -240,7 +241,7 @@ async def game_watcher(ctx: FactorioContext): | |||
|                             if in_world_energy < (ctx.energy_link_increment * in_world_bridges): | ||||
|                                 # attempt to refill | ||||
|                                 ctx.last_deplete = time.time() | ||||
|                                 asyncio.create_task(ctx.send_msgs([{ | ||||
|                                 async_start(ctx.send_msgs([{ | ||||
|                                     "cmd": "Set", "key": "EnergyLink", "operations": | ||||
|                                         [{"operation": "add", "value": -ctx.energy_link_increment * in_world_bridges}, | ||||
|                                          {"operation": "max", "value": 0}], | ||||
|  | @ -250,7 +251,7 @@ async def game_watcher(ctx: FactorioContext): | |||
|                             elif in_world_energy > (in_world_bridges * ctx.energy_link_increment * 5) - \ | ||||
|                                 ctx.energy_link_increment*in_world_bridges: | ||||
|                                 value = ctx.energy_link_increment * in_world_bridges | ||||
|                                 asyncio.create_task(ctx.send_msgs([{ | ||||
|                                 async_start(ctx.send_msgs([{ | ||||
|                                     "cmd": "Set", "key": "EnergyLink", "operations": | ||||
|                                         [{"operation": "add", "value": value}] | ||||
|                                 }])) | ||||
|  |  | |||
|  | @ -31,7 +31,7 @@ except ImportError: | |||
| 
 | ||||
| import NetUtils | ||||
| import Utils | ||||
| from Utils import version_tuple, restricted_loads, Version | ||||
| from Utils import version_tuple, restricted_loads, Version, async_start | ||||
| from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \ | ||||
|     SlotType | ||||
| 
 | ||||
|  | @ -273,16 +273,16 @@ class Context: | |||
|     def broadcast_all(self, msgs: typing.List[dict]): | ||||
|         msgs = self.dumper(msgs) | ||||
|         endpoints = (endpoint for endpoint in self.endpoints if endpoint.auth) | ||||
|         asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) | ||||
|         async_start(self.broadcast_send_encoded_msgs(endpoints, msgs)) | ||||
| 
 | ||||
|     def broadcast_team(self, team: int, msgs: typing.List[dict]): | ||||
|         msgs = self.dumper(msgs) | ||||
|         endpoints = (endpoint for endpoint in itertools.chain.from_iterable(self.clients[team].values())) | ||||
|         asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) | ||||
|         async_start(self.broadcast_send_encoded_msgs(endpoints, msgs)) | ||||
| 
 | ||||
|     def broadcast(self, endpoints: typing.Iterable[Client], msgs: typing.List[dict]): | ||||
|         msgs = self.dumper(msgs) | ||||
|         asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) | ||||
|         async_start(self.broadcast_send_encoded_msgs(endpoints, msgs)) | ||||
| 
 | ||||
|     async def disconnect(self, endpoint: Client): | ||||
|         if endpoint in self.endpoints: | ||||
|  | @ -302,18 +302,18 @@ class Context: | |||
|             return | ||||
|         logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text)) | ||||
|         if client.version >= print_command_compatability_threshold: | ||||
|             asyncio.create_task(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }]}])) | ||||
|             async_start(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }]}])) | ||||
|         else: | ||||
|             asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}])) | ||||
|             async_start(self.send_msgs(client, [{"cmd": "Print", "text": text}])) | ||||
| 
 | ||||
|     def notify_client_multiple(self, client: Client, texts: typing.List[str]): | ||||
|         if not client.auth: | ||||
|             return | ||||
|         if client.version >= print_command_compatability_threshold: | ||||
|             asyncio.create_task(self.send_msgs(client,  | ||||
|             async_start(self.send_msgs(client,  | ||||
|                 [{"cmd": "PrintJSON", "data": [{ "text": text }]} for text in texts])) | ||||
|         else: | ||||
|             asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts])) | ||||
|             async_start(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts])) | ||||
| 
 | ||||
|     # loading | ||||
| 
 | ||||
|  | @ -627,7 +627,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint], onl | |||
|             continue | ||||
|         client_hints = [datum[1] for datum in sorted(hint_data, key=lambda x: x[0].finding_player == slot)] | ||||
|         for client in clients: | ||||
|             asyncio.create_task(ctx.send_msgs(client, client_hints)) | ||||
|             async_start(ctx.send_msgs(client, client_hints)) | ||||
| 
 | ||||
| 
 | ||||
| def update_aliases(ctx: Context, team: int): | ||||
|  | @ -636,7 +636,7 @@ def update_aliases(ctx: Context, team: int): | |||
| 
 | ||||
|     for clients in ctx.clients[team].values(): | ||||
|         for client in clients: | ||||
|             asyncio.create_task(ctx.send_encoded_msgs(client, cmd)) | ||||
|             async_start(ctx.send_encoded_msgs(client, cmd)) | ||||
| 
 | ||||
| 
 | ||||
| async def server(websocket, path: str = "/", ctx: Context = None): | ||||
|  | @ -814,7 +814,7 @@ def send_new_items(ctx: Context): | |||
|                 items = get_received_items(ctx, team, slot, client.remote_items) | ||||
|                 if len(start_inventory) + len(items) > client.send_index: | ||||
|                     first_new_item = max(0, client.send_index - len(start_inventory)) | ||||
|                     asyncio.create_task(ctx.send_msgs(client, [{ | ||||
|                     async_start(ctx.send_msgs(client, [{ | ||||
|                         "cmd": "ReceivedItems", | ||||
|                         "index": client.send_index, | ||||
|                         "items": start_inventory[client.send_index:] + items[first_new_item:]}])) | ||||
|  | @ -1090,7 +1090,7 @@ class CommonCommandProcessor(CommandProcessor): | |||
|             timer = int(seconds, 10) | ||||
|         except ValueError: | ||||
|             timer = 10 | ||||
|         asyncio.create_task(countdown(self.ctx, timer)) | ||||
|         async_start(countdown(self.ctx, timer)) | ||||
|         return True | ||||
| 
 | ||||
|     def _cmd_options(self): | ||||
|  | @ -1771,7 +1771,7 @@ class ServerCommandProcessor(CommonCommandProcessor): | |||
| 
 | ||||
|     def _cmd_exit(self) -> bool: | ||||
|         """Shutdown the server""" | ||||
|         asyncio.create_task(self.ctx.server.ws_server._close()) | ||||
|         async_start(self.ctx.server.ws_server._close()) | ||||
|         if self.ctx.shutdown_task: | ||||
|             self.ctx.shutdown_task.cancel() | ||||
|         self.ctx.exit_event.set() | ||||
|  | @ -2084,7 +2084,7 @@ async def auto_shutdown(ctx, to_cancel=None): | |||
|     await asyncio.sleep(ctx.auto_shutdown) | ||||
|     while not ctx.exit_event.is_set(): | ||||
|         if not ctx.client_activity_timers.values(): | ||||
|             asyncio.create_task(ctx.server.ws_server._close()) | ||||
|             async_start(ctx.server.ws_server._close()) | ||||
|             ctx.exit_event.set() | ||||
|             if to_cancel: | ||||
|                 for task in to_cancel: | ||||
|  | @ -2095,7 +2095,7 @@ async def auto_shutdown(ctx, to_cancel=None): | |||
|             delta = datetime.datetime.now(datetime.timezone.utc) - newest_activity | ||||
|             seconds = ctx.auto_shutdown - delta.total_seconds() | ||||
|             if seconds < 0: | ||||
|                 asyncio.create_task(ctx.server.ws_server._close()) | ||||
|                 async_start(ctx.server.ws_server._close()) | ||||
|                 ctx.exit_event.set() | ||||
|                 if to_cancel: | ||||
|                     for task in to_cancel: | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ from asyncio import StreamReader, StreamWriter | |||
| from CommonClient import CommonContext, server_loop, gui_enabled, \ | ||||
|     ClientCommandProcessor, logger, get_base_parser | ||||
| import Utils | ||||
| from Utils import async_start | ||||
| from worlds import network_data_package | ||||
| from worlds.oot.Rom import Rom, compress_rom_file | ||||
| from worlds.oot.N64Patch import apply_patch_file | ||||
|  | @ -69,7 +70,7 @@ class OoTCommandProcessor(ClientCommandProcessor): | |||
|         if isinstance(self.ctx, OoTContext): | ||||
|             self.ctx.deathlink_client_override = True | ||||
|             self.ctx.deathlink_enabled = not self.ctx.deathlink_enabled | ||||
|             asyncio.create_task(self.ctx.update_death_link(self.ctx.deathlink_enabled), name="Update Deathlink") | ||||
|             async_start(self.ctx.update_death_link(self.ctx.deathlink_enabled), name="Update Deathlink") | ||||
| 
 | ||||
| 
 | ||||
| class OoTContext(CommonContext): | ||||
|  | @ -203,7 +204,7 @@ async def n64_sync_task(ctx: OoTContext): | |||
|                     if reported_version >= script_version: | ||||
|                         if ctx.game is not None and 'locations' in data_decoded: | ||||
|                             # Not just a keep alive ping, parse | ||||
|                             asyncio.create_task(parse_payload(data_decoded, ctx, False)) | ||||
|                             async_start(parse_payload(data_decoded, ctx, False)) | ||||
|                         if not ctx.auth: | ||||
|                             ctx.auth = data_decoded['playerName'] | ||||
|                             if ctx.awaiting_rom: | ||||
|  | @ -279,7 +280,7 @@ async def patch_and_run_game(apz5_file): | |||
|     os.chdir(data_path("Compress")) | ||||
|     compress_rom_file(decomp_path, comp_path) | ||||
|     os.remove(decomp_path) | ||||
|     asyncio.create_task(run_game(comp_path)) | ||||
|     async_start(run_game(comp_path)) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|  | @ -295,7 +296,7 @@ if __name__ == '__main__': | |||
| 
 | ||||
|         if args.apz5_file: | ||||
|             logger.info("APZ5 file supplied, beginning patching process...") | ||||
|             asyncio.create_task(patch_and_run_game(args.apz5_file)) | ||||
|             async_start(patch_and_run_game(args.apz5_file)) | ||||
| 
 | ||||
|         ctx = OoTContext(args.connect, args.password) | ||||
|         ctx.server_task = asyncio.create_task(server_loop(ctx), name="Server Loop") | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ from typing import List | |||
| 
 | ||||
| 
 | ||||
| import Utils | ||||
| from Utils import async_start | ||||
| from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \ | ||||
|     get_base_parser | ||||
| 
 | ||||
|  | @ -185,7 +186,7 @@ async def gb_sync_task(ctx: GBContext): | |||
|                     if 'locations' in data_decoded and ctx.game and ctx.gb_status == CONNECTION_CONNECTED_STATUS \ | ||||
|                             and not error_status and ctx.auth: | ||||
|                         # Not just a keep alive ping, parse | ||||
|                         asyncio.create_task(parse_locations(data_decoded['locations'], ctx)) | ||||
|                         async_start(parse_locations(data_decoded['locations'], ctx)) | ||||
|                 except asyncio.TimeoutError: | ||||
|                     logger.debug("Read Timed Out, Reconnecting") | ||||
|                     error_status = CONNECTION_TIMING_OUT_STATUS | ||||
|  | @ -265,7 +266,7 @@ async def patch_and_run_game(game_version, patch_file, ctx): | |||
|         with open(comp_path, "wb") as patched_rom_file: | ||||
|             patched_rom_file.write(patched_rom_data) | ||||
| 
 | ||||
|         asyncio.create_task(run_game(comp_path)) | ||||
|         async_start(run_game(comp_path)) | ||||
|     else: | ||||
|         msg = "Patch supplied was not generated with the same base patch version as this client. Patching failed." | ||||
|         logger.warning(msg) | ||||
|  | @ -295,10 +296,10 @@ if __name__ == '__main__': | |||
|             ext = args.patch_file.split(".")[len(args.patch_file.split(".")) - 1].lower() | ||||
|             if ext == "apred": | ||||
|                 logger.info("APRED file supplied, beginning patching process...") | ||||
|                 asyncio.create_task(patch_and_run_game("red", args.patch_file, ctx)) | ||||
|                 async_start(patch_and_run_game("red", args.patch_file, ctx)) | ||||
|             elif ext == "apblue": | ||||
|                 logger.info("APBLUE file supplied, beginning patching process...") | ||||
|                 asyncio.create_task(patch_and_run_game("blue", args.patch_file, ctx)) | ||||
|                 async_start(patch_and_run_game("blue", args.patch_file, ctx)) | ||||
|             else: | ||||
|                 logger.warning(f"Unknown patch file extension {ext}") | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										18
									
								
								SNIClient.py
								
								
								
								
							
							
						
						
									
										18
									
								
								SNIClient.py
								
								
								
								
							|  | @ -18,7 +18,7 @@ from json import loads, dumps | |||
| from CommonClient import CommonContext, server_loop, ClientCommandProcessor, gui_enabled, get_base_parser | ||||
| 
 | ||||
| import Utils | ||||
| 
 | ||||
| from Utils import async_start | ||||
| from MultiServer import mark_raw | ||||
| if typing.TYPE_CHECKING: | ||||
|     from worlds.AutoSNIClient import SNIClient | ||||
|  | @ -84,7 +84,7 @@ class SNIClientCommandProcessor(ClientCommandProcessor): | |||
|         """Close connection to a currently connected snes""" | ||||
|         self.ctx.snes_reconnect_address = None | ||||
|         if self.ctx.snes_socket is not None and not self.ctx.snes_socket.closed: | ||||
|             asyncio.create_task(self.ctx.snes_socket.close()) | ||||
|             async_start(self.ctx.snes_socket.close()) | ||||
|             return True | ||||
|         else: | ||||
|             return False | ||||
|  | @ -96,7 +96,7 @@ class SNIClientCommandProcessor(ClientCommandProcessor): | |||
|     #         self.output("No attached SNES Device.") | ||||
|     #         return False | ||||
|     #     snes_buffered_write(self.ctx, int(address, 16), bytes([int(data)])) | ||||
|     #     asyncio.create_task(snes_flush_writes(self.ctx)) | ||||
|     #     async_start(snes_flush_writes(self.ctx)) | ||||
|     #     self.output("Data Sent") | ||||
|     #     return True | ||||
| 
 | ||||
|  | @ -167,7 +167,7 @@ class SNIContext(CommonContext): | |||
| 
 | ||||
|     def event_invalid_slot(self) -> typing.NoReturn: | ||||
|         if self.snes_socket is not None and not self.snes_socket.closed: | ||||
|             asyncio.create_task(self.snes_socket.close()) | ||||
|             async_start(self.snes_socket.close()) | ||||
|         raise Exception("Invalid ROM detected, " | ||||
|                         "please verify that you have loaded the correct rom and reconnect your snes (/snes)") | ||||
| 
 | ||||
|  | @ -230,7 +230,7 @@ class SNIContext(CommonContext): | |||
|                 # since the player will likely need that item. | ||||
|                 # Once the games handled by SNIClient gets made to be remote items, | ||||
|                 # this will no longer be needed. | ||||
|                 asyncio.create_task(self.send_msgs([{"cmd": "LocationScouts", "locations": list(new_locations)}])) | ||||
|                 async_start(self.send_msgs([{"cmd": "LocationScouts", "locations": list(new_locations)}])) | ||||
| 
 | ||||
|     def run_gui(self) -> None: | ||||
|         from kvui import GameManager | ||||
|  | @ -443,7 +443,7 @@ async def snes_connect(ctx: SNIContext, address: str, deviceIndex: int = -1) -> | |||
|             snes_logger.error("Error connecting to snes (%s)" % e) | ||||
|         else: | ||||
|             snes_logger.error(f"Error connecting to snes, attempt again in {_global_snes_reconnect_delay}s") | ||||
|             asyncio.create_task(snes_autoreconnect(ctx)) | ||||
|             async_start(snes_autoreconnect(ctx)) | ||||
|         _global_snes_reconnect_delay *= 2 | ||||
| 
 | ||||
|     else: | ||||
|  | @ -488,7 +488,7 @@ async def snes_recv_loop(ctx: SNIContext) -> None: | |||
| 
 | ||||
|         if ctx.snes_reconnect_address: | ||||
|             snes_logger.info(f"...reconnecting in {_global_snes_reconnect_delay}s") | ||||
|             asyncio.create_task(snes_autoreconnect(ctx)) | ||||
|             async_start(snes_autoreconnect(ctx)) | ||||
| 
 | ||||
| 
 | ||||
| async def snes_read(ctx: SNIContext, address: int, size: int) -> typing.Optional[bytes]: | ||||
|  | @ -674,9 +674,9 @@ async def main() -> None: | |||
|         elif args.diff_file.endswith(".aplttp"): | ||||
|             from worlds.alttp.Client import get_alttp_settings | ||||
|             adjustedromfile, adjusted = get_alttp_settings(romfile) | ||||
|             asyncio.create_task(run_game(adjustedromfile if adjusted else romfile)) | ||||
|             async_start(run_game(adjustedromfile if adjusted else romfile)) | ||||
|         else: | ||||
|             asyncio.create_task(run_game(romfile)) | ||||
|             async_start(run_game(romfile)) | ||||
| 
 | ||||
|     ctx = SNIContext(args.snes, args.connect, args.password) | ||||
|     if ctx.server_task is None: | ||||
|  |  | |||
							
								
								
									
										24
									
								
								Utils.py
								
								
								
								
							
							
						
						
									
										24
									
								
								Utils.py
								
								
								
								
							|  | @ -1,5 +1,6 @@ | |||
| from __future__ import annotations | ||||
| 
 | ||||
| import asyncio | ||||
| import typing | ||||
| import builtins | ||||
| import os | ||||
|  | @ -11,7 +12,7 @@ import io | |||
| import collections | ||||
| import importlib | ||||
| import logging | ||||
| from typing import BinaryIO | ||||
| from typing import BinaryIO, ClassVar, Coroutine, Optional, Set | ||||
| 
 | ||||
| from yaml import load, load_all, dump, SafeLoader | ||||
| 
 | ||||
|  | @ -650,3 +651,24 @@ def read_snes_rom(stream: BinaryIO, strip_header: bool = True) -> bytearray: | |||
|     if strip_header and len(buffer) % 0x400 == 0x200: | ||||
|         return buffer[0x200:] | ||||
|     return buffer | ||||
| 
 | ||||
| 
 | ||||
| _faf_tasks: "Set[asyncio.Task[None]]" = set() | ||||
| 
 | ||||
| 
 | ||||
| def async_start(co: Coroutine[None, None, None], name: Optional[str] = None) -> None: | ||||
|     """ | ||||
|     Use this to start a task when you don't keep a reference to it or immediately await it, | ||||
|     to prevent early garbage collection. "fire-and-forget" | ||||
|     """ | ||||
|     # https://docs.python.org/3.10/library/asyncio-task.html#asyncio.create_task | ||||
|     # Python docs: | ||||
|     # ``` | ||||
|     # Important: Save a reference to the result of [asyncio.create_task], | ||||
|     # to avoid a task disappearing mid-execution. | ||||
|     # ``` | ||||
|     # This implementation follows the pattern given in that documentation. | ||||
| 
 | ||||
|     task = asyncio.create_task(co, name=name) | ||||
|     _faf_tasks.add(task) | ||||
|     task.add_done_callback(_faf_tasks.discard) | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ from CommonClient import CommonContext, server_loop, gui_enabled, \ | |||
|     ClientCommandProcessor, logger, get_base_parser | ||||
| from NetUtils import ClientStatus | ||||
| import Utils | ||||
| from Utils import async_start | ||||
| 
 | ||||
| import colorama  # type: ignore | ||||
| 
 | ||||
|  | @ -263,7 +264,7 @@ class ZillionContext(CommonContext): | |||
|                 "cmd": "Get", | ||||
|                 "keys": [f"zillion-{self.auth}-doors"] | ||||
|             } | ||||
|             asyncio.create_task(self.send_msgs([payload])) | ||||
|             async_start(self.send_msgs([payload])) | ||||
|         elif cmd == "Retrieved": | ||||
|             if "keys" not in args: | ||||
|                 logger.warning(f"invalid Retrieved packet to ZillionClient: {args}") | ||||
|  | @ -304,7 +305,7 @@ class ZillionContext(CommonContext): | |||
|                     self.ap_local_count += 1 | ||||
|                     n_locations = len(self.missing_locations) + len(self.checked_locations) - 1  # -1 to ignore win | ||||
|                     logger.info(f'New Check: {loc_name} ({self.ap_local_count}/{n_locations})') | ||||
|                     asyncio.create_task(self.send_msgs([ | ||||
|                     async_start(self.send_msgs([ | ||||
|                         {"cmd": 'LocationChecks', "locations": [server_id]} | ||||
|                     ])) | ||||
|                 else: | ||||
|  | @ -312,10 +313,10 @@ class ZillionContext(CommonContext): | |||
|                     # because all the key words are local and unwatched by the server. | ||||
|                     logger.debug(f"DEBUG: {loc_name} not in missing") | ||||
|             elif isinstance(event_from_game, events.DeathEventFromGame): | ||||
|                 asyncio.create_task(self.send_death()) | ||||
|                 async_start(self.send_death()) | ||||
|             elif isinstance(event_from_game, events.WinEventFromGame): | ||||
|                 if not self.finished_game: | ||||
|                     asyncio.create_task(self.send_msgs([ | ||||
|                     async_start(self.send_msgs([ | ||||
|                         {"cmd": "StatusUpdate", "status": ClientStatus.CLIENT_GOAL} | ||||
|                     ])) | ||||
|                     self.finished_game = True | ||||
|  | @ -327,7 +328,7 @@ class ZillionContext(CommonContext): | |||
|                         "key": f"zillion-{self.auth}-doors", | ||||
|                         "operations": [{"operation": "replace", "value": doors_b64}] | ||||
|                     } | ||||
|                     asyncio.create_task(self.send_msgs([payload])) | ||||
|                     async_start(self.send_msgs([payload])) | ||||
|             else: | ||||
|                 logger.warning(f"WARNING: unhandled event from game {event_from_game}") | ||||
| 
 | ||||
|  | @ -410,7 +411,7 @@ async def zillion_sync_task(ctx: ZillionContext) -> None: | |||
|                                         ctx.next_item = 0 | ||||
|                                         ctx.ap_local_count = len(ctx.checked_locations) | ||||
|                                     else:  # no slot data yet | ||||
|                                         asyncio.create_task(ctx.send_connect()) | ||||
|                                         async_start(ctx.send_connect()) | ||||
|                                         log_no_spam("logging in to server...") | ||||
|                                         await asyncio.wait(( | ||||
|                                             ctx.got_slot_data.wait(), | ||||
|  | @ -434,7 +435,7 @@ async def zillion_sync_task(ctx: ZillionContext) -> None: | |||
|                     memory.reset_game_state() | ||||
| 
 | ||||
|                     ctx.auth = name | ||||
|                     asyncio.create_task(ctx.connect()) | ||||
|                     async_start(ctx.connect()) | ||||
|                     await asyncio.wait(( | ||||
|                         ctx.got_room_info.wait(), | ||||
|                         ctx.exit_event.wait(), | ||||
|  |  | |||
							
								
								
									
										5
									
								
								kvui.py
								
								
								
								
							
							
						
						
									
										5
									
								
								kvui.py
								
								
								
								
							|  | @ -49,6 +49,7 @@ fade_in_animation = Animation(opacity=0, duration=0) + Animation(opacity=1, dura | |||
| 
 | ||||
| 
 | ||||
| from NetUtils import JSONtoTextParser, JSONMessagePart, SlotType | ||||
| from Utils import async_start | ||||
| 
 | ||||
| if typing.TYPE_CHECKING: | ||||
|     import CommonClient | ||||
|  | @ -427,9 +428,9 @@ class GameManager(App): | |||
|         if self.ctx.server: | ||||
|             self.ctx.server_address = None | ||||
|             self.ctx.username = None | ||||
|             asyncio.create_task(self.ctx.disconnect()) | ||||
|             async_start(self.ctx.disconnect()) | ||||
|         else: | ||||
|             asyncio.create_task(self.ctx.connect(self.server_connect_bar.text.replace("/connect ", ""))) | ||||
|             async_start(self.ctx.connect(self.server_connect_bar.text.replace("/connect ", ""))) | ||||
| 
 | ||||
|     def on_stop(self): | ||||
|         # "kill" input tasks | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue