""" A module for interacting with BizHawk through `connector_bizhawk_generic.lua`. Any mention of `domain` in this module refers to the names BizHawk gives to memory domains in its own lua api. They are naively passed to BizHawk without validation or modification. """ import asyncio import base64 import enum import json import sys import typing BIZHAWK_SOCKET_PORT_RANGE_START = 43055 BIZHAWK_SOCKET_PORT_RANGE_SIZE = 5 class ConnectionStatus(enum.IntEnum): NOT_CONNECTED = 1 TENTATIVE = 2 CONNECTED = 3 class NotConnectedError(Exception): """Raised when something tries to make a request to the connector script before a connection has been established""" pass class RequestFailedError(Exception): """Raised when the connector script did not respond to a request""" pass class ConnectorError(Exception): """Raised when the connector script encounters an error while processing a request""" pass class SyncError(Exception): """Raised when the connector script responded with a mismatched response type""" pass class BizHawkContext: streams: typing.Optional[typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter]] connection_status: ConnectionStatus _lock: asyncio.Lock _port: typing.Optional[int] def __init__(self) -> None: self.streams = None self.connection_status = ConnectionStatus.NOT_CONNECTED self._lock = asyncio.Lock() self._port = None async def _send_message(self, message: str): async with self._lock: if self.streams is None: raise NotConnectedError("You tried to send a request before a connection to BizHawk was made") try: reader, writer = self.streams writer.write(message.encode("utf-8") + b"\n") await asyncio.wait_for(writer.drain(), timeout=5) res = await asyncio.wait_for(reader.readline(), timeout=5) if res == b"": writer.close() self.streams = None self.connection_status = ConnectionStatus.NOT_CONNECTED raise RequestFailedError("Connection closed") if self.connection_status == ConnectionStatus.TENTATIVE: self.connection_status = ConnectionStatus.CONNECTED return res.decode("utf-8") except asyncio.TimeoutError as exc: writer.close() self.streams = None self.connection_status = ConnectionStatus.NOT_CONNECTED raise RequestFailedError("Connection timed out") from exc except ConnectionResetError as exc: writer.close() self.streams = None self.connection_status = ConnectionStatus.NOT_CONNECTED raise RequestFailedError("Connection reset") from exc async def connect(ctx: BizHawkContext) -> bool: """Attempts to establish a connection with a connector script. Returns True if successful.""" rotation_steps = 0 if ctx._port is None else ctx._port - BIZHAWK_SOCKET_PORT_RANGE_START ports = [*range(BIZHAWK_SOCKET_PORT_RANGE_START, BIZHAWK_SOCKET_PORT_RANGE_START + BIZHAWK_SOCKET_PORT_RANGE_SIZE)] ports = ports[rotation_steps:] + ports[:rotation_steps] for port in ports: try: ctx.streams = await asyncio.open_connection("", port) ctx.connection_status = ConnectionStatus.TENTATIVE ctx._port = port return True except (TimeoutError, ConnectionRefusedError): continue # No ports worked ctx.streams = None ctx.connection_status = ConnectionStatus.NOT_CONNECTED return False def disconnect(ctx: BizHawkContext) -> None: """Closes the connection to the connector script.""" if ctx.streams is not None: ctx.streams[1].close() ctx.streams = None ctx.connection_status = ConnectionStatus.NOT_CONNECTED async def get_script_version(ctx: BizHawkContext) -> int: return int(await ctx._send_message("VERSION")) async def send_requests(ctx: BizHawkContext, req_list: typing.List[typing.Dict[str, typing.Any]]) -> typing.List[typing.Dict[str, typing.Any]]: """Sends a list of requests to the BizHawk connector and returns their responses. It's likely you want to use the wrapper functions instead of this.""" responses = json.loads(await ctx._send_message(json.dumps(req_list))) errors: typing.List[ConnectorError] = [] for response in responses: if response["type"] == "ERROR": errors.append(ConnectorError(response["err"])) if errors: if sys.version_info >= (3, 11, 0): raise ExceptionGroup("Connector script returned errors", errors) # noqa else: raise errors[0] return responses async def ping(ctx: BizHawkContext) -> None: """Sends a PING request and receives a PONG response.""" res = (await send_requests(ctx, [{"type": "PING"}]))[0] if res["type"] != "PONG": raise SyncError(f"Expected response of type PONG but got {res['type']}") async def get_hash(ctx: BizHawkContext) -> str: """Gets the hash value of the currently loaded ROM""" res = (await send_requests(ctx, [{"type": "HASH"}]))[0] if res["type"] != "HASH_RESPONSE": raise SyncError(f"Expected response of type HASH_RESPONSE but got {res['type']}") return res["value"] async def get_memory_size(ctx: BizHawkContext, domain: str) -> int: """Gets the size in bytes of the specified memory domain""" res = (await send_requests(ctx, [{"type": "MEMORY_SIZE", "domain": domain}]))[0] if res["type"] != "MEMORY_SIZE_RESPONSE": raise SyncError(f"Expected response of type MEMORY_SIZE_RESPONSE but got {res['type']}") return res["value"] async def get_system(ctx: BizHawkContext) -> str: """Gets the system name for the currently loaded ROM""" res = (await send_requests(ctx, [{"type": "SYSTEM"}]))[0] if res["type"] != "SYSTEM_RESPONSE": raise SyncError(f"Expected response of type SYSTEM_RESPONSE but got {res['type']}") return res["value"] async def get_cores(ctx: BizHawkContext) -> typing.Dict[str, str]: """Gets the preferred cores for systems with multiple cores. Only systems with multiple available cores have entries.""" res = (await send_requests(ctx, [{"type": "PREFERRED_CORES"}]))[0] if res["type"] != "PREFERRED_CORES_RESPONSE": raise SyncError(f"Expected response of type PREFERRED_CORES_RESPONSE but got {res['type']}") return res["value"] async def lock(ctx: BizHawkContext) -> None: """Locks BizHawk in anticipation of receiving more requests this frame. Consider using guarded reads and writes instead of locks if possible. While locked, emulation will halt and the connector will block on incoming requests until an `UNLOCK` request is sent. Remember to unlock when you're done, or the emulator will appear to freeze. Sending multiple lock commands is the same as sending one.""" res = (await send_requests(ctx, [{"type": "LOCK"}]))[0] if res["type"] != "LOCKED": raise SyncError(f"Expected response of type LOCKED but got {res['type']}") async def unlock(ctx: BizHawkContext) -> None: """Unlocks BizHawk to allow it to resume emulation. See `lock` for more info. Sending multiple unlock commands is the same as sending one.""" res = (await send_requests(ctx, [{"type": "UNLOCK"}]))[0] if res["type"] != "UNLOCKED": raise SyncError(f"Expected response of type UNLOCKED but got {res['type']}") async def display_message(ctx: BizHawkContext, message: str) -> None: """Displays the provided message in BizHawk's message queue.""" res = (await send_requests(ctx, [{"type": "DISPLAY_MESSAGE", "message": message}]))[0] if res["type"] != "DISPLAY_MESSAGE_RESPONSE": raise SyncError(f"Expected response of type DISPLAY_MESSAGE_RESPONSE but got {res['type']}") async def set_message_interval(ctx: BizHawkContext, value: float) -> None: """Sets the minimum amount of time in seconds to wait between queued messages. The default value of 0 will allow one new message to display per frame.""" res = (await send_requests(ctx, [{"type": "SET_MESSAGE_INTERVAL", "value": value}]))[0] if res["type"] != "SET_MESSAGE_INTERVAL_RESPONSE": raise SyncError(f"Expected response of type SET_MESSAGE_INTERVAL_RESPONSE but got {res['type']}") async def guarded_read(ctx: BizHawkContext, read_list: typing.Sequence[typing.Tuple[int, int, str]], guard_list: typing.Sequence[typing.Tuple[int, typing.Sequence[int], str]]) -> typing.Optional[typing.List[bytes]]: """Reads an array of bytes at 1 or more addresses if and only if every byte in guard_list matches its expected value. Items in read_list should be organized (address, size, domain) where - `address` is the address of the first byte of data - `size` is the number of bytes to read - `domain` is the name of the region of memory the address corresponds to Items in `guard_list` should be organized `(address, expected_data, domain)` where - `address` is the address of the first byte of data - `expected_data` is the bytes that the data starting at this address is expected to match - `domain` is the name of the region of memory the address corresponds to Returns None if any item in guard_list failed to validate. Otherwise returns a list of bytes in the order they were requested.""" res = await send_requests(ctx, [{ "type": "GUARD", "address": address, "expected_data": base64.b64encode(bytes(expected_data)).decode("ascii"), "domain": domain } for address, expected_data, domain in guard_list] + [{ "type": "READ", "address": address, "size": size, "domain": domain } for address, size, domain in read_list]) ret: typing.List[bytes] = [] for item in res: if item["type"] == "GUARD_RESPONSE": if not item["value"]: return None else: if item["type"] != "READ_RESPONSE": raise SyncError(f"Expected response of type READ_RESPONSE or GUARD_RESPONSE but got {item['type']}") ret.append(base64.b64decode(item["value"])) return ret async def read(ctx: BizHawkContext, read_list: typing.Sequence[typing.Tuple[int, int, str]]) -> typing.List[bytes]: """Reads data at 1 or more addresses. Items in `read_list` should be organized `(address, size, domain)` where - `address` is the address of the first byte of data - `size` is the number of bytes to read - `domain` is the name of the region of memory the address corresponds to Returns a list of bytes in the order they were requested.""" return await guarded_read(ctx, read_list, []) async def guarded_write(ctx: BizHawkContext, write_list: typing.Sequence[typing.Tuple[int, typing.Sequence[int], str]], guard_list: typing.Sequence[typing.Tuple[int, typing.Sequence[int], str]]) -> bool: """Writes data to 1 or more addresses if and only if every byte in guard_list matches its expected value. Items in `write_list` should be organized `(address, value, domain)` where - `address` is the address of the first byte of data - `value` is a list of bytes to write, in order, starting at `address` - `domain` is the name of the region of memory the address corresponds to Items in `guard_list` should be organized `(address, expected_data, domain)` where - `address` is the address of the first byte of data - `expected_data` is the bytes that the data starting at this address is expected to match - `domain` is the name of the region of memory the address corresponds to Returns False if any item in guard_list failed to validate. Otherwise returns True.""" res = await send_requests(ctx, [{ "type": "GUARD", "address": address, "expected_data": base64.b64encode(bytes(expected_data)).decode("ascii"), "domain": domain } for address, expected_data, domain in guard_list] + [{ "type": "WRITE", "address": address, "value": base64.b64encode(bytes(value)).decode("ascii"), "domain": domain } for address, value, domain in write_list]) for item in res: if item["type"] == "GUARD_RESPONSE": if not item["value"]: return False else: if item["type"] != "WRITE_RESPONSE": raise SyncError(f"Expected response of type WRITE_RESPONSE or GUARD_RESPONSE but got {item['type']}") return True async def write(ctx: BizHawkContext, write_list: typing.Sequence[typing.Tuple[int, typing.Sequence[int], str]]) -> None: """Writes data to 1 or more addresses. Items in write_list should be organized `(address, value, domain)` where - `address` is the address of the first byte of data - `value` is a list of bytes to write, in order, starting at `address` - `domain` is the name of the region of memory the address corresponds to""" await guarded_write(ctx, write_list, [])