diff --git a/worlds/_bizhawk/context.py b/worlds/_bizhawk/context.py index 5d865f33..ccf747f1 100644 --- a/worlds/_bizhawk/context.py +++ b/worlds/_bizhawk/context.py @@ -5,6 +5,7 @@ checking or launching the client, otherwise it will probably cause circular impo import asyncio +import enum import subprocess import traceback from typing import Any, Dict, Optional @@ -21,6 +22,13 @@ from .client import BizHawkClient, AutoBizHawkClientRegister EXPECTED_SCRIPT_VERSION = 1 +class AuthStatus(enum.IntEnum): + NOT_AUTHENTICATED = 0 + NEED_INFO = 1 + PENDING = 2 + AUTHENTICATED = 3 + + class BizHawkClientCommandProcessor(ClientCommandProcessor): def _cmd_bh(self): """Shows the current status of the client's connection to BizHawk""" @@ -35,6 +43,8 @@ class BizHawkClientCommandProcessor(ClientCommandProcessor): class BizHawkClientContext(CommonContext): command_processor = BizHawkClientCommandProcessor + auth_status: AuthStatus + password_requested: bool client_handler: Optional[BizHawkClient] slot_data: Optional[Dict[str, Any]] = None rom_hash: Optional[str] = None @@ -45,6 +55,8 @@ class BizHawkClientContext(CommonContext): def __init__(self, server_address: Optional[str], password: Optional[str]): super().__init__(server_address, password) + self.auth_status = AuthStatus.NOT_AUTHENTICATED + self.password_requested = False self.client_handler = None self.bizhawk_ctx = BizHawkContext() self.watcher_timeout = 0.5 @@ -61,10 +73,41 @@ class BizHawkClientContext(CommonContext): def on_package(self, cmd, args): if cmd == "Connected": self.slot_data = args.get("slot_data", None) + self.auth_status = AuthStatus.AUTHENTICATED if self.client_handler is not None: self.client_handler.on_package(self, cmd, args) + async def server_auth(self, password_requested: bool = False): + self.password_requested = password_requested + + if self.bizhawk_ctx.connection_status != ConnectionStatus.CONNECTED: + logger.info("Awaiting connection to BizHawk before authenticating") + return + + if self.client_handler is None: + return + + # Ask handler to set auth + if self.auth is None: + self.auth_status = AuthStatus.NEED_INFO + await self.client_handler.set_auth(self) + + # Handler didn't set auth, ask user for slot name + if self.auth is None: + await self.get_username() + + if password_requested and not self.password: + self.auth_status = AuthStatus.NEED_INFO + await super(BizHawkClientContext, self).server_auth(password_requested) + + await self.send_connect() + self.auth_status = AuthStatus.PENDING + + async def disconnect(self, allow_autoreconnect: bool = False): + self.auth_status = AuthStatus.NOT_AUTHENTICATED + await super().disconnect(allow_autoreconnect) + async def _game_watcher(ctx: BizHawkClientContext): showed_connecting_message = False @@ -109,12 +152,13 @@ async def _game_watcher(ctx: BizHawkClientContext): rom_hash = await get_hash(ctx.bizhawk_ctx) if ctx.rom_hash is not None and ctx.rom_hash != rom_hash: - if ctx.server is not None: + if ctx.server is not None and not ctx.server.socket.closed: logger.info(f"ROM changed. Disconnecting from server.") - await ctx.disconnect(True) ctx.auth = None ctx.username = None + ctx.client_handler = None + await ctx.disconnect(False) ctx.rom_hash = rom_hash if ctx.client_handler is None: @@ -136,15 +180,14 @@ async def _game_watcher(ctx: BizHawkClientContext): except NotConnectedError: continue - # Get slot name and send `Connect` - if ctx.server is not None and ctx.username is None: - await ctx.client_handler.set_auth(ctx) - - if ctx.auth is None: - await ctx.get_username() - - await ctx.send_connect() + # Server auth + if ctx.server is not None and not ctx.server.socket.closed: + if ctx.auth_status == AuthStatus.NOT_AUTHENTICATED: + Utils.async_start(ctx.server_auth(ctx.password_requested)) + else: + ctx.auth_status = AuthStatus.NOT_AUTHENTICATED + # Call the handler's game watcher await ctx.client_handler.game_watcher(ctx)