from __future__ import annotations import sys import threading import time import multiprocessing import os import subprocess import base64 import logging import asyncio import enum import typing from json import loads, dumps # CommonClient import first to trigger ModuleUpdater from CommonClient import CommonContext, server_loop, ClientCommandProcessor, gui_enabled, get_base_parser import Utils from Utils import async_start from MultiServer import mark_raw if typing.TYPE_CHECKING: from worlds.AutoSNIClient import SNIClient if __name__ == "__main__": Utils.init_logging("SNIClient", exception_logger="Client") import colorama from websockets.client import connect as websockets_connect, WebSocketClientProtocol from websockets.exceptions import WebSocketException, ConnectionClosed snes_logger = logging.getLogger("SNES") class DeathState(enum.IntEnum): killing_player = 1 alive = 2 dead = 3 class SNIClientCommandProcessor(ClientCommandProcessor): ctx: SNIContext def _cmd_slow_mode(self, toggle: str = "") -> None: """Toggle slow mode, which limits how fast you send / receive items.""" if toggle: self.ctx.slow_mode = toggle.lower() in {"1", "true", "on"} else: self.ctx.slow_mode = not self.ctx.slow_mode self.output(f"Setting slow mode to {self.ctx.slow_mode}") @mark_raw def _cmd_snes(self, snes_options: str = "") -> bool: """Connect to a snes. Optionally include network address of a snes to connect to, otherwise show available devices; and a SNES device number if more than one SNES is detected. Examples: "/snes", "/snes 1", "/snes localhost:23074 1" """ if self.ctx.snes_state in {SNESState.SNES_ATTACHED, SNESState.SNES_CONNECTED, SNESState.SNES_CONNECTING}: self.output("Already connected to SNES. Disconnecting first.") self._cmd_snes_close() return self.connect_to_snes(snes_options) def connect_to_snes(self, snes_options: str = "") -> bool: snes_address = self.ctx.snes_address snes_device_number = -1 options = snes_options.split() num_options = len(options) if num_options > 1: snes_address = options[0] snes_device_number = int(options[1]) elif num_options > 0: snes_device_number = int(options[0]) self.ctx.snes_reconnect_address = None if self.ctx.snes_connect_task: self.ctx.snes_connect_task.cancel() self.ctx.snes_connect_task = asyncio.create_task(snes_connect(self.ctx, snes_address, snes_device_number), name="SNES Connect") return True def _cmd_snes_close(self) -> bool: """Close connection to a currently connected snes""" self.ctx.snes_reconnect_address = None self.ctx.cancel_snes_autoreconnect() if self.ctx.snes_socket and not self.ctx.snes_socket.closed: async_start(self.ctx.snes_socket.close()) return True else: return False # Left here for quick re-addition for debugging. # def _cmd_snes_write(self, address, data): # """Write the specified byte (base10) to the SNES' memory address (base16).""" # if self.ctx.snes_state != SNESState.SNES_ATTACHED: # self.output("No attached SNES Device.") # return False # snes_buffered_write(self.ctx, int(address, 16), bytes([int(data)])) # async_start(snes_flush_writes(self.ctx)) # self.output("Data Sent") # return True # def _cmd_snes_read(self, address, size=1): # """Read the SNES' memory address (base16).""" # if self.ctx.snes_state != SNESState.SNES_ATTACHED: # self.output("No attached SNES Device.") # return False # data = await snes_read(self.ctx, int(address, 16), size) # self.output(f"Data Read: {data}") # return True class SNIContext(CommonContext): command_processor: typing.Type[SNIClientCommandProcessor] = SNIClientCommandProcessor game: typing.Optional[str] = None # set in validate_rom items_handling: typing.Optional[int] = None # set in game_watcher snes_connect_task: "typing.Optional[asyncio.Task[None]]" = None snes_autoreconnect_task: typing.Optional["asyncio.Task[None]"] = None snes_address: str snes_socket: typing.Optional[WebSocketClientProtocol] snes_state: SNESState snes_attached_device: typing.Optional[typing.Tuple[int, str]] snes_reconnect_address: typing.Optional[str] snes_recv_queue: "asyncio.Queue[bytes]" snes_request_lock: asyncio.Lock snes_write_buffer: typing.List[typing.Tuple[int, bytes]] snes_connector_lock: threading.Lock death_state: DeathState killing_player_task: "typing.Optional[asyncio.Task[None]]" allow_collect: bool slow_mode: bool client_handler: typing.Optional[SNIClient] awaiting_rom: bool rom: typing.Optional[bytes] prev_rom: typing.Optional[bytes] hud_message_queue: typing.List[str] # TODO: str is a guess, is this right? death_link_allow_survive: bool def __init__(self, snes_address: str, server_address: str, password: str) -> None: super(SNIContext, self).__init__(server_address, password) # snes stuff self.snes_address = snes_address self.snes_socket = None self.snes_state = SNESState.SNES_DISCONNECTED self.snes_attached_device = None self.snes_reconnect_address = None self.snes_recv_queue = asyncio.Queue() self.snes_request_lock = asyncio.Lock() self.snes_write_buffer = [] self.snes_connector_lock = threading.Lock() self.death_state = DeathState.alive # for death link flop behaviour self.killing_player_task = None self.allow_collect = False self.slow_mode = False self.client_handler = None self.awaiting_rom = False self.rom = None self.prev_rom = None async def connection_closed(self) -> None: await super(SNIContext, self).connection_closed() self.awaiting_rom = False def event_invalid_slot(self) -> typing.NoReturn: if self.snes_socket is not None and not self.snes_socket.closed: async_start(self.snes_socket.close()) raise Exception("Invalid ROM detected, " "please verify that you have loaded the correct rom and reconnect your snes (/snes)") async def server_auth(self, password_requested: bool = False) -> None: if password_requested and not self.password: await super(SNIContext, self).server_auth(password_requested) if self.rom is None: self.awaiting_rom = True snes_logger.info( "No ROM detected, awaiting snes connection to authenticate to the multiworld server (/snes)") return self.awaiting_rom = False # TODO: This looks kind of hacky... # Context.auth is meant to be the "name" parameter in send_connect, # which has to be a str (bytes is not json serializable). # But here, Context.auth is being used for something else # (where it has to be bytes because it is compared with rom elsewhere). # If we need to save something to compare with rom elsewhere, # it should probably be in a different variable, # and let auth be used for what it's meant for. self.auth = self.rom auth = base64.b64encode(self.rom).decode() await self.send_connect(name=auth) def cancel_snes_autoreconnect(self) -> bool: if self.snes_autoreconnect_task: self.snes_autoreconnect_task.cancel() self.snes_autoreconnect_task = None return True return False def on_deathlink(self, data: typing.Dict[str, typing.Any]) -> None: if not self.killing_player_task or self.killing_player_task.done(): self.killing_player_task = asyncio.create_task(deathlink_kill_player(self)) super(SNIContext, self).on_deathlink(data) async def handle_deathlink_state(self, currently_dead: bool, death_text: str = "") -> None: # in this state we only care about triggering a death send if self.death_state == DeathState.alive: if currently_dead: self.death_state = DeathState.dead await self.send_death(death_text) # in this state we care about confirming a kill, to move state to dead elif self.death_state == DeathState.killing_player: # this is being handled in deathlink_kill_player(ctx) already pass # in this state we wait until the player is alive again elif self.death_state == DeathState.dead: if not currently_dead: self.death_state = DeathState.alive async def shutdown(self) -> None: await super(SNIContext, self).shutdown() self.cancel_snes_autoreconnect() if self.snes_connect_task: try: await asyncio.wait_for(self.snes_connect_task, 1) except asyncio.TimeoutError: self.snes_connect_task.cancel() def on_package(self, cmd: str, args: typing.Dict[str, typing.Any]) -> None: if cmd in {"Connected", "RoomUpdate"}: if "checked_locations" in args and args["checked_locations"]: new_locations = set(args["checked_locations"]) self.checked_locations |= new_locations self.locations_scouted |= new_locations # Items belonging to the player should not be marked as checked in game, # since the player will likely need that item. # Once the games handled by SNIClient gets made to be remote items, # this will no longer be needed. async_start(self.send_msgs([{"cmd": "LocationScouts", "locations": list(new_locations)}])) def run_gui(self) -> None: from kvui import GameManager class SNIManager(GameManager): logging_pairs = [ ("Client", "Archipelago"), ("SNES", "SNES"), ] base_title = "Archipelago SNI Client" self.ui = SNIManager(self) self.ui_task = asyncio.create_task(self.ui.async_run(), name="UI") # type: ignore async def deathlink_kill_player(ctx: SNIContext) -> None: ctx.death_state = DeathState.killing_player while ctx.death_state == DeathState.killing_player and \ ctx.snes_state == SNESState.SNES_ATTACHED: if ctx.client_handler is None: continue await ctx.client_handler.deathlink_kill_player(ctx) ctx.last_death_link = time.time() _global_snes_reconnect_delay = 5 class SNESState(enum.IntEnum): SNES_DISCONNECTED = 0 SNES_CONNECTING = 1 SNES_CONNECTED = 2 SNES_ATTACHED = 3 def launch_sni() -> None: sni_path = Utils.get_options()["sni_options"]["sni_path"] if not os.path.isdir(sni_path): sni_path = Utils.local_path(sni_path) if os.path.isdir(sni_path): dir_entry: "os.DirEntry[str]" for dir_entry in os.scandir(sni_path): if dir_entry.is_file(): lower_file = dir_entry.name.lower() if (lower_file.startswith("sni.") and not lower_file.endswith(".proto")) or (lower_file == "sni"): sni_path = dir_entry.path break if os.path.isfile(sni_path): snes_logger.info(f"Attempting to start {sni_path}") import sys if not sys.stdout: # if it spawns a visible console, may as well populate it subprocess.Popen(os.path.abspath(sni_path), cwd=os.path.dirname(sni_path)) else: proc = subprocess.Popen(os.path.abspath(sni_path), cwd=os.path.dirname(sni_path), stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) try: proc.wait(.1) # wait a bit to see if startup fails (missing dependencies) snes_logger.info('Failed to start SNI. Try running it externally for error output.') except subprocess.TimeoutExpired: pass # seems to be running else: snes_logger.info( f"Attempt to start SNI was aborted as path {sni_path} was not found, " f"please start it yourself if it is not running") async def _snes_connect(ctx: SNIContext, address: str, retry: bool = True) -> WebSocketClientProtocol: address = f"ws://{address}" if "://" not in address else address snes_logger.info("Connecting to SNI at %s ..." % address) seen_problems: typing.Set[str] = set() while True: try: snes_socket = await websockets_connect(address, ping_timeout=None, ping_interval=None) except Exception as e: problem = "%s" % e # only tell the user about new problems, otherwise silently lay in wait for a working connection if problem not in seen_problems: seen_problems.add(problem) snes_logger.error(f"Error connecting to SNI ({problem})") if len(seen_problems) == 1: # this is the first problem. Let's try launching SNI if it isn't already running launch_sni() await asyncio.sleep(1) else: return snes_socket if not retry: break class SNESRequest(typing.TypedDict): Opcode: str Space: str Operands: typing.List[str] # TODO: When Python 3.11 is the lowest version supported, `Operands` can use `typing.NotRequired` (pep-0655) # Then the `Operands` key doesn't need to be given for opcodes that don't use it. async def get_snes_devices(ctx: SNIContext) -> typing.List[str]: socket = await _snes_connect(ctx, ctx.snes_address) # establish new connection to poll DeviceList_Request: SNESRequest = { "Opcode": "DeviceList", "Space": "SNES", "Operands": [] } await socket.send(dumps(DeviceList_Request)) reply: typing.Dict[str, typing.Any] = loads(await socket.recv()) devices: typing.List[str] = reply['Results'] if 'Results' in reply and len(reply['Results']) > 0 else [] if not devices: snes_logger.info('No SNES device found. Please connect a SNES device to SNI.') while not devices and not ctx.exit_event.is_set(): await asyncio.sleep(0.1) await socket.send(dumps(DeviceList_Request)) reply = loads(await socket.recv()) devices = reply['Results'] if 'Results' in reply and len(reply['Results']) > 0 else [] if devices: await verify_snes_app(socket) await socket.close() return sorted(devices) async def verify_snes_app(socket: WebSocketClientProtocol) -> None: AppVersion_Request = { "Opcode": "AppVersion", } await socket.send(dumps(AppVersion_Request)) app: str = loads(await socket.recv())["Results"][0] if "SNI" not in app: snes_logger.warning(f"Warning: Did not find SNI as the endpoint, instead {app} was found.") async def snes_connect(ctx: SNIContext, address: str, deviceIndex: int = -1) -> None: global _global_snes_reconnect_delay if ctx.snes_socket is not None and ctx.snes_state == SNESState.SNES_CONNECTED: if ctx.rom: snes_logger.error('Already connected to SNES, with rom loaded.') else: snes_logger.error('Already connected to SNI, likely awaiting a device.') return ctx.cancel_snes_autoreconnect() device = None recv_task = None ctx.snes_state = SNESState.SNES_CONNECTING socket = await _snes_connect(ctx, address) ctx.snes_socket = socket ctx.snes_state = SNESState.SNES_CONNECTED try: devices = await get_snes_devices(ctx) device_count = len(devices) if device_count == 1: device = devices[0] elif ctx.snes_reconnect_address: assert ctx.snes_attached_device if ctx.snes_attached_device[1] in devices: device = ctx.snes_attached_device[1] else: device = devices[ctx.snes_attached_device[0]] elif device_count > 1: if deviceIndex == -1: snes_logger.info(f"Found {device_count} SNES devices. " f"Connect to one with /snes