BizHawkClient: Add lock for communicating with lua script (#2369)
This commit is contained in:
parent
88d69dba97
commit
b16804102d
|
@ -13,7 +13,6 @@ import typing
|
|||
|
||||
|
||||
BIZHAWK_SOCKET_PORT = 43055
|
||||
EXPECTED_SCRIPT_VERSION = 1
|
||||
|
||||
|
||||
class ConnectionStatus(enum.IntEnum):
|
||||
|
@ -22,15 +21,6 @@ class ConnectionStatus(enum.IntEnum):
|
|||
CONNECTED = 3
|
||||
|
||||
|
||||
class BizHawkContext:
|
||||
streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]]
|
||||
connection_status: ConnectionStatus
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.streams = None
|
||||
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
|
||||
|
||||
class NotConnectedError(Exception):
|
||||
"""Raised when something tries to make a request to the connector script before a connection has been established"""
|
||||
pass
|
||||
|
@ -51,6 +41,50 @@ class SyncError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class BizHawkContext:
|
||||
streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]]
|
||||
connection_status: ConnectionStatus
|
||||
_lock: asyncio.Lock
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.streams = None
|
||||
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _send_message(self, message: str):
|
||||
async with self._lock:
|
||||
if self.streams is None:
|
||||
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
|
||||
|
||||
try:
|
||||
reader, writer = self.streams
|
||||
writer.write(message.encode("utf-8") + b"\n")
|
||||
await asyncio.wait_for(writer.drain(), timeout=5)
|
||||
|
||||
res = await asyncio.wait_for(reader.readline(), timeout=5)
|
||||
|
||||
if res == b"":
|
||||
writer.close()
|
||||
self.streams = None
|
||||
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
raise RequestFailedError("Connection closed")
|
||||
|
||||
if self.connection_status == ConnectionStatus.TENTATIVE:
|
||||
self.connection_status = ConnectionStatus.CONNECTED
|
||||
|
||||
return res.decode("utf-8")
|
||||
except asyncio.TimeoutError as exc:
|
||||
writer.close()
|
||||
self.streams = None
|
||||
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
raise RequestFailedError("Connection timed out") from exc
|
||||
except ConnectionResetError as exc:
|
||||
writer.close()
|
||||
self.streams = None
|
||||
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
raise RequestFailedError("Connection reset") from exc
|
||||
|
||||
|
||||
async def connect(ctx: BizHawkContext) -> bool:
|
||||
"""Attempts to establish a connection with the connector script. Returns True if successful."""
|
||||
try:
|
||||
|
@ -72,74 +106,14 @@ def disconnect(ctx: BizHawkContext) -> None:
|
|||
|
||||
|
||||
async def get_script_version(ctx: BizHawkContext) -> int:
|
||||
if ctx.streams is None:
|
||||
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
|
||||
|
||||
try:
|
||||
reader, writer = ctx.streams
|
||||
writer.write("VERSION".encode("ascii") + b"\n")
|
||||
await asyncio.wait_for(writer.drain(), timeout=5)
|
||||
|
||||
version = await asyncio.wait_for(reader.readline(), timeout=5)
|
||||
|
||||
if version == b"":
|
||||
writer.close()
|
||||
ctx.streams = None
|
||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
raise RequestFailedError("Connection closed")
|
||||
|
||||
return int(version.decode("ascii"))
|
||||
except asyncio.TimeoutError as exc:
|
||||
writer.close()
|
||||
ctx.streams = None
|
||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
raise RequestFailedError("Connection timed out") from exc
|
||||
except ConnectionResetError as exc:
|
||||
writer.close()
|
||||
ctx.streams = None
|
||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
raise RequestFailedError("Connection reset") from exc
|
||||
return int(await ctx._send_message("VERSION"))
|
||||
|
||||
|
||||
async def send_requests(ctx: BizHawkContext, req_list: typing.List[typing.Dict[str, typing.Any]]) -> typing.List[typing.Dict[str, typing.Any]]:
|
||||
"""Sends a list of requests to the BizHawk connector and returns their responses.
|
||||
|
||||
It's likely you want to use the wrapper functions instead of this."""
|
||||
if ctx.streams is None:
|
||||
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
|
||||
|
||||
try:
|
||||
reader, writer = ctx.streams
|
||||
writer.write(json.dumps(req_list).encode("utf-8") + b"\n")
|
||||
await asyncio.wait_for(writer.drain(), timeout=5)
|
||||
|
||||
res = await asyncio.wait_for(reader.readline(), timeout=5)
|
||||
|
||||
if res == b"":
|
||||
writer.close()
|
||||
ctx.streams = None
|
||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
raise RequestFailedError("Connection closed")
|
||||
|
||||
if ctx.connection_status == ConnectionStatus.TENTATIVE:
|
||||
ctx.connection_status = ConnectionStatus.CONNECTED
|
||||
|
||||
ret = json.loads(res.decode("utf-8"))
|
||||
for response in ret:
|
||||
if response["type"] == "ERROR":
|
||||
raise ConnectorError(response["err"])
|
||||
|
||||
return ret
|
||||
except asyncio.TimeoutError as exc:
|
||||
writer.close()
|
||||
ctx.streams = None
|
||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
raise RequestFailedError("Connection timed out") from exc
|
||||
except ConnectionResetError as exc:
|
||||
writer.close()
|
||||
ctx.streams = None
|
||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||
raise RequestFailedError("Connection reset") from exc
|
||||
return json.loads(await ctx._send_message(json.dumps(req_list)))
|
||||
|
||||
|
||||
async def ping(ctx: BizHawkContext) -> None:
|
||||
|
|
|
@ -13,8 +13,8 @@ from CommonClient import CommonContext, ClientCommandProcessor, get_base_parser,
|
|||
import Patch
|
||||
import Utils
|
||||
|
||||
from . import BizHawkContext, ConnectionStatus, RequestFailedError, connect, disconnect, get_hash, get_script_version, \
|
||||
get_system, ping
|
||||
from . import BizHawkContext, ConnectionStatus, NotConnectedError, RequestFailedError, connect, disconnect, get_hash, \
|
||||
get_script_version, get_system, ping
|
||||
from .client import BizHawkClient, AutoBizHawkClientRegister
|
||||
|
||||
|
||||
|
@ -133,6 +133,8 @@ async def _game_watcher(ctx: BizHawkClientContext):
|
|||
except RequestFailedError as exc:
|
||||
logger.info(f"Lost connection to BizHawk: {exc.args[0]}")
|
||||
continue
|
||||
except NotConnectedError:
|
||||
continue
|
||||
|
||||
# Get slot name and send `Connect`
|
||||
if ctx.server is not None and ctx.username is None:
|
||||
|
|
Loading…
Reference in New Issue