MultiServer and clients: async coroutine starter in Utils.py (#1143)

* async coroutine starter in Utils.py

* refactor from static class to function

* async_start docstring
This commit is contained in:
Doug Hoskisson 2022-11-02 07:51:35 -07:00 committed by GitHub
parent a6e1e14fee
commit da392239a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 82 additions and 54 deletions

View File

@ -20,7 +20,7 @@ if __name__ == "__main__":
from MultiServer import CommandProcessor from MultiServer import CommandProcessor
from NetUtils import Endpoint, decode, NetworkItem, encode, JSONtoTextParser, \ from NetUtils import Endpoint, decode, NetworkItem, encode, JSONtoTextParser, \
ClientStatus, Permission, NetworkSlot, RawJSONtoTextParser ClientStatus, Permission, NetworkSlot, RawJSONtoTextParser
from Utils import Version, stream_input from Utils import Version, stream_input, async_start
from worlds import network_data_package, AutoWorldRegister from worlds import network_data_package, AutoWorldRegister
import os import os
@ -46,14 +46,14 @@ class ClientCommandProcessor(CommandProcessor):
"""Connect to a MultiWorld Server""" """Connect to a MultiWorld Server"""
self.ctx.server_address = None self.ctx.server_address = None
self.ctx.username = None self.ctx.username = None
asyncio.create_task(self.ctx.connect(address if address else None), name="connecting") async_start(self.ctx.connect(address if address else None), name="connecting")
return True return True
def _cmd_disconnect(self) -> bool: def _cmd_disconnect(self) -> bool:
"""Disconnect from a MultiWorld Server""" """Disconnect from a MultiWorld Server"""
self.ctx.server_address = None self.ctx.server_address = None
self.ctx.username = None self.ctx.username = None
asyncio.create_task(self.ctx.disconnect(), name="disconnecting") async_start(self.ctx.disconnect(), name="disconnecting")
return True return True
def _cmd_received(self) -> bool: def _cmd_received(self) -> bool:
@ -116,12 +116,12 @@ class ClientCommandProcessor(CommandProcessor):
else: else:
state = ClientStatus.CLIENT_CONNECTED state = ClientStatus.CLIENT_CONNECTED
self.output("Unreadied.") self.output("Unreadied.")
asyncio.create_task(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}]), name="send StatusUpdate") async_start(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}]), name="send StatusUpdate")
def default(self, raw: str): def default(self, raw: str):
raw = self.ctx.on_user_say(raw) raw = self.ctx.on_user_say(raw)
if raw: if raw:
asyncio.create_task(self.ctx.send_msgs([{"cmd": "Say", "text": raw}]), name="send Say") async_start(self.ctx.send_msgs([{"cmd": "Say", "text": raw}]), name="send Say")
class CommonContext: class CommonContext:
@ -562,7 +562,7 @@ async def server_loop(ctx: CommonContext, address: typing.Optional[str] = None)
await ctx.connection_closed() await ctx.connection_closed()
if ctx.server_address: if ctx.server_address:
logger.info(f"... reconnecting in {ctx.current_reconnect_delay}s") logger.info(f"... reconnecting in {ctx.current_reconnect_delay}s")
asyncio.create_task(server_autoreconnect(ctx), name="server auto reconnect") async_start(server_autoreconnect(ctx), name="server auto reconnect")
ctx.current_reconnect_delay *= 2 ctx.current_reconnect_delay *= 2

View File

@ -7,6 +7,7 @@ from typing import List
import Utils import Utils
from Utils import async_start
from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \ from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \
get_base_parser get_base_parser
@ -69,7 +70,7 @@ class FF1Context(CommonContext):
def on_package(self, cmd: str, args: dict): def on_package(self, cmd: str, args: dict):
if cmd == 'Connected': if cmd == 'Connected':
asyncio.create_task(parse_locations(self.locations_array, self, True)) async_start(parse_locations(self.locations_array, self, True))
elif cmd == 'Print': elif cmd == 'Print':
msg = args['text'] msg = args['text']
if ': !' not in msg: if ': !' not in msg:
@ -180,7 +181,7 @@ async def nes_sync_task(ctx: FF1Context):
# print(data_decoded) # print(data_decoded)
if ctx.game is not None and 'locations' in data_decoded: if ctx.game is not None and 'locations' in data_decoded:
# Not just a keep alive ping, parse # Not just a keep alive ping, parse
asyncio.create_task(parse_locations(data_decoded['locations'], ctx, False)) async_start(parse_locations(data_decoded['locations'], ctx, False))
if not ctx.auth: if not ctx.auth:
ctx.auth = ''.join([chr(i) for i in data_decoded['playerName'] if i != 0]) ctx.auth = ''.join([chr(i) for i in data_decoded['playerName'] if i != 0])
if ctx.auth == '': if ctx.auth == '':

View File

@ -25,6 +25,7 @@ if __name__ == "__main__":
from CommonClient import CommonContext, server_loop, ClientCommandProcessor, logger, gui_enabled, get_base_parser from CommonClient import CommonContext, server_loop, ClientCommandProcessor, logger, gui_enabled, get_base_parser
from MultiServer import mark_raw from MultiServer import mark_raw
from NetUtils import NetworkItem, ClientStatus, JSONtoTextParser, JSONMessagePart from NetUtils import NetworkItem, ClientStatus, JSONtoTextParser, JSONMessagePart
from Utils import async_start
from worlds.factorio import Factorio from worlds.factorio import Factorio
@ -124,7 +125,7 @@ class FactorioContext(CommonContext):
self.rcon_client.send_commands({item_name: f'/ap-get-technology ap-{item_name}-\t-1' for self.rcon_client.send_commands({item_name: f'/ap-get-technology ap-{item_name}-\t-1' for
item_name in args["checked_locations"]}) item_name in args["checked_locations"]})
if cmd == "Connected" and self.energy_link_increment: if cmd == "Connected" and self.energy_link_increment:
asyncio.create_task(self.send_msgs([{ async_start(self.send_msgs([{
"cmd": "SetNotify", "keys": ["EnergyLink"] "cmd": "SetNotify", "keys": ["EnergyLink"]
}])) }]))
elif cmd == "SetReply": elif cmd == "SetReply":
@ -232,7 +233,7 @@ async def game_watcher(ctx: FactorioContext):
if death_link_tick != ctx.death_link_tick: if death_link_tick != ctx.death_link_tick:
ctx.death_link_tick = death_link_tick ctx.death_link_tick = death_link_tick
if "DeathLink" in ctx.tags: if "DeathLink" in ctx.tags:
asyncio.create_task(ctx.send_death()) async_start(ctx.send_death())
if ctx.energy_link_increment: if ctx.energy_link_increment:
in_world_bridges = data["energy_bridges"] in_world_bridges = data["energy_bridges"]
if in_world_bridges: if in_world_bridges:
@ -240,7 +241,7 @@ async def game_watcher(ctx: FactorioContext):
if in_world_energy < (ctx.energy_link_increment * in_world_bridges): if in_world_energy < (ctx.energy_link_increment * in_world_bridges):
# attempt to refill # attempt to refill
ctx.last_deplete = time.time() ctx.last_deplete = time.time()
asyncio.create_task(ctx.send_msgs([{ async_start(ctx.send_msgs([{
"cmd": "Set", "key": "EnergyLink", "operations": "cmd": "Set", "key": "EnergyLink", "operations":
[{"operation": "add", "value": -ctx.energy_link_increment * in_world_bridges}, [{"operation": "add", "value": -ctx.energy_link_increment * in_world_bridges},
{"operation": "max", "value": 0}], {"operation": "max", "value": 0}],
@ -250,7 +251,7 @@ async def game_watcher(ctx: FactorioContext):
elif in_world_energy > (in_world_bridges * ctx.energy_link_increment * 5) - \ elif in_world_energy > (in_world_bridges * ctx.energy_link_increment * 5) - \
ctx.energy_link_increment*in_world_bridges: ctx.energy_link_increment*in_world_bridges:
value = ctx.energy_link_increment * in_world_bridges value = ctx.energy_link_increment * in_world_bridges
asyncio.create_task(ctx.send_msgs([{ async_start(ctx.send_msgs([{
"cmd": "Set", "key": "EnergyLink", "operations": "cmd": "Set", "key": "EnergyLink", "operations":
[{"operation": "add", "value": value}] [{"operation": "add", "value": value}]
}])) }]))

View File

@ -31,7 +31,7 @@ except ImportError:
import NetUtils import NetUtils
import Utils import Utils
from Utils import version_tuple, restricted_loads, Version from Utils import version_tuple, restricted_loads, Version, async_start
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \ from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
SlotType SlotType
@ -273,16 +273,16 @@ class Context:
def broadcast_all(self, msgs: typing.List[dict]): def broadcast_all(self, msgs: typing.List[dict]):
msgs = self.dumper(msgs) msgs = self.dumper(msgs)
endpoints = (endpoint for endpoint in self.endpoints if endpoint.auth) endpoints = (endpoint for endpoint in self.endpoints if endpoint.auth)
asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) async_start(self.broadcast_send_encoded_msgs(endpoints, msgs))
def broadcast_team(self, team: int, msgs: typing.List[dict]): def broadcast_team(self, team: int, msgs: typing.List[dict]):
msgs = self.dumper(msgs) msgs = self.dumper(msgs)
endpoints = (endpoint for endpoint in itertools.chain.from_iterable(self.clients[team].values())) endpoints = (endpoint for endpoint in itertools.chain.from_iterable(self.clients[team].values()))
asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) async_start(self.broadcast_send_encoded_msgs(endpoints, msgs))
def broadcast(self, endpoints: typing.Iterable[Client], msgs: typing.List[dict]): def broadcast(self, endpoints: typing.Iterable[Client], msgs: typing.List[dict]):
msgs = self.dumper(msgs) msgs = self.dumper(msgs)
asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) async_start(self.broadcast_send_encoded_msgs(endpoints, msgs))
async def disconnect(self, endpoint: Client): async def disconnect(self, endpoint: Client):
if endpoint in self.endpoints: if endpoint in self.endpoints:
@ -302,18 +302,18 @@ class Context:
return return
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text)) logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
if client.version >= print_command_compatability_threshold: if client.version >= print_command_compatability_threshold:
asyncio.create_task(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }]}])) async_start(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }]}]))
else: else:
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}])) async_start(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
def notify_client_multiple(self, client: Client, texts: typing.List[str]): def notify_client_multiple(self, client: Client, texts: typing.List[str]):
if not client.auth: if not client.auth:
return return
if client.version >= print_command_compatability_threshold: if client.version >= print_command_compatability_threshold:
asyncio.create_task(self.send_msgs(client, async_start(self.send_msgs(client,
[{"cmd": "PrintJSON", "data": [{ "text": text }]} for text in texts])) [{"cmd": "PrintJSON", "data": [{ "text": text }]} for text in texts]))
else: else:
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts])) async_start(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
# loading # loading
@ -627,7 +627,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint], onl
continue continue
client_hints = [datum[1] for datum in sorted(hint_data, key=lambda x: x[0].finding_player == slot)] client_hints = [datum[1] for datum in sorted(hint_data, key=lambda x: x[0].finding_player == slot)]
for client in clients: for client in clients:
asyncio.create_task(ctx.send_msgs(client, client_hints)) async_start(ctx.send_msgs(client, client_hints))
def update_aliases(ctx: Context, team: int): def update_aliases(ctx: Context, team: int):
@ -636,7 +636,7 @@ def update_aliases(ctx: Context, team: int):
for clients in ctx.clients[team].values(): for clients in ctx.clients[team].values():
for client in clients: for client in clients:
asyncio.create_task(ctx.send_encoded_msgs(client, cmd)) async_start(ctx.send_encoded_msgs(client, cmd))
async def server(websocket, path: str = "/", ctx: Context = None): async def server(websocket, path: str = "/", ctx: Context = None):
@ -814,7 +814,7 @@ def send_new_items(ctx: Context):
items = get_received_items(ctx, team, slot, client.remote_items) items = get_received_items(ctx, team, slot, client.remote_items)
if len(start_inventory) + len(items) > client.send_index: if len(start_inventory) + len(items) > client.send_index:
first_new_item = max(0, client.send_index - len(start_inventory)) first_new_item = max(0, client.send_index - len(start_inventory))
asyncio.create_task(ctx.send_msgs(client, [{ async_start(ctx.send_msgs(client, [{
"cmd": "ReceivedItems", "cmd": "ReceivedItems",
"index": client.send_index, "index": client.send_index,
"items": start_inventory[client.send_index:] + items[first_new_item:]}])) "items": start_inventory[client.send_index:] + items[first_new_item:]}]))
@ -1090,7 +1090,7 @@ class CommonCommandProcessor(CommandProcessor):
timer = int(seconds, 10) timer = int(seconds, 10)
except ValueError: except ValueError:
timer = 10 timer = 10
asyncio.create_task(countdown(self.ctx, timer)) async_start(countdown(self.ctx, timer))
return True return True
def _cmd_options(self): def _cmd_options(self):
@ -1771,7 +1771,7 @@ class ServerCommandProcessor(CommonCommandProcessor):
def _cmd_exit(self) -> bool: def _cmd_exit(self) -> bool:
"""Shutdown the server""" """Shutdown the server"""
asyncio.create_task(self.ctx.server.ws_server._close()) async_start(self.ctx.server.ws_server._close())
if self.ctx.shutdown_task: if self.ctx.shutdown_task:
self.ctx.shutdown_task.cancel() self.ctx.shutdown_task.cancel()
self.ctx.exit_event.set() self.ctx.exit_event.set()
@ -2084,7 +2084,7 @@ async def auto_shutdown(ctx, to_cancel=None):
await asyncio.sleep(ctx.auto_shutdown) await asyncio.sleep(ctx.auto_shutdown)
while not ctx.exit_event.is_set(): while not ctx.exit_event.is_set():
if not ctx.client_activity_timers.values(): if not ctx.client_activity_timers.values():
asyncio.create_task(ctx.server.ws_server._close()) async_start(ctx.server.ws_server._close())
ctx.exit_event.set() ctx.exit_event.set()
if to_cancel: if to_cancel:
for task in to_cancel: for task in to_cancel:
@ -2095,7 +2095,7 @@ async def auto_shutdown(ctx, to_cancel=None):
delta = datetime.datetime.now(datetime.timezone.utc) - newest_activity delta = datetime.datetime.now(datetime.timezone.utc) - newest_activity
seconds = ctx.auto_shutdown - delta.total_seconds() seconds = ctx.auto_shutdown - delta.total_seconds()
if seconds < 0: if seconds < 0:
asyncio.create_task(ctx.server.ws_server._close()) async_start(ctx.server.ws_server._close())
ctx.exit_event.set() ctx.exit_event.set()
if to_cancel: if to_cancel:
for task in to_cancel: for task in to_cancel:

View File

@ -9,6 +9,7 @@ from asyncio import StreamReader, StreamWriter
from CommonClient import CommonContext, server_loop, gui_enabled, \ from CommonClient import CommonContext, server_loop, gui_enabled, \
ClientCommandProcessor, logger, get_base_parser ClientCommandProcessor, logger, get_base_parser
import Utils import Utils
from Utils import async_start
from worlds import network_data_package from worlds import network_data_package
from worlds.oot.Rom import Rom, compress_rom_file from worlds.oot.Rom import Rom, compress_rom_file
from worlds.oot.N64Patch import apply_patch_file from worlds.oot.N64Patch import apply_patch_file
@ -69,7 +70,7 @@ class OoTCommandProcessor(ClientCommandProcessor):
if isinstance(self.ctx, OoTContext): if isinstance(self.ctx, OoTContext):
self.ctx.deathlink_client_override = True self.ctx.deathlink_client_override = True
self.ctx.deathlink_enabled = not self.ctx.deathlink_enabled self.ctx.deathlink_enabled = not self.ctx.deathlink_enabled
asyncio.create_task(self.ctx.update_death_link(self.ctx.deathlink_enabled), name="Update Deathlink") async_start(self.ctx.update_death_link(self.ctx.deathlink_enabled), name="Update Deathlink")
class OoTContext(CommonContext): class OoTContext(CommonContext):
@ -203,7 +204,7 @@ async def n64_sync_task(ctx: OoTContext):
if reported_version >= script_version: if reported_version >= script_version:
if ctx.game is not None and 'locations' in data_decoded: if ctx.game is not None and 'locations' in data_decoded:
# Not just a keep alive ping, parse # Not just a keep alive ping, parse
asyncio.create_task(parse_payload(data_decoded, ctx, False)) async_start(parse_payload(data_decoded, ctx, False))
if not ctx.auth: if not ctx.auth:
ctx.auth = data_decoded['playerName'] ctx.auth = data_decoded['playerName']
if ctx.awaiting_rom: if ctx.awaiting_rom:
@ -279,7 +280,7 @@ async def patch_and_run_game(apz5_file):
os.chdir(data_path("Compress")) os.chdir(data_path("Compress"))
compress_rom_file(decomp_path, comp_path) compress_rom_file(decomp_path, comp_path)
os.remove(decomp_path) os.remove(decomp_path)
asyncio.create_task(run_game(comp_path)) async_start(run_game(comp_path))
if __name__ == '__main__': if __name__ == '__main__':
@ -295,7 +296,7 @@ if __name__ == '__main__':
if args.apz5_file: if args.apz5_file:
logger.info("APZ5 file supplied, beginning patching process...") logger.info("APZ5 file supplied, beginning patching process...")
asyncio.create_task(patch_and_run_game(args.apz5_file)) async_start(patch_and_run_game(args.apz5_file))
ctx = OoTContext(args.connect, args.password) ctx = OoTContext(args.connect, args.password)
ctx.server_task = asyncio.create_task(server_loop(ctx), name="Server Loop") ctx.server_task = asyncio.create_task(server_loop(ctx), name="Server Loop")

View File

@ -11,6 +11,7 @@ from typing import List
import Utils import Utils
from Utils import async_start
from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \ from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \
get_base_parser get_base_parser
@ -185,7 +186,7 @@ async def gb_sync_task(ctx: GBContext):
if 'locations' in data_decoded and ctx.game and ctx.gb_status == CONNECTION_CONNECTED_STATUS \ if 'locations' in data_decoded and ctx.game and ctx.gb_status == CONNECTION_CONNECTED_STATUS \
and not error_status and ctx.auth: and not error_status and ctx.auth:
# Not just a keep alive ping, parse # Not just a keep alive ping, parse
asyncio.create_task(parse_locations(data_decoded['locations'], ctx)) async_start(parse_locations(data_decoded['locations'], ctx))
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.debug("Read Timed Out, Reconnecting") logger.debug("Read Timed Out, Reconnecting")
error_status = CONNECTION_TIMING_OUT_STATUS error_status = CONNECTION_TIMING_OUT_STATUS
@ -265,7 +266,7 @@ async def patch_and_run_game(game_version, patch_file, ctx):
with open(comp_path, "wb") as patched_rom_file: with open(comp_path, "wb") as patched_rom_file:
patched_rom_file.write(patched_rom_data) patched_rom_file.write(patched_rom_data)
asyncio.create_task(run_game(comp_path)) async_start(run_game(comp_path))
else: else:
msg = "Patch supplied was not generated with the same base patch version as this client. Patching failed." msg = "Patch supplied was not generated with the same base patch version as this client. Patching failed."
logger.warning(msg) logger.warning(msg)
@ -295,10 +296,10 @@ if __name__ == '__main__':
ext = args.patch_file.split(".")[len(args.patch_file.split(".")) - 1].lower() ext = args.patch_file.split(".")[len(args.patch_file.split(".")) - 1].lower()
if ext == "apred": if ext == "apred":
logger.info("APRED file supplied, beginning patching process...") logger.info("APRED file supplied, beginning patching process...")
asyncio.create_task(patch_and_run_game("red", args.patch_file, ctx)) async_start(patch_and_run_game("red", args.patch_file, ctx))
elif ext == "apblue": elif ext == "apblue":
logger.info("APBLUE file supplied, beginning patching process...") logger.info("APBLUE file supplied, beginning patching process...")
asyncio.create_task(patch_and_run_game("blue", args.patch_file, ctx)) async_start(patch_and_run_game("blue", args.patch_file, ctx))
else: else:
logger.warning(f"Unknown patch file extension {ext}") logger.warning(f"Unknown patch file extension {ext}")

View File

@ -18,7 +18,7 @@ from json import loads, dumps
from CommonClient import CommonContext, server_loop, ClientCommandProcessor, gui_enabled, get_base_parser from CommonClient import CommonContext, server_loop, ClientCommandProcessor, gui_enabled, get_base_parser
import Utils import Utils
from Utils import async_start
from MultiServer import mark_raw from MultiServer import mark_raw
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
from worlds.AutoSNIClient import SNIClient from worlds.AutoSNIClient import SNIClient
@ -84,7 +84,7 @@ class SNIClientCommandProcessor(ClientCommandProcessor):
"""Close connection to a currently connected snes""" """Close connection to a currently connected snes"""
self.ctx.snes_reconnect_address = None self.ctx.snes_reconnect_address = None
if self.ctx.snes_socket is not None and not self.ctx.snes_socket.closed: if self.ctx.snes_socket is not None and not self.ctx.snes_socket.closed:
asyncio.create_task(self.ctx.snes_socket.close()) async_start(self.ctx.snes_socket.close())
return True return True
else: else:
return False return False
@ -96,7 +96,7 @@ class SNIClientCommandProcessor(ClientCommandProcessor):
# self.output("No attached SNES Device.") # self.output("No attached SNES Device.")
# return False # return False
# snes_buffered_write(self.ctx, int(address, 16), bytes([int(data)])) # snes_buffered_write(self.ctx, int(address, 16), bytes([int(data)]))
# asyncio.create_task(snes_flush_writes(self.ctx)) # async_start(snes_flush_writes(self.ctx))
# self.output("Data Sent") # self.output("Data Sent")
# return True # return True
@ -167,7 +167,7 @@ class SNIContext(CommonContext):
def event_invalid_slot(self) -> typing.NoReturn: def event_invalid_slot(self) -> typing.NoReturn:
if self.snes_socket is not None and not self.snes_socket.closed: if self.snes_socket is not None and not self.snes_socket.closed:
asyncio.create_task(self.snes_socket.close()) async_start(self.snes_socket.close())
raise Exception("Invalid ROM detected, " raise Exception("Invalid ROM detected, "
"please verify that you have loaded the correct rom and reconnect your snes (/snes)") "please verify that you have loaded the correct rom and reconnect your snes (/snes)")
@ -230,7 +230,7 @@ class SNIContext(CommonContext):
# since the player will likely need that item. # since the player will likely need that item.
# Once the games handled by SNIClient gets made to be remote items, # Once the games handled by SNIClient gets made to be remote items,
# this will no longer be needed. # this will no longer be needed.
asyncio.create_task(self.send_msgs([{"cmd": "LocationScouts", "locations": list(new_locations)}])) async_start(self.send_msgs([{"cmd": "LocationScouts", "locations": list(new_locations)}]))
def run_gui(self) -> None: def run_gui(self) -> None:
from kvui import GameManager from kvui import GameManager
@ -443,7 +443,7 @@ async def snes_connect(ctx: SNIContext, address: str, deviceIndex: int = -1) ->
snes_logger.error("Error connecting to snes (%s)" % e) snes_logger.error("Error connecting to snes (%s)" % e)
else: else:
snes_logger.error(f"Error connecting to snes, attempt again in {_global_snes_reconnect_delay}s") snes_logger.error(f"Error connecting to snes, attempt again in {_global_snes_reconnect_delay}s")
asyncio.create_task(snes_autoreconnect(ctx)) async_start(snes_autoreconnect(ctx))
_global_snes_reconnect_delay *= 2 _global_snes_reconnect_delay *= 2
else: else:
@ -488,7 +488,7 @@ async def snes_recv_loop(ctx: SNIContext) -> None:
if ctx.snes_reconnect_address: if ctx.snes_reconnect_address:
snes_logger.info(f"...reconnecting in {_global_snes_reconnect_delay}s") snes_logger.info(f"...reconnecting in {_global_snes_reconnect_delay}s")
asyncio.create_task(snes_autoreconnect(ctx)) async_start(snes_autoreconnect(ctx))
async def snes_read(ctx: SNIContext, address: int, size: int) -> typing.Optional[bytes]: async def snes_read(ctx: SNIContext, address: int, size: int) -> typing.Optional[bytes]:
@ -674,9 +674,9 @@ async def main() -> None:
elif args.diff_file.endswith(".aplttp"): elif args.diff_file.endswith(".aplttp"):
from worlds.alttp.Client import get_alttp_settings from worlds.alttp.Client import get_alttp_settings
adjustedromfile, adjusted = get_alttp_settings(romfile) adjustedromfile, adjusted = get_alttp_settings(romfile)
asyncio.create_task(run_game(adjustedromfile if adjusted else romfile)) async_start(run_game(adjustedromfile if adjusted else romfile))
else: else:
asyncio.create_task(run_game(romfile)) async_start(run_game(romfile))
ctx = SNIContext(args.snes, args.connect, args.password) ctx = SNIContext(args.snes, args.connect, args.password)
if ctx.server_task is None: if ctx.server_task is None:

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import typing import typing
import builtins import builtins
import os import os
@ -11,7 +12,7 @@ import io
import collections import collections
import importlib import importlib
import logging import logging
from typing import BinaryIO from typing import BinaryIO, ClassVar, Coroutine, Optional, Set
from yaml import load, load_all, dump, SafeLoader from yaml import load, load_all, dump, SafeLoader
@ -650,3 +651,24 @@ def read_snes_rom(stream: BinaryIO, strip_header: bool = True) -> bytearray:
if strip_header and len(buffer) % 0x400 == 0x200: if strip_header and len(buffer) % 0x400 == 0x200:
return buffer[0x200:] return buffer[0x200:]
return buffer return buffer
_faf_tasks: "Set[asyncio.Task[None]]" = set()
def async_start(co: Coroutine[None, None, None], name: Optional[str] = None) -> None:
"""
Use this to start a task when you don't keep a reference to it or immediately await it,
to prevent early garbage collection. "fire-and-forget"
"""
# https://docs.python.org/3.10/library/asyncio-task.html#asyncio.create_task
# Python docs:
# ```
# Important: Save a reference to the result of [asyncio.create_task],
# to avoid a task disappearing mid-execution.
# ```
# This implementation follows the pattern given in that documentation.
task = asyncio.create_task(co, name=name)
_faf_tasks.add(task)
task.add_done_callback(_faf_tasks.discard)

View File

@ -8,6 +8,7 @@ from CommonClient import CommonContext, server_loop, gui_enabled, \
ClientCommandProcessor, logger, get_base_parser ClientCommandProcessor, logger, get_base_parser
from NetUtils import ClientStatus from NetUtils import ClientStatus
import Utils import Utils
from Utils import async_start
import colorama # type: ignore import colorama # type: ignore
@ -263,7 +264,7 @@ class ZillionContext(CommonContext):
"cmd": "Get", "cmd": "Get",
"keys": [f"zillion-{self.auth}-doors"] "keys": [f"zillion-{self.auth}-doors"]
} }
asyncio.create_task(self.send_msgs([payload])) async_start(self.send_msgs([payload]))
elif cmd == "Retrieved": elif cmd == "Retrieved":
if "keys" not in args: if "keys" not in args:
logger.warning(f"invalid Retrieved packet to ZillionClient: {args}") logger.warning(f"invalid Retrieved packet to ZillionClient: {args}")
@ -304,7 +305,7 @@ class ZillionContext(CommonContext):
self.ap_local_count += 1 self.ap_local_count += 1
n_locations = len(self.missing_locations) + len(self.checked_locations) - 1 # -1 to ignore win n_locations = len(self.missing_locations) + len(self.checked_locations) - 1 # -1 to ignore win
logger.info(f'New Check: {loc_name} ({self.ap_local_count}/{n_locations})') logger.info(f'New Check: {loc_name} ({self.ap_local_count}/{n_locations})')
asyncio.create_task(self.send_msgs([ async_start(self.send_msgs([
{"cmd": 'LocationChecks', "locations": [server_id]} {"cmd": 'LocationChecks', "locations": [server_id]}
])) ]))
else: else:
@ -312,10 +313,10 @@ class ZillionContext(CommonContext):
# because all the key words are local and unwatched by the server. # because all the key words are local and unwatched by the server.
logger.debug(f"DEBUG: {loc_name} not in missing") logger.debug(f"DEBUG: {loc_name} not in missing")
elif isinstance(event_from_game, events.DeathEventFromGame): elif isinstance(event_from_game, events.DeathEventFromGame):
asyncio.create_task(self.send_death()) async_start(self.send_death())
elif isinstance(event_from_game, events.WinEventFromGame): elif isinstance(event_from_game, events.WinEventFromGame):
if not self.finished_game: if not self.finished_game:
asyncio.create_task(self.send_msgs([ async_start(self.send_msgs([
{"cmd": "StatusUpdate", "status": ClientStatus.CLIENT_GOAL} {"cmd": "StatusUpdate", "status": ClientStatus.CLIENT_GOAL}
])) ]))
self.finished_game = True self.finished_game = True
@ -327,7 +328,7 @@ class ZillionContext(CommonContext):
"key": f"zillion-{self.auth}-doors", "key": f"zillion-{self.auth}-doors",
"operations": [{"operation": "replace", "value": doors_b64}] "operations": [{"operation": "replace", "value": doors_b64}]
} }
asyncio.create_task(self.send_msgs([payload])) async_start(self.send_msgs([payload]))
else: else:
logger.warning(f"WARNING: unhandled event from game {event_from_game}") logger.warning(f"WARNING: unhandled event from game {event_from_game}")
@ -410,7 +411,7 @@ async def zillion_sync_task(ctx: ZillionContext) -> None:
ctx.next_item = 0 ctx.next_item = 0
ctx.ap_local_count = len(ctx.checked_locations) ctx.ap_local_count = len(ctx.checked_locations)
else: # no slot data yet else: # no slot data yet
asyncio.create_task(ctx.send_connect()) async_start(ctx.send_connect())
log_no_spam("logging in to server...") log_no_spam("logging in to server...")
await asyncio.wait(( await asyncio.wait((
ctx.got_slot_data.wait(), ctx.got_slot_data.wait(),
@ -434,7 +435,7 @@ async def zillion_sync_task(ctx: ZillionContext) -> None:
memory.reset_game_state() memory.reset_game_state()
ctx.auth = name ctx.auth = name
asyncio.create_task(ctx.connect()) async_start(ctx.connect())
await asyncio.wait(( await asyncio.wait((
ctx.got_room_info.wait(), ctx.got_room_info.wait(),
ctx.exit_event.wait(), ctx.exit_event.wait(),

View File

@ -49,6 +49,7 @@ fade_in_animation = Animation(opacity=0, duration=0) + Animation(opacity=1, dura
from NetUtils import JSONtoTextParser, JSONMessagePart, SlotType from NetUtils import JSONtoTextParser, JSONMessagePart, SlotType
from Utils import async_start
if typing.TYPE_CHECKING: if typing.TYPE_CHECKING:
import CommonClient import CommonClient
@ -427,9 +428,9 @@ class GameManager(App):
if self.ctx.server: if self.ctx.server:
self.ctx.server_address = None self.ctx.server_address = None
self.ctx.username = None self.ctx.username = None
asyncio.create_task(self.ctx.disconnect()) async_start(self.ctx.disconnect())
else: else:
asyncio.create_task(self.ctx.connect(self.server_connect_bar.text.replace("/connect ", ""))) async_start(self.ctx.connect(self.server_connect_bar.text.replace("/connect ", "")))
def on_stop(self): def on_stop(self):
# "kill" input tasks # "kill" input tasks