MultiServer and clients: async coroutine starter in Utils.py (#1143)

* async coroutine starter in Utils.py

* refactor from static class to function

* async_start docstring
This commit is contained in:
Doug Hoskisson 2022-11-02 07:51:35 -07:00 committed by GitHub
parent a6e1e14fee
commit da392239a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 82 additions and 54 deletions

View File

@ -20,7 +20,7 @@ if __name__ == "__main__":
from MultiServer import CommandProcessor
from NetUtils import Endpoint, decode, NetworkItem, encode, JSONtoTextParser, \
ClientStatus, Permission, NetworkSlot, RawJSONtoTextParser
from Utils import Version, stream_input
from Utils import Version, stream_input, async_start
from worlds import network_data_package, AutoWorldRegister
import os
@ -46,14 +46,14 @@ class ClientCommandProcessor(CommandProcessor):
"""Connect to a MultiWorld Server"""
self.ctx.server_address = None
self.ctx.username = None
asyncio.create_task(self.ctx.connect(address if address else None), name="connecting")
async_start(self.ctx.connect(address if address else None), name="connecting")
return True
def _cmd_disconnect(self) -> bool:
"""Disconnect from a MultiWorld Server"""
self.ctx.server_address = None
self.ctx.username = None
asyncio.create_task(self.ctx.disconnect(), name="disconnecting")
async_start(self.ctx.disconnect(), name="disconnecting")
return True
def _cmd_received(self) -> bool:
@ -116,12 +116,12 @@ class ClientCommandProcessor(CommandProcessor):
else:
state = ClientStatus.CLIENT_CONNECTED
self.output("Unreadied.")
asyncio.create_task(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}]), name="send StatusUpdate")
async_start(self.ctx.send_msgs([{"cmd": "StatusUpdate", "status": state}]), name="send StatusUpdate")
def default(self, raw: str):
raw = self.ctx.on_user_say(raw)
if raw:
asyncio.create_task(self.ctx.send_msgs([{"cmd": "Say", "text": raw}]), name="send Say")
async_start(self.ctx.send_msgs([{"cmd": "Say", "text": raw}]), name="send Say")
class CommonContext:
@ -562,7 +562,7 @@ async def server_loop(ctx: CommonContext, address: typing.Optional[str] = None)
await ctx.connection_closed()
if ctx.server_address:
logger.info(f"... reconnecting in {ctx.current_reconnect_delay}s")
asyncio.create_task(server_autoreconnect(ctx), name="server auto reconnect")
async_start(server_autoreconnect(ctx), name="server auto reconnect")
ctx.current_reconnect_delay *= 2

View File

@ -7,6 +7,7 @@ from typing import List
import Utils
from Utils import async_start
from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \
get_base_parser
@ -69,7 +70,7 @@ class FF1Context(CommonContext):
def on_package(self, cmd: str, args: dict):
if cmd == 'Connected':
asyncio.create_task(parse_locations(self.locations_array, self, True))
async_start(parse_locations(self.locations_array, self, True))
elif cmd == 'Print':
msg = args['text']
if ': !' not in msg:
@ -180,7 +181,7 @@ async def nes_sync_task(ctx: FF1Context):
# print(data_decoded)
if ctx.game is not None and 'locations' in data_decoded:
# Not just a keep alive ping, parse
asyncio.create_task(parse_locations(data_decoded['locations'], ctx, False))
async_start(parse_locations(data_decoded['locations'], ctx, False))
if not ctx.auth:
ctx.auth = ''.join([chr(i) for i in data_decoded['playerName'] if i != 0])
if ctx.auth == '':

View File

@ -25,6 +25,7 @@ if __name__ == "__main__":
from CommonClient import CommonContext, server_loop, ClientCommandProcessor, logger, gui_enabled, get_base_parser
from MultiServer import mark_raw
from NetUtils import NetworkItem, ClientStatus, JSONtoTextParser, JSONMessagePart
from Utils import async_start
from worlds.factorio import Factorio
@ -124,7 +125,7 @@ class FactorioContext(CommonContext):
self.rcon_client.send_commands({item_name: f'/ap-get-technology ap-{item_name}-\t-1' for
item_name in args["checked_locations"]})
if cmd == "Connected" and self.energy_link_increment:
asyncio.create_task(self.send_msgs([{
async_start(self.send_msgs([{
"cmd": "SetNotify", "keys": ["EnergyLink"]
}]))
elif cmd == "SetReply":
@ -232,7 +233,7 @@ async def game_watcher(ctx: FactorioContext):
if death_link_tick != ctx.death_link_tick:
ctx.death_link_tick = death_link_tick
if "DeathLink" in ctx.tags:
asyncio.create_task(ctx.send_death())
async_start(ctx.send_death())
if ctx.energy_link_increment:
in_world_bridges = data["energy_bridges"]
if in_world_bridges:
@ -240,7 +241,7 @@ async def game_watcher(ctx: FactorioContext):
if in_world_energy < (ctx.energy_link_increment * in_world_bridges):
# attempt to refill
ctx.last_deplete = time.time()
asyncio.create_task(ctx.send_msgs([{
async_start(ctx.send_msgs([{
"cmd": "Set", "key": "EnergyLink", "operations":
[{"operation": "add", "value": -ctx.energy_link_increment * in_world_bridges},
{"operation": "max", "value": 0}],
@ -250,7 +251,7 @@ async def game_watcher(ctx: FactorioContext):
elif in_world_energy > (in_world_bridges * ctx.energy_link_increment * 5) - \
ctx.energy_link_increment*in_world_bridges:
value = ctx.energy_link_increment * in_world_bridges
asyncio.create_task(ctx.send_msgs([{
async_start(ctx.send_msgs([{
"cmd": "Set", "key": "EnergyLink", "operations":
[{"operation": "add", "value": value}]
}]))

View File

@ -31,7 +31,7 @@ except ImportError:
import NetUtils
import Utils
from Utils import version_tuple, restricted_loads, Version
from Utils import version_tuple, restricted_loads, Version, async_start
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
SlotType
@ -273,16 +273,16 @@ class Context:
def broadcast_all(self, msgs: typing.List[dict]):
msgs = self.dumper(msgs)
endpoints = (endpoint for endpoint in self.endpoints if endpoint.auth)
asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs))
async_start(self.broadcast_send_encoded_msgs(endpoints, msgs))
def broadcast_team(self, team: int, msgs: typing.List[dict]):
msgs = self.dumper(msgs)
endpoints = (endpoint for endpoint in itertools.chain.from_iterable(self.clients[team].values()))
asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs))
async_start(self.broadcast_send_encoded_msgs(endpoints, msgs))
def broadcast(self, endpoints: typing.Iterable[Client], msgs: typing.List[dict]):
msgs = self.dumper(msgs)
asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs))
async_start(self.broadcast_send_encoded_msgs(endpoints, msgs))
async def disconnect(self, endpoint: Client):
if endpoint in self.endpoints:
@ -302,18 +302,18 @@ class Context:
return
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
if client.version >= print_command_compatability_threshold:
asyncio.create_task(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }]}]))
async_start(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }]}]))
else:
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
async_start(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
def notify_client_multiple(self, client: Client, texts: typing.List[str]):
if not client.auth:
return
if client.version >= print_command_compatability_threshold:
asyncio.create_task(self.send_msgs(client,
async_start(self.send_msgs(client,
[{"cmd": "PrintJSON", "data": [{ "text": text }]} for text in texts]))
else:
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
async_start(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
# loading
@ -627,7 +627,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint], onl
continue
client_hints = [datum[1] for datum in sorted(hint_data, key=lambda x: x[0].finding_player == slot)]
for client in clients:
asyncio.create_task(ctx.send_msgs(client, client_hints))
async_start(ctx.send_msgs(client, client_hints))
def update_aliases(ctx: Context, team: int):
@ -636,7 +636,7 @@ def update_aliases(ctx: Context, team: int):
for clients in ctx.clients[team].values():
for client in clients:
asyncio.create_task(ctx.send_encoded_msgs(client, cmd))
async_start(ctx.send_encoded_msgs(client, cmd))
async def server(websocket, path: str = "/", ctx: Context = None):
@ -814,7 +814,7 @@ def send_new_items(ctx: Context):
items = get_received_items(ctx, team, slot, client.remote_items)
if len(start_inventory) + len(items) > client.send_index:
first_new_item = max(0, client.send_index - len(start_inventory))
asyncio.create_task(ctx.send_msgs(client, [{
async_start(ctx.send_msgs(client, [{
"cmd": "ReceivedItems",
"index": client.send_index,
"items": start_inventory[client.send_index:] + items[first_new_item:]}]))
@ -1090,7 +1090,7 @@ class CommonCommandProcessor(CommandProcessor):
timer = int(seconds, 10)
except ValueError:
timer = 10
asyncio.create_task(countdown(self.ctx, timer))
async_start(countdown(self.ctx, timer))
return True
def _cmd_options(self):
@ -1771,7 +1771,7 @@ class ServerCommandProcessor(CommonCommandProcessor):
def _cmd_exit(self) -> bool:
"""Shutdown the server"""
asyncio.create_task(self.ctx.server.ws_server._close())
async_start(self.ctx.server.ws_server._close())
if self.ctx.shutdown_task:
self.ctx.shutdown_task.cancel()
self.ctx.exit_event.set()
@ -2084,7 +2084,7 @@ async def auto_shutdown(ctx, to_cancel=None):
await asyncio.sleep(ctx.auto_shutdown)
while not ctx.exit_event.is_set():
if not ctx.client_activity_timers.values():
asyncio.create_task(ctx.server.ws_server._close())
async_start(ctx.server.ws_server._close())
ctx.exit_event.set()
if to_cancel:
for task in to_cancel:
@ -2095,7 +2095,7 @@ async def auto_shutdown(ctx, to_cancel=None):
delta = datetime.datetime.now(datetime.timezone.utc) - newest_activity
seconds = ctx.auto_shutdown - delta.total_seconds()
if seconds < 0:
asyncio.create_task(ctx.server.ws_server._close())
async_start(ctx.server.ws_server._close())
ctx.exit_event.set()
if to_cancel:
for task in to_cancel:

View File

@ -9,6 +9,7 @@ from asyncio import StreamReader, StreamWriter
from CommonClient import CommonContext, server_loop, gui_enabled, \
ClientCommandProcessor, logger, get_base_parser
import Utils
from Utils import async_start
from worlds import network_data_package
from worlds.oot.Rom import Rom, compress_rom_file
from worlds.oot.N64Patch import apply_patch_file
@ -69,7 +70,7 @@ class OoTCommandProcessor(ClientCommandProcessor):
if isinstance(self.ctx, OoTContext):
self.ctx.deathlink_client_override = True
self.ctx.deathlink_enabled = not self.ctx.deathlink_enabled
asyncio.create_task(self.ctx.update_death_link(self.ctx.deathlink_enabled), name="Update Deathlink")
async_start(self.ctx.update_death_link(self.ctx.deathlink_enabled), name="Update Deathlink")
class OoTContext(CommonContext):
@ -203,7 +204,7 @@ async def n64_sync_task(ctx: OoTContext):
if reported_version >= script_version:
if ctx.game is not None and 'locations' in data_decoded:
# Not just a keep alive ping, parse
asyncio.create_task(parse_payload(data_decoded, ctx, False))
async_start(parse_payload(data_decoded, ctx, False))
if not ctx.auth:
ctx.auth = data_decoded['playerName']
if ctx.awaiting_rom:
@ -279,7 +280,7 @@ async def patch_and_run_game(apz5_file):
os.chdir(data_path("Compress"))
compress_rom_file(decomp_path, comp_path)
os.remove(decomp_path)
asyncio.create_task(run_game(comp_path))
async_start(run_game(comp_path))
if __name__ == '__main__':
@ -295,7 +296,7 @@ if __name__ == '__main__':
if args.apz5_file:
logger.info("APZ5 file supplied, beginning patching process...")
asyncio.create_task(patch_and_run_game(args.apz5_file))
async_start(patch_and_run_game(args.apz5_file))
ctx = OoTContext(args.connect, args.password)
ctx.server_task = asyncio.create_task(server_loop(ctx), name="Server Loop")

View File

@ -11,6 +11,7 @@ from typing import List
import Utils
from Utils import async_start
from CommonClient import CommonContext, server_loop, gui_enabled, ClientCommandProcessor, logger, \
get_base_parser
@ -185,7 +186,7 @@ async def gb_sync_task(ctx: GBContext):
if 'locations' in data_decoded and ctx.game and ctx.gb_status == CONNECTION_CONNECTED_STATUS \
and not error_status and ctx.auth:
# Not just a keep alive ping, parse
asyncio.create_task(parse_locations(data_decoded['locations'], ctx))
async_start(parse_locations(data_decoded['locations'], ctx))
except asyncio.TimeoutError:
logger.debug("Read Timed Out, Reconnecting")
error_status = CONNECTION_TIMING_OUT_STATUS
@ -265,7 +266,7 @@ async def patch_and_run_game(game_version, patch_file, ctx):
with open(comp_path, "wb") as patched_rom_file:
patched_rom_file.write(patched_rom_data)
asyncio.create_task(run_game(comp_path))
async_start(run_game(comp_path))
else:
msg = "Patch supplied was not generated with the same base patch version as this client. Patching failed."
logger.warning(msg)
@ -295,10 +296,10 @@ if __name__ == '__main__':
ext = args.patch_file.split(".")[len(args.patch_file.split(".")) - 1].lower()
if ext == "apred":
logger.info("APRED file supplied, beginning patching process...")
asyncio.create_task(patch_and_run_game("red", args.patch_file, ctx))
async_start(patch_and_run_game("red", args.patch_file, ctx))
elif ext == "apblue":
logger.info("APBLUE file supplied, beginning patching process...")
asyncio.create_task(patch_and_run_game("blue", args.patch_file, ctx))
async_start(patch_and_run_game("blue", args.patch_file, ctx))
else:
logger.warning(f"Unknown patch file extension {ext}")

View File

@ -18,7 +18,7 @@ from json import loads, dumps
from CommonClient import CommonContext, server_loop, ClientCommandProcessor, gui_enabled, get_base_parser
import Utils
from Utils import async_start
from MultiServer import mark_raw
if typing.TYPE_CHECKING:
from worlds.AutoSNIClient import SNIClient
@ -84,7 +84,7 @@ class SNIClientCommandProcessor(ClientCommandProcessor):
"""Close connection to a currently connected snes"""
self.ctx.snes_reconnect_address = None
if self.ctx.snes_socket is not None and not self.ctx.snes_socket.closed:
asyncio.create_task(self.ctx.snes_socket.close())
async_start(self.ctx.snes_socket.close())
return True
else:
return False
@ -96,7 +96,7 @@ class SNIClientCommandProcessor(ClientCommandProcessor):
# self.output("No attached SNES Device.")
# return False
# snes_buffered_write(self.ctx, int(address, 16), bytes([int(data)]))
# asyncio.create_task(snes_flush_writes(self.ctx))
# async_start(snes_flush_writes(self.ctx))
# self.output("Data Sent")
# return True
@ -167,7 +167,7 @@ class SNIContext(CommonContext):
def event_invalid_slot(self) -> typing.NoReturn:
if self.snes_socket is not None and not self.snes_socket.closed:
asyncio.create_task(self.snes_socket.close())
async_start(self.snes_socket.close())
raise Exception("Invalid ROM detected, "
"please verify that you have loaded the correct rom and reconnect your snes (/snes)")
@ -230,7 +230,7 @@ class SNIContext(CommonContext):
# since the player will likely need that item.
# Once the games handled by SNIClient gets made to be remote items,
# this will no longer be needed.
asyncio.create_task(self.send_msgs([{"cmd": "LocationScouts", "locations": list(new_locations)}]))
async_start(self.send_msgs([{"cmd": "LocationScouts", "locations": list(new_locations)}]))
def run_gui(self) -> None:
from kvui import GameManager
@ -443,7 +443,7 @@ async def snes_connect(ctx: SNIContext, address: str, deviceIndex: int = -1) ->
snes_logger.error("Error connecting to snes (%s)" % e)
else:
snes_logger.error(f"Error connecting to snes, attempt again in {_global_snes_reconnect_delay}s")
asyncio.create_task(snes_autoreconnect(ctx))
async_start(snes_autoreconnect(ctx))
_global_snes_reconnect_delay *= 2
else:
@ -488,7 +488,7 @@ async def snes_recv_loop(ctx: SNIContext) -> None:
if ctx.snes_reconnect_address:
snes_logger.info(f"...reconnecting in {_global_snes_reconnect_delay}s")
asyncio.create_task(snes_autoreconnect(ctx))
async_start(snes_autoreconnect(ctx))
async def snes_read(ctx: SNIContext, address: int, size: int) -> typing.Optional[bytes]:
@ -674,9 +674,9 @@ async def main() -> None:
elif args.diff_file.endswith(".aplttp"):
from worlds.alttp.Client import get_alttp_settings
adjustedromfile, adjusted = get_alttp_settings(romfile)
asyncio.create_task(run_game(adjustedromfile if adjusted else romfile))
async_start(run_game(adjustedromfile if adjusted else romfile))
else:
asyncio.create_task(run_game(romfile))
async_start(run_game(romfile))
ctx = SNIContext(args.snes, args.connect, args.password)
if ctx.server_task is None:

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import typing
import builtins
import os
@ -11,7 +12,7 @@ import io
import collections
import importlib
import logging
from typing import BinaryIO
from typing import BinaryIO, ClassVar, Coroutine, Optional, Set
from yaml import load, load_all, dump, SafeLoader
@ -650,3 +651,24 @@ def read_snes_rom(stream: BinaryIO, strip_header: bool = True) -> bytearray:
if strip_header and len(buffer) % 0x400 == 0x200:
return buffer[0x200:]
return buffer
_faf_tasks: "Set[asyncio.Task[None]]" = set()
def async_start(co: Coroutine[None, None, None], name: Optional[str] = None) -> None:
"""
Use this to start a task when you don't keep a reference to it or immediately await it,
to prevent early garbage collection. "fire-and-forget"
"""
# https://docs.python.org/3.10/library/asyncio-task.html#asyncio.create_task
# Python docs:
# ```
# Important: Save a reference to the result of [asyncio.create_task],
# to avoid a task disappearing mid-execution.
# ```
# This implementation follows the pattern given in that documentation.
task = asyncio.create_task(co, name=name)
_faf_tasks.add(task)
task.add_done_callback(_faf_tasks.discard)

View File

@ -8,6 +8,7 @@ from CommonClient import CommonContext, server_loop, gui_enabled, \
ClientCommandProcessor, logger, get_base_parser
from NetUtils import ClientStatus
import Utils
from Utils import async_start
import colorama # type: ignore
@ -263,7 +264,7 @@ class ZillionContext(CommonContext):
"cmd": "Get",
"keys": [f"zillion-{self.auth}-doors"]
}
asyncio.create_task(self.send_msgs([payload]))
async_start(self.send_msgs([payload]))
elif cmd == "Retrieved":
if "keys" not in args:
logger.warning(f"invalid Retrieved packet to ZillionClient: {args}")
@ -304,7 +305,7 @@ class ZillionContext(CommonContext):
self.ap_local_count += 1
n_locations = len(self.missing_locations) + len(self.checked_locations) - 1 # -1 to ignore win
logger.info(f'New Check: {loc_name} ({self.ap_local_count}/{n_locations})')
asyncio.create_task(self.send_msgs([
async_start(self.send_msgs([
{"cmd": 'LocationChecks', "locations": [server_id]}
]))
else:
@ -312,10 +313,10 @@ class ZillionContext(CommonContext):
# because all the key words are local and unwatched by the server.
logger.debug(f"DEBUG: {loc_name} not in missing")
elif isinstance(event_from_game, events.DeathEventFromGame):
asyncio.create_task(self.send_death())
async_start(self.send_death())
elif isinstance(event_from_game, events.WinEventFromGame):
if not self.finished_game:
asyncio.create_task(self.send_msgs([
async_start(self.send_msgs([
{"cmd": "StatusUpdate", "status": ClientStatus.CLIENT_GOAL}
]))
self.finished_game = True
@ -327,7 +328,7 @@ class ZillionContext(CommonContext):
"key": f"zillion-{self.auth}-doors",
"operations": [{"operation": "replace", "value": doors_b64}]
}
asyncio.create_task(self.send_msgs([payload]))
async_start(self.send_msgs([payload]))
else:
logger.warning(f"WARNING: unhandled event from game {event_from_game}")
@ -410,7 +411,7 @@ async def zillion_sync_task(ctx: ZillionContext) -> None:
ctx.next_item = 0
ctx.ap_local_count = len(ctx.checked_locations)
else: # no slot data yet
asyncio.create_task(ctx.send_connect())
async_start(ctx.send_connect())
log_no_spam("logging in to server...")
await asyncio.wait((
ctx.got_slot_data.wait(),
@ -434,7 +435,7 @@ async def zillion_sync_task(ctx: ZillionContext) -> None:
memory.reset_game_state()
ctx.auth = name
asyncio.create_task(ctx.connect())
async_start(ctx.connect())
await asyncio.wait((
ctx.got_room_info.wait(),
ctx.exit_event.wait(),

View File

@ -49,6 +49,7 @@ fade_in_animation = Animation(opacity=0, duration=0) + Animation(opacity=1, dura
from NetUtils import JSONtoTextParser, JSONMessagePart, SlotType
from Utils import async_start
if typing.TYPE_CHECKING:
import CommonClient
@ -427,9 +428,9 @@ class GameManager(App):
if self.ctx.server:
self.ctx.server_address = None
self.ctx.username = None
asyncio.create_task(self.ctx.disconnect())
async_start(self.ctx.disconnect())
else:
asyncio.create_task(self.ctx.connect(self.server_connect_bar.text.replace("/connect ", "")))
async_start(self.ctx.connect(self.server_connect_bar.text.replace("/connect ", "")))
def on_stop(self):
# "kill" input tasks