BizHawkClient: Add lock for communicating with lua script (#2369)

This commit is contained in:
Bryce Wilson 2023-10-26 18:55:46 -07:00 committed by GitHub
parent 88d69dba97
commit b16804102d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 74 deletions

View File

@ -13,7 +13,6 @@ import typing
BIZHAWK_SOCKET_PORT = 43055 BIZHAWK_SOCKET_PORT = 43055
EXPECTED_SCRIPT_VERSION = 1
class ConnectionStatus(enum.IntEnum): class ConnectionStatus(enum.IntEnum):
@ -22,15 +21,6 @@ class ConnectionStatus(enum.IntEnum):
CONNECTED = 3 CONNECTED = 3
class BizHawkContext:
streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]]
connection_status: ConnectionStatus
def __init__(self) -> None:
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
class NotConnectedError(Exception): class NotConnectedError(Exception):
"""Raised when something tries to make a request to the connector script before a connection has been established""" """Raised when something tries to make a request to the connector script before a connection has been established"""
pass pass
@ -51,6 +41,50 @@ class SyncError(Exception):
pass pass
class BizHawkContext:
streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]]
connection_status: ConnectionStatus
_lock: asyncio.Lock
def __init__(self) -> None:
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
self._lock = asyncio.Lock()
async def _send_message(self, message: str):
async with self._lock:
if self.streams is None:
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
try:
reader, writer = self.streams
writer.write(message.encode("utf-8") + b"\n")
await asyncio.wait_for(writer.drain(), timeout=5)
res = await asyncio.wait_for(reader.readline(), timeout=5)
if res == b"":
writer.close()
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection closed")
if self.connection_status == ConnectionStatus.TENTATIVE:
self.connection_status = ConnectionStatus.CONNECTED
return res.decode("utf-8")
except asyncio.TimeoutError as exc:
writer.close()
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection timed out") from exc
except ConnectionResetError as exc:
writer.close()
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection reset") from exc
async def connect(ctx: BizHawkContext) -> bool: async def connect(ctx: BizHawkContext) -> bool:
"""Attempts to establish a connection with the connector script. Returns True if successful.""" """Attempts to establish a connection with the connector script. Returns True if successful."""
try: try:
@ -72,74 +106,14 @@ def disconnect(ctx: BizHawkContext) -> None:
async def get_script_version(ctx: BizHawkContext) -> int: async def get_script_version(ctx: BizHawkContext) -> int:
if ctx.streams is None: return int(await ctx._send_message("VERSION"))
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
try:
reader, writer = ctx.streams
writer.write("VERSION".encode("ascii") + b"\n")
await asyncio.wait_for(writer.drain(), timeout=5)
version = await asyncio.wait_for(reader.readline(), timeout=5)
if version == b"":
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection closed")
return int(version.decode("ascii"))
except asyncio.TimeoutError as exc:
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection timed out") from exc
except ConnectionResetError as exc:
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection reset") from exc
async def send_requests(ctx: BizHawkContext, req_list: typing.List[typing.Dict[str, typing.Any]]) -> typing.List[typing.Dict[str, typing.Any]]: async def send_requests(ctx: BizHawkContext, req_list: typing.List[typing.Dict[str, typing.Any]]) -> typing.List[typing.Dict[str, typing.Any]]:
"""Sends a list of requests to the BizHawk connector and returns their responses. """Sends a list of requests to the BizHawk connector and returns their responses.
It's likely you want to use the wrapper functions instead of this.""" It's likely you want to use the wrapper functions instead of this."""
if ctx.streams is None: return json.loads(await ctx._send_message(json.dumps(req_list)))
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
try:
reader, writer = ctx.streams
writer.write(json.dumps(req_list).encode("utf-8") + b"\n")
await asyncio.wait_for(writer.drain(), timeout=5)
res = await asyncio.wait_for(reader.readline(), timeout=5)
if res == b"":
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection closed")
if ctx.connection_status == ConnectionStatus.TENTATIVE:
ctx.connection_status = ConnectionStatus.CONNECTED
ret = json.loads(res.decode("utf-8"))
for response in ret:
if response["type"] == "ERROR":
raise ConnectorError(response["err"])
return ret
except asyncio.TimeoutError as exc:
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection timed out") from exc
except ConnectionResetError as exc:
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection reset") from exc
async def ping(ctx: BizHawkContext) -> None: async def ping(ctx: BizHawkContext) -> None:

View File

@ -13,8 +13,8 @@ from CommonClient import CommonContext, ClientCommandProcessor, get_base_parser,
import Patch import Patch
import Utils import Utils
from . import BizHawkContext, ConnectionStatus, RequestFailedError, connect, disconnect, get_hash, get_script_version, \ from . import BizHawkContext, ConnectionStatus, NotConnectedError, RequestFailedError, connect, disconnect, get_hash, \
get_system, ping get_script_version, get_system, ping
from .client import BizHawkClient, AutoBizHawkClientRegister from .client import BizHawkClient, AutoBizHawkClientRegister
@ -133,6 +133,8 @@ async def _game_watcher(ctx: BizHawkClientContext):
except RequestFailedError as exc: except RequestFailedError as exc:
logger.info(f"Lost connection to BizHawk: {exc.args[0]}") logger.info(f"Lost connection to BizHawk: {exc.args[0]}")
continue continue
except NotConnectedError:
continue
# Get slot name and send `Connect` # Get slot name and send `Connect`
if ctx.server is not None and ctx.username is None: if ctx.server is not None and ctx.username is None: