diff --git a/CommonClient.py b/CommonClient.py index 3752a129..e157b31c 100644 --- a/CommonClient.py +++ b/CommonClient.py @@ -24,6 +24,9 @@ from Utils import Version, stream_input, async_start from worlds import network_data_package, AutoWorldRegister import os +if typing.TYPE_CHECKING: + import kvui + logger = logging.getLogger("Client") # without terminal, we have to use gui mode @@ -44,15 +47,17 @@ class ClientCommandProcessor(CommandProcessor): def _cmd_connect(self, address: str = "") -> bool: """Connect to a MultiWorld Server""" - self.ctx.server_address = None - self.ctx.username = None + if address: + self.ctx.server_address = None + self.ctx.username = None + elif not self.ctx.server_address: + self.output("Please specify an address.") + return False async_start(self.ctx.connect(address if address else None), name="connecting") return True def _cmd_disconnect(self) -> bool: """Disconnect from a MultiWorld Server""" - self.ctx.server_address = None - self.ctx.username = None async_start(self.ctx.disconnect(), name="disconnecting") return True @@ -144,6 +149,8 @@ class CommonContext: input_task: typing.Optional["asyncio.Task[None]"] = None keep_alive_task: typing.Optional["asyncio.Task[None]"] = None server_task: typing.Optional["asyncio.Task[None]"] = None + autoreconnect_task: typing.Optional["asyncio.Task[None]"] = None + disconnected_intentionally: bool = False server: typing.Optional[Endpoint] = None server_version: Version = Version(0, 0, 0) current_energy_link_value: int = 0 # to display in UI, gets set by server @@ -173,7 +180,9 @@ class CommonContext: # internals # current message box through kvui - _messagebox = None + _messagebox: typing.Optional["kvui.MessageBox"] = None + # message box reporting a loss of connection + _messagebox_connection_loss: typing.Optional["kvui.MessageBox"] = None def __init__(self, server_address: typing.Optional[str], password: typing.Optional[str]) -> None: # server state @@ -255,7 +264,11 @@ class CommonContext: "remaining": "disabled", } - async def disconnect(self): + async def disconnect(self, allow_autoreconnect: bool = False): + if not allow_autoreconnect: + self.disconnected_intentionally = True + if self.cancel_autoreconnect(): + logger.info("Cancelled auto-reconnect.") if self.server and not self.server.socket.closed: await self.server.socket.close() if self.server_task is not None: @@ -313,6 +326,13 @@ class CommonContext: await self.disconnect() self.server_task = asyncio.create_task(server_loop(self, address), name="server loop") + def cancel_autoreconnect(self) -> bool: + if self.autoreconnect_task: + self.autoreconnect_task.cancel() + self.autoreconnect_task = None + return True + return False + def slot_concerns_self(self, slot) -> bool: if slot == self.slot: return True @@ -357,6 +377,7 @@ class CommonContext: async def shutdown(self): self.server_address = "" self.username = None + self.cancel_autoreconnect() if self.server and not self.server.socket.closed: await self.server.socket.close() if self.server_task: @@ -450,10 +471,10 @@ class CommonContext: if old_tags != self.tags and self.server and not self.server.socket.closed: await self.send_msgs([{"cmd": "ConnectUpdate", "tags": self.tags}]) - def gui_error(self, title: str, text: typing.Union[Exception, str]): + def gui_error(self, title: str, text: typing.Union[Exception, str]) -> typing.Optional["kvui.MessageBox"]: """Displays an error messagebox""" if not self.ui: - return + return None title = title or "Error" from kvui import MessageBox if self._messagebox: @@ -470,6 +491,13 @@ class CommonContext: # display error self._messagebox = MessageBox(title, text, error=True) self._messagebox.open() + return self._messagebox + + def _handle_connection_loss(self, msg: str) -> None: + """Helper for logging and displaying a loss of connection. Must be called from an except block.""" + exc_info = sys.exc_info() + logger.exception(msg, exc_info=exc_info, extra={'compact_gui': True}) + self._messagebox_connection_loss = self.gui_error(msg, exc_info[1]) def run_gui(self): """Import kivy UI system and start running it as self.ui_task.""" @@ -519,6 +547,11 @@ async def server_loop(ctx: CommonContext, address: typing.Optional[str] = None) logger.info('Please connect to an Archipelago server.') return + ctx.cancel_autoreconnect() + if ctx._messagebox_connection_loss: + ctx._messagebox_connection_loss.dismiss() + ctx._messagebox_connection_loss = None + address = f"ws://{address}" if "://" not in address \ else address.replace("archipelago://", "ws://") @@ -529,6 +562,9 @@ async def server_loop(ctx: CommonContext, address: typing.Optional[str] = None) ctx.password = server_url.password port = server_url.port or 38281 + def reconnect_hint() -> str: + return ", type /connect to reconnect" if ctx.server_address else "" + logger.info(f'Connecting to Archipelago server at {address}') try: socket = await websockets.connect(address, port=port, ping_timeout=None, ping_interval=None) @@ -538,31 +574,25 @@ async def server_loop(ctx: CommonContext, address: typing.Optional[str] = None) logger.info('Connected') ctx.server_address = address ctx.current_reconnect_delay = ctx.starting_reconnect_delay + ctx.disconnected_intentionally = False async for data in ctx.server.socket: for msg in decode(data): await process_server_cmd(ctx, msg) - logger.warning('Disconnected from multiworld server, type /connect to reconnect') - except ConnectionRefusedError as e: - msg = 'Connection refused by the server. May not be running Archipelago on that address or port.' - logger.exception(msg, extra={'compact_gui': True}) - ctx.gui_error(msg, e) - except websockets.InvalidURI as e: - msg = 'Failed to connect to the multiworld server (invalid URI)' - logger.exception(msg, extra={'compact_gui': True}) - ctx.gui_error(msg, e) - except OSError as e: - msg = 'Failed to connect to the multiworld server' - logger.exception(msg, extra={'compact_gui': True}) - ctx.gui_error(msg, e) - except Exception as e: - msg = 'Lost connection to the multiworld server, type /connect to reconnect' - logger.exception(msg, extra={'compact_gui': True}) - ctx.gui_error(msg, e) + logger.warning(f"Disconnected from multiworld server{reconnect_hint()}") + except ConnectionRefusedError: + ctx._handle_connection_loss("Connection refused by the server. May not be running Archipelago on that address or port.") + except websockets.InvalidURI: + ctx._handle_connection_loss("Failed to connect to the multiworld server (invalid URI)") + except OSError: + ctx._handle_connection_loss("Failed to connect to the multiworld server") + except Exception: + ctx._handle_connection_loss(f"Lost connection to the multiworld server{reconnect_hint()}") finally: await ctx.connection_closed() - if ctx.server_address: - logger.info(f"... reconnecting in {ctx.current_reconnect_delay}s") - async_start(server_autoreconnect(ctx), name="server auto reconnect") + if ctx.server_address and ctx.username and not ctx.disconnected_intentionally: + logger.info(f"... automatically reconnecting in {ctx.current_reconnect_delay} seconds") + assert ctx.autoreconnect_task is None + ctx.autoreconnect_task = asyncio.create_task(server_autoreconnect(ctx), name="server auto reconnect") ctx.current_reconnect_delay *= 2 diff --git a/SNIClient.py b/SNIClient.py index b89c588d..623bc175 100644 --- a/SNIClient.py +++ b/SNIClient.py @@ -83,6 +83,7 @@ class SNIClientCommandProcessor(ClientCommandProcessor): def _cmd_snes_close(self) -> bool: """Close connection to a currently connected snes""" self.ctx.snes_reconnect_address = None + self.ctx.cancel_snes_autoreconnect() if self.ctx.snes_socket is not None and not self.ctx.snes_socket.closed: async_start(self.ctx.snes_socket.close()) return True @@ -115,6 +116,7 @@ class SNIContext(CommonContext): game = None # set in validate_rom items_handling = None # set in game_watcher snes_connect_task: "typing.Optional[asyncio.Task[None]]" = None + snes_autoreconnect_task: typing.Optional["asyncio.Task[None]"] = None snes_address: str snes_socket: typing.Optional[WebSocketClientProtocol] @@ -192,6 +194,13 @@ class SNIContext(CommonContext): auth = base64.b64encode(self.rom).decode() await self.send_connect(name=auth) + def cancel_snes_autoreconnect(self) -> bool: + if self.snes_autoreconnect_task: + self.snes_autoreconnect_task.cancel() + self.snes_autoreconnect_task = None + return True + return False + def on_deathlink(self, data: typing.Dict[str, typing.Any]) -> None: if not self.killing_player_task or self.killing_player_task.done(): self.killing_player_task = asyncio.create_task(deathlink_kill_player(self)) @@ -214,6 +223,7 @@ class SNIContext(CommonContext): async def shutdown(self) -> None: await super(SNIContext, self).shutdown() + self.cancel_snes_autoreconnect() if self.snes_connect_task: try: await asyncio.wait_for(self.snes_connect_task, 1) @@ -379,6 +389,8 @@ async def snes_connect(ctx: SNIContext, address: str, deviceIndex: int = -1) -> snes_logger.error('Already connected to SNI, likely awaiting a device.') return + ctx.cancel_snes_autoreconnect() + device = None recv_task = None ctx.snes_state = SNESState.SNES_CONNECTING @@ -442,8 +454,9 @@ async def snes_connect(ctx: SNIContext, address: str, deviceIndex: int = -1) -> if not ctx.snes_reconnect_address: snes_logger.error("Error connecting to snes (%s)" % e) else: - snes_logger.error(f"Error connecting to snes, attempt again in {_global_snes_reconnect_delay}s") - async_start(snes_autoreconnect(ctx)) + snes_logger.error(f"Error connecting to snes, retrying in {_global_snes_reconnect_delay} seconds") + assert ctx.snes_autoreconnect_task is None + ctx.snes_autoreconnect_task = asyncio.create_task(snes_autoreconnect(ctx), name="snes auto-reconnect") _global_snes_reconnect_delay *= 2 else: @@ -460,8 +473,8 @@ async def snes_disconnect(ctx: SNIContext) -> None: async def snes_autoreconnect(ctx: SNIContext) -> None: await asyncio.sleep(_global_snes_reconnect_delay) - if ctx.snes_reconnect_address and ctx.snes_socket is None: - await snes_connect(ctx, ctx.snes_reconnect_address) + if ctx.snes_reconnect_address and not ctx.snes_socket and not ctx.snes_connect_task: + ctx.snes_connect_task = asyncio.create_task(snes_connect(ctx, ctx.snes_reconnect_address), name="SNES Connect") async def snes_recv_loop(ctx: SNIContext) -> None: @@ -487,8 +500,9 @@ async def snes_recv_loop(ctx: SNIContext) -> None: ctx.rom = None if ctx.snes_reconnect_address: - snes_logger.info(f"...reconnecting in {_global_snes_reconnect_delay}s") - async_start(snes_autoreconnect(ctx)) + snes_logger.info(f"... automatically reconnecting to snes in {_global_snes_reconnect_delay} seconds") + assert ctx.snes_autoreconnect_task is None + ctx.snes_autoreconnect_task = asyncio.create_task(snes_autoreconnect(ctx), name="snes auto-reconnect") async def snes_read(ctx: SNIContext, address: int, size: int) -> typing.Optional[bytes]: @@ -619,7 +633,7 @@ async def game_watcher(ctx: SNIContext) -> None: if not rom_validated or (ctx.auth and ctx.auth != ctx.rom): snes_logger.warning("ROM change detected, please reconnect to the multiworld server") - await ctx.disconnect() + await ctx.disconnect(allow_autoreconnect=True) ctx.client_handler = None ctx.rom = None ctx.command_processor(ctx).connect_to_snes() diff --git a/kvui.py b/kvui.py index 071c07e6..d4cf09ca 100644 --- a/kvui.py +++ b/kvui.py @@ -426,7 +426,6 @@ class GameManager(App): def connect_button_action(self, button): if self.ctx.server: - self.ctx.server_address = None self.ctx.username = None async_start(self.ctx.disconnect()) else: