diff --git a/worlds/_bizhawk/__init__.py b/worlds/_bizhawk/__init__.py index cdf227ec..34039908 100644 --- a/worlds/_bizhawk/__init__.py +++ b/worlds/_bizhawk/__init__.py @@ -13,7 +13,6 @@ import typing BIZHAWK_SOCKET_PORT = 43055 -EXPECTED_SCRIPT_VERSION = 1 class ConnectionStatus(enum.IntEnum): @@ -22,15 +21,6 @@ class ConnectionStatus(enum.IntEnum): CONNECTED = 3 -class BizHawkContext: - streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]] - connection_status: ConnectionStatus - - def __init__(self) -> None: - self.streams = None - self.connection_status = ConnectionStatus.NOT_CONNECTED - - class NotConnectedError(Exception): """Raised when something tries to make a request to the connector script before a connection has been established""" pass @@ -51,6 +41,50 @@ class SyncError(Exception): pass +class BizHawkContext: + streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]] + connection_status: ConnectionStatus + _lock: asyncio.Lock + + def __init__(self) -> None: + self.streams = None + self.connection_status = ConnectionStatus.NOT_CONNECTED + self._lock = asyncio.Lock() + + async def _send_message(self, message: str): + async with self._lock: + if self.streams is None: + raise NotConnectedError("You tried to send a request before a connection to BizHawk was made") + + try: + reader, writer = self.streams + writer.write(message.encode("utf-8") + b"\n") + await asyncio.wait_for(writer.drain(), timeout=5) + + res = await asyncio.wait_for(reader.readline(), timeout=5) + + if res == b"": + writer.close() + self.streams = None + self.connection_status = ConnectionStatus.NOT_CONNECTED + raise RequestFailedError("Connection closed") + + if self.connection_status == ConnectionStatus.TENTATIVE: + self.connection_status = ConnectionStatus.CONNECTED + + return res.decode("utf-8") + except asyncio.TimeoutError as exc: + writer.close() + self.streams = None + self.connection_status = ConnectionStatus.NOT_CONNECTED + raise RequestFailedError("Connection timed out") from exc + except ConnectionResetError as exc: + writer.close() + self.streams = None + self.connection_status = ConnectionStatus.NOT_CONNECTED + raise RequestFailedError("Connection reset") from exc + + async def connect(ctx: BizHawkContext) -> bool: """Attempts to establish a connection with the connector script. Returns True if successful.""" try: @@ -72,74 +106,14 @@ def disconnect(ctx: BizHawkContext) -> None: async def get_script_version(ctx: BizHawkContext) -> int: - if ctx.streams is None: - raise NotConnectedError("You tried to send a request before a connection to BizHawk was made") - - try: - reader, writer = ctx.streams - writer.write("VERSION".encode("ascii") + b"\n") - await asyncio.wait_for(writer.drain(), timeout=5) - - version = await asyncio.wait_for(reader.readline(), timeout=5) - - if version == b"": - writer.close() - ctx.streams = None - ctx.connection_status = ConnectionStatus.NOT_CONNECTED - raise RequestFailedError("Connection closed") - - return int(version.decode("ascii")) - except asyncio.TimeoutError as exc: - writer.close() - ctx.streams = None - ctx.connection_status = ConnectionStatus.NOT_CONNECTED - raise RequestFailedError("Connection timed out") from exc - except ConnectionResetError as exc: - writer.close() - ctx.streams = None - ctx.connection_status = ConnectionStatus.NOT_CONNECTED - raise RequestFailedError("Connection reset") from exc + return int(await ctx._send_message("VERSION")) async def send_requests(ctx: BizHawkContext, req_list: typing.List[typing.Dict[str, typing.Any]]) -> typing.List[typing.Dict[str, typing.Any]]: """Sends a list of requests to the BizHawk connector and returns their responses. It's likely you want to use the wrapper functions instead of this.""" - if ctx.streams is None: - raise NotConnectedError("You tried to send a request before a connection to BizHawk was made") - - try: - reader, writer = ctx.streams - writer.write(json.dumps(req_list).encode("utf-8") + b"\n") - await asyncio.wait_for(writer.drain(), timeout=5) - - res = await asyncio.wait_for(reader.readline(), timeout=5) - - if res == b"": - writer.close() - ctx.streams = None - ctx.connection_status = ConnectionStatus.NOT_CONNECTED - raise RequestFailedError("Connection closed") - - if ctx.connection_status == ConnectionStatus.TENTATIVE: - ctx.connection_status = ConnectionStatus.CONNECTED - - ret = json.loads(res.decode("utf-8")) - for response in ret: - if response["type"] == "ERROR": - raise ConnectorError(response["err"]) - - return ret - except asyncio.TimeoutError as exc: - writer.close() - ctx.streams = None - ctx.connection_status = ConnectionStatus.NOT_CONNECTED - raise RequestFailedError("Connection timed out") from exc - except ConnectionResetError as exc: - writer.close() - ctx.streams = None - ctx.connection_status = ConnectionStatus.NOT_CONNECTED - raise RequestFailedError("Connection reset") from exc + return json.loads(await ctx._send_message(json.dumps(req_list))) async def ping(ctx: BizHawkContext) -> None: diff --git a/worlds/_bizhawk/context.py b/worlds/_bizhawk/context.py index 46533427..5d865f33 100644 --- a/worlds/_bizhawk/context.py +++ b/worlds/_bizhawk/context.py @@ -13,8 +13,8 @@ from CommonClient import CommonContext, ClientCommandProcessor, get_base_parser, import Patch import Utils -from . import BizHawkContext, ConnectionStatus, RequestFailedError, connect, disconnect, get_hash, get_script_version, \ - get_system, ping +from . import BizHawkContext, ConnectionStatus, NotConnectedError, RequestFailedError, connect, disconnect, get_hash, \ + get_script_version, get_system, ping from .client import BizHawkClient, AutoBizHawkClientRegister @@ -133,6 +133,8 @@ async def _game_watcher(ctx: BizHawkClientContext): except RequestFailedError as exc: logger.info(f"Lost connection to BizHawk: {exc.args[0]}") continue + except NotConnectedError: + continue # Get slot name and send `Connect` if ctx.server is not None and ctx.username is None: