BizHawkClient: Add lock for communicating with lua script (#2369)
This commit is contained in:
parent
88d69dba97
commit
b16804102d
|
@ -13,7 +13,6 @@ import typing
|
||||||
|
|
||||||
|
|
||||||
BIZHAWK_SOCKET_PORT = 43055
|
BIZHAWK_SOCKET_PORT = 43055
|
||||||
EXPECTED_SCRIPT_VERSION = 1
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionStatus(enum.IntEnum):
|
class ConnectionStatus(enum.IntEnum):
|
||||||
|
@ -22,15 +21,6 @@ class ConnectionStatus(enum.IntEnum):
|
||||||
CONNECTED = 3
|
CONNECTED = 3
|
||||||
|
|
||||||
|
|
||||||
class BizHawkContext:
|
|
||||||
streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]]
|
|
||||||
connection_status: ConnectionStatus
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.streams = None
|
|
||||||
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
|
||||||
|
|
||||||
|
|
||||||
class NotConnectedError(Exception):
|
class NotConnectedError(Exception):
|
||||||
"""Raised when something tries to make a request to the connector script before a connection has been established"""
|
"""Raised when something tries to make a request to the connector script before a connection has been established"""
|
||||||
pass
|
pass
|
||||||
|
@ -51,6 +41,50 @@ class SyncError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BizHawkContext:
|
||||||
|
streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]]
|
||||||
|
connection_status: ConnectionStatus
|
||||||
|
_lock: asyncio.Lock
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.streams = None
|
||||||
|
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def _send_message(self, message: str):
|
||||||
|
async with self._lock:
|
||||||
|
if self.streams is None:
|
||||||
|
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
|
||||||
|
|
||||||
|
try:
|
||||||
|
reader, writer = self.streams
|
||||||
|
writer.write(message.encode("utf-8") + b"\n")
|
||||||
|
await asyncio.wait_for(writer.drain(), timeout=5)
|
||||||
|
|
||||||
|
res = await asyncio.wait_for(reader.readline(), timeout=5)
|
||||||
|
|
||||||
|
if res == b"":
|
||||||
|
writer.close()
|
||||||
|
self.streams = None
|
||||||
|
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||||
|
raise RequestFailedError("Connection closed")
|
||||||
|
|
||||||
|
if self.connection_status == ConnectionStatus.TENTATIVE:
|
||||||
|
self.connection_status = ConnectionStatus.CONNECTED
|
||||||
|
|
||||||
|
return res.decode("utf-8")
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
writer.close()
|
||||||
|
self.streams = None
|
||||||
|
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||||
|
raise RequestFailedError("Connection timed out") from exc
|
||||||
|
except ConnectionResetError as exc:
|
||||||
|
writer.close()
|
||||||
|
self.streams = None
|
||||||
|
self.connection_status = ConnectionStatus.NOT_CONNECTED
|
||||||
|
raise RequestFailedError("Connection reset") from exc
|
||||||
|
|
||||||
|
|
||||||
async def connect(ctx: BizHawkContext) -> bool:
|
async def connect(ctx: BizHawkContext) -> bool:
|
||||||
"""Attempts to establish a connection with the connector script. Returns True if successful."""
|
"""Attempts to establish a connection with the connector script. Returns True if successful."""
|
||||||
try:
|
try:
|
||||||
|
@ -72,74 +106,14 @@ def disconnect(ctx: BizHawkContext) -> None:
|
||||||
|
|
||||||
|
|
||||||
async def get_script_version(ctx: BizHawkContext) -> int:
|
async def get_script_version(ctx: BizHawkContext) -> int:
|
||||||
if ctx.streams is None:
|
return int(await ctx._send_message("VERSION"))
|
||||||
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader, writer = ctx.streams
|
|
||||||
writer.write("VERSION".encode("ascii") + b"\n")
|
|
||||||
await asyncio.wait_for(writer.drain(), timeout=5)
|
|
||||||
|
|
||||||
version = await asyncio.wait_for(reader.readline(), timeout=5)
|
|
||||||
|
|
||||||
if version == b"":
|
|
||||||
writer.close()
|
|
||||||
ctx.streams = None
|
|
||||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
|
||||||
raise RequestFailedError("Connection closed")
|
|
||||||
|
|
||||||
return int(version.decode("ascii"))
|
|
||||||
except asyncio.TimeoutError as exc:
|
|
||||||
writer.close()
|
|
||||||
ctx.streams = None
|
|
||||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
|
||||||
raise RequestFailedError("Connection timed out") from exc
|
|
||||||
except ConnectionResetError as exc:
|
|
||||||
writer.close()
|
|
||||||
ctx.streams = None
|
|
||||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
|
||||||
raise RequestFailedError("Connection reset") from exc
|
|
||||||
|
|
||||||
|
|
||||||
async def send_requests(ctx: BizHawkContext, req_list: typing.List[typing.Dict[str, typing.Any]]) -> typing.List[typing.Dict[str, typing.Any]]:
|
async def send_requests(ctx: BizHawkContext, req_list: typing.List[typing.Dict[str, typing.Any]]) -> typing.List[typing.Dict[str, typing.Any]]:
|
||||||
"""Sends a list of requests to the BizHawk connector and returns their responses.
|
"""Sends a list of requests to the BizHawk connector and returns their responses.
|
||||||
|
|
||||||
It's likely you want to use the wrapper functions instead of this."""
|
It's likely you want to use the wrapper functions instead of this."""
|
||||||
if ctx.streams is None:
|
return json.loads(await ctx._send_message(json.dumps(req_list)))
|
||||||
raise NotConnectedError("You tried to send a request before a connection to BizHawk was made")
|
|
||||||
|
|
||||||
try:
|
|
||||||
reader, writer = ctx.streams
|
|
||||||
writer.write(json.dumps(req_list).encode("utf-8") + b"\n")
|
|
||||||
await asyncio.wait_for(writer.drain(), timeout=5)
|
|
||||||
|
|
||||||
res = await asyncio.wait_for(reader.readline(), timeout=5)
|
|
||||||
|
|
||||||
if res == b"":
|
|
||||||
writer.close()
|
|
||||||
ctx.streams = None
|
|
||||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
|
||||||
raise RequestFailedError("Connection closed")
|
|
||||||
|
|
||||||
if ctx.connection_status == ConnectionStatus.TENTATIVE:
|
|
||||||
ctx.connection_status = ConnectionStatus.CONNECTED
|
|
||||||
|
|
||||||
ret = json.loads(res.decode("utf-8"))
|
|
||||||
for response in ret:
|
|
||||||
if response["type"] == "ERROR":
|
|
||||||
raise ConnectorError(response["err"])
|
|
||||||
|
|
||||||
return ret
|
|
||||||
except asyncio.TimeoutError as exc:
|
|
||||||
writer.close()
|
|
||||||
ctx.streams = None
|
|
||||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
|
||||||
raise RequestFailedError("Connection timed out") from exc
|
|
||||||
except ConnectionResetError as exc:
|
|
||||||
writer.close()
|
|
||||||
ctx.streams = None
|
|
||||||
ctx.connection_status = ConnectionStatus.NOT_CONNECTED
|
|
||||||
raise RequestFailedError("Connection reset") from exc
|
|
||||||
|
|
||||||
|
|
||||||
async def ping(ctx: BizHawkContext) -> None:
|
async def ping(ctx: BizHawkContext) -> None:
|
||||||
|
|
|
@ -13,8 +13,8 @@ from CommonClient import CommonContext, ClientCommandProcessor, get_base_parser,
|
||||||
import Patch
|
import Patch
|
||||||
import Utils
|
import Utils
|
||||||
|
|
||||||
from . import BizHawkContext, ConnectionStatus, RequestFailedError, connect, disconnect, get_hash, get_script_version, \
|
from . import BizHawkContext, ConnectionStatus, NotConnectedError, RequestFailedError, connect, disconnect, get_hash, \
|
||||||
get_system, ping
|
get_script_version, get_system, ping
|
||||||
from .client import BizHawkClient, AutoBizHawkClientRegister
|
from .client import BizHawkClient, AutoBizHawkClientRegister
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,6 +133,8 @@ async def _game_watcher(ctx: BizHawkClientContext):
|
||||||
except RequestFailedError as exc:
|
except RequestFailedError as exc:
|
||||||
logger.info(f"Lost connection to BizHawk: {exc.args[0]}")
|
logger.info(f"Lost connection to BizHawk: {exc.args[0]}")
|
||||||
continue
|
continue
|
||||||
|
except NotConnectedError:
|
||||||
|
continue
|
||||||
|
|
||||||
# Get slot name and send `Connect`
|
# Get slot name and send `Connect`
|
||||||
if ctx.server is not None and ctx.username is None:
|
if ctx.server is not None and ctx.username is None:
|
||||||
|
|
Loading…
Reference in New Issue