BizHawkClient: Add lock for communicating with lua script (#2369)

This commit is contained in:
Bryce Wilson 2023-10-26 18:55:46 -07:00 committed by GitHub
parent 88d69dba97
commit b16804102d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 74 deletions

View File

@ -13,7 +13,6 @@ import typing
BIZHAWK_SOCKET_PORT = 43055
EXPECTED_SCRIPT_VERSION = 1
class ConnectionStatus(enum.IntEnum):
@ -22,15 +21,6 @@ class ConnectionStatus(enum.IntEnum):
CONNECTED = 3
class BizHawkContext:
streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]]
connection_status: ConnectionStatus
def __init__(self) -> None:
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
class NotConnectedError(Exception):
"""Raised when something tries to make a request to the connector script before a connection has been established"""
pass
@ -51,6 +41,50 @@ class SyncError(Exception):
pass
class BizHawkContext:
streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]]
connection_status: ConnectionStatus
_lock: asyncio.Lock
def __init__(self) -> None:
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
self._lock = asyncio.Lock()
async def _send_message(self, message: str):
async with self._lock:
if self.streams is None:
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
try:
reader, writer = self.streams
writer.write(message.encode("utf-8") + b"\n")
await asyncio.wait_for(writer.drain(), timeout=5)
res = await asyncio.wait_for(reader.readline(), timeout=5)
if res == b"":
writer.close()
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection closed")
if self.connection_status == ConnectionStatus.TENTATIVE:
self.connection_status = ConnectionStatus.CONNECTED
return res.decode("utf-8")
except asyncio.TimeoutError as exc:
writer.close()
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection timed out") from exc
except ConnectionResetError as exc:
writer.close()
self.streams = None
self.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection reset") from exc
async def connect(ctx: BizHawkContext) -> bool:
"""Attempts to establish a connection with the connector script. Returns True if successful."""
try:
@ -72,74 +106,14 @@ def disconnect(ctx: BizHawkContext) -> None:
async def get_script_version(ctx: BizHawkContext) -> int:
if ctx.streams is None:
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
try:
reader, writer = ctx.streams
writer.write("VERSION".encode("ascii") + b"\n")
await asyncio.wait_for(writer.drain(), timeout=5)
version = await asyncio.wait_for(reader.readline(), timeout=5)
if version == b"":
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection closed")
return int(version.decode("ascii"))
except asyncio.TimeoutError as exc:
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection timed out") from exc
except ConnectionResetError as exc:
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection reset") from exc
return int(await ctx._send_message("VERSION"))
async def send_requests(ctx: BizHawkContext, req_list: typing.List[typing.Dict[str, typing.Any]]) -> typing.List[typing.Dict[str, typing.Any]]:
"""Sends a list of requests to the BizHawk connector and returns their responses.
It's likely you want to use the wrapper functions instead of this."""
if ctx.streams is None:
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
try:
reader, writer = ctx.streams
writer.write(json.dumps(req_list).encode("utf-8") + b"\n")
await asyncio.wait_for(writer.drain(), timeout=5)
res = await asyncio.wait_for(reader.readline(), timeout=5)
if res == b"":
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection closed")
if ctx.connection_status == ConnectionStatus.TENTATIVE:
ctx.connection_status = ConnectionStatus.CONNECTED
ret = json.loads(res.decode("utf-8"))
for response in ret:
if response["type"] == "ERROR":
raise ConnectorError(response["err"])
return ret
except asyncio.TimeoutError as exc:
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection timed out") from exc
except ConnectionResetError as exc:
writer.close()
ctx.streams = None
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
raise RequestFailedError("Connection reset") from exc
return json.loads(await ctx._send_message(json.dumps(req_list)))
async def ping(ctx: BizHawkContext) -> None:

View File

@ -13,8 +13,8 @@ from CommonClient import CommonContext, ClientCommandProcessor, get_base_parser,
import Patch
import Utils
from . import BizHawkContext, ConnectionStatus, RequestFailedError, connect, disconnect, get_hash, get_script_version, \
get_system, ping
from . import BizHawkContext, ConnectionStatus, NotConnectedError, RequestFailedError, connect, disconnect, get_hash, \
get_script_version, get_system, ping
from .client import BizHawkClient, AutoBizHawkClientRegister
@ -133,6 +133,8 @@ async def _game_watcher(ctx: BizHawkClientContext):
except RequestFailedError as exc:
logger.info(f"Lost connection to BizHawk: {exc.args[0]}")
continue
except NotConnectedError:
continue
# Get slot name and send `Connect`
if ctx.server is not None and ctx.username is None: