# pylint: disable=W0212 from __future__ import annotations import asyncio import json import platform import signal from contextlib import suppress from dataclasses import dataclass from io import BytesIO from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import mpyq import portpicker from aiohttp import ClientSession, ClientWebSocketResponse from worlds._sc2common.bot import logger from s2clientprotocol import sc2api_pb2 as sc_pb from .bot_ai import BotAI from .client import Client from .controller import Controller from .data import CreateGameError, Result, Status from .game_state import GameState from .maps import Map from .player import AbstractPlayer, Bot, BotProcess, Human from .portconfig import Portconfig from .protocol import ConnectionAlreadyClosed, ProtocolError from .proxy import Proxy from .sc2process import SC2Process, kill_switch @dataclass class GameMatch: """Dataclass for hosting a match of SC2. This contains all of the needed information for RequestCreateGame. :param sc2_config: dicts of arguments to unpack into sc2process's construction, one per player second sc2_config will be ignored if only one sc2_instance is spawned e.g. sc2_args=[{"fullscreen": True}, {}]: only player 1's sc2instance will be fullscreen :param game_time_limit: The time (in seconds) until a match is artificially declared a Tie """ map_sc2: Map players: List[AbstractPlayer] realtime: bool = False random_seed: int = None disable_fog: bool = None sc2_config: List[Dict] = None game_time_limit: int = None def __post_init__(self): # avoid players sharing names if len(self.players) > 1 and self.players[0].name is not None and self.players[0].name == self.players[1].name: self.players[1].name += "2" if self.sc2_config is not None: if isinstance(self.sc2_config, dict): self.sc2_config = [self.sc2_config] if len(self.sc2_config) == 0: self.sc2_config = [{}] while len(self.sc2_config) < len(self.players): self.sc2_config += self.sc2_config self.sc2_config = self.sc2_config[:len(self.players)] @property def needed_sc2_count(self) -> int: return sum(player.needs_sc2 for player in self.players) @property def host_game_kwargs(self) -> Dict: return { "map_settings": self.map_sc2, "players": self.players, "realtime": self.realtime, "random_seed": self.random_seed, "disable_fog": self.disable_fog, } def __repr__(self): p1 = self.players[0] p1 = p1.name if p1.name else p1 p2 = self.players[1] p2 = p2.name if p2.name else p2 return f"Map: {self.map_sc2.name}, {p1} vs {p2}, realtime={self.realtime}, seed={self.random_seed}" async def _play_game_human(client, player_id, realtime, game_time_limit): while True: state = await client.observation() if client._game_result: return client._game_result[player_id] if game_time_limit and state.observation.observation.game_loop / 22.4 > game_time_limit: logger.info(state.observation.game_loop, state.observation.game_loop / 22.4) return Result.Tie if not realtime: await client.step() # pylint: disable=R0912,R0911,R0914 async def _play_game_ai( client: Client, player_id: int, ai: BotAI, realtime: bool, game_time_limit: Optional[int] ) -> Result: gs: GameState = None async def initialize_first_step() -> Optional[Result]: nonlocal gs ai._initialize_variables() game_data = await client.get_game_data() game_info = await client.get_game_info() ping_response = await client.ping() # This game_data will become self.game_data in botAI ai._prepare_start( client, player_id, game_info, game_data, realtime=realtime, base_build=ping_response.ping.base_build ) state = await client.observation() # check game result every time we get the observation if client._game_result: await ai.on_end(client._game_result[player_id]) return client._game_result[player_id] gs = GameState(state.observation) proto_game_info = await client._execute(game_info=sc_pb.RequestGameInfo()) try: ai._prepare_step(gs, proto_game_info) await ai.on_before_start() ai._prepare_first_step() await ai.on_start() # TODO Catching too general exception Exception (broad-except) # pylint: disable=W0703 except Exception as e: logger.exception(f"Caught unknown exception in AI on_start: {e}") logger.error("Resigning due to previous error") await ai.on_end(Result.Defeat) return Result.Defeat result = await initialize_first_step() if result is not None: return result async def run_bot_iteration(iteration: int): nonlocal gs logger.debug(f"Running AI step, it={iteration} {gs.game_loop / 22.4:.2f}s") # Issue event like unit created or unit destroyed await ai.issue_events() # In on_step various errors can occur - log properly try: await ai.on_step(iteration) except (AttributeError, ) as e: logger.exception(f"Caught exception: {e}") raise except Exception as e: logger.exception(f"Caught unknown exception: {e}") raise await ai._after_step() logger.debug("Running AI step: done") # Only used in realtime=True previous_state_observation = None for iteration in range(10**10): if realtime and gs: # On realtime=True, might get an error here: sc2.protocol.ProtocolError: ['Not in a game'] with suppress(ProtocolError): requested_step = gs.game_loop + client.game_step state = await client.observation(requested_step) # If the bot took too long in the previous observation, request another observation one frame after if state.observation.observation.game_loop > requested_step: logger.debug("Skipped a step in realtime=True") previous_state_observation = state.observation state = await client.observation(state.observation.observation.game_loop + 1) else: state = await client.observation() # check game result every time we get the observation if client._game_result: await ai.on_end(client._game_result[player_id]) return client._game_result[player_id] gs = GameState(state.observation, previous_state_observation) previous_state_observation = None logger.debug(f"Score: {gs.score.score}") if game_time_limit and gs.game_loop / 22.4 > game_time_limit: await ai.on_end(Result.Tie) return Result.Tie proto_game_info = await client._execute(game_info=sc_pb.RequestGameInfo()) ai._prepare_step(gs, proto_game_info) await run_bot_iteration(iteration) # Main bot loop if not realtime: if not client.in_game: # Client left (resigned) the game await ai.on_end(client._game_result[player_id]) return client._game_result[player_id] # TODO: In bot vs bot, if the other bot ends the game, this bot gets stuck in requesting an observation when using main.py:run_multiple_games await client.step() return Result.Undecided async def _play_game( player: AbstractPlayer, client: Client, realtime, portconfig, game_time_limit=None, rgb_render_config=None ) -> Result: assert isinstance(realtime, bool), repr(realtime) player_id = await client.join_game( player.name, player.race, portconfig=portconfig, rgb_render_config=rgb_render_config ) logger.info(f"Player {player_id} - {player.name if player.name else str(player)}") if isinstance(player, Human): result = await _play_game_human(client, player_id, realtime, game_time_limit) else: result = await _play_game_ai(client, player_id, player.ai, realtime, game_time_limit) logger.info( f"Result for player {player_id} - {player.name if player.name else str(player)}: " f"{result._name_ if isinstance(result, Result) else result}" ) return result async def _setup_host_game( server: Controller, map_settings, players, realtime, random_seed=None, disable_fog=None, save_replay_as=None ): r = await server.create_game(map_settings, players, realtime, random_seed, disable_fog) if r.create_game.HasField("error"): err = f"Could not create game: {CreateGameError(r.create_game.error)}" if r.create_game.HasField("error_details"): err += f": {r.create_game.error_details}" logger.critical(err) raise RuntimeError(err) return Client(server._ws, save_replay_as) async def _host_game( map_settings, players, realtime=False, portconfig=None, save_replay_as=None, game_time_limit=None, rgb_render_config=None, random_seed=None, sc2_version=None, disable_fog=None, ): assert players, "Can't create a game without players" assert any(isinstance(p, (Human, Bot)) for p in players) async with SC2Process( fullscreen=players[0].fullscreen, render=rgb_render_config is not None, sc2_version=sc2_version ) as server: await server.ping() client = await _setup_host_game( server, map_settings, players, realtime, random_seed, disable_fog, save_replay_as ) # Bot can decide if it wants to launch with 'raw_affects_selection=True' if not isinstance(players[0], Human) and getattr(players[0].ai, "raw_affects_selection", None) is not None: client.raw_affects_selection = players[0].ai.raw_affects_selection result = await _play_game(players[0], client, realtime, portconfig, game_time_limit, rgb_render_config) if client.save_replay_path is not None: await client.save_replay(client.save_replay_path) try: await client.leave() except ConnectionAlreadyClosed: logger.error("Connection was closed before the game ended") await client.quit() return result async def _host_game_aiter( map_settings, players, realtime, portconfig=None, save_replay_as=None, game_time_limit=None, ): assert players, "Can't create a game without players" assert any(isinstance(p, (Human, Bot)) for p in players) async with SC2Process() as server: while True: await server.ping() client = await _setup_host_game(server, map_settings, players, realtime) if not isinstance(players[0], Human) and getattr(players[0].ai, "raw_affects_selection", None) is not None: client.raw_affects_selection = players[0].ai.raw_affects_selection try: result = await _play_game(players[0], client, realtime, portconfig, game_time_limit) if save_replay_as is not None: await client.save_replay(save_replay_as) await client.leave() except ConnectionAlreadyClosed: logger.error("Connection was closed before the game ended") return new_players = yield result if new_players is not None: players = new_players def _host_game_iter(*args, **kwargs): game = _host_game_aiter(*args, **kwargs) new_playerconfig = None while True: new_playerconfig = yield asyncio.get_event_loop().run_until_complete(game.asend(new_playerconfig)) async def _join_game( players, realtime, portconfig, save_replay_as=None, game_time_limit=None, ): async with SC2Process(fullscreen=players[1].fullscreen) as server: await server.ping() client = Client(server._ws) # Bot can decide if it wants to launch with 'raw_affects_selection=True' if not isinstance(players[1], Human) and getattr(players[1].ai, "raw_affects_selection", None) is not None: client.raw_affects_selection = players[1].ai.raw_affects_selection result = await _play_game(players[1], client, realtime, portconfig, game_time_limit) if save_replay_as is not None: await client.save_replay(save_replay_as) try: await client.leave() except ConnectionAlreadyClosed: logger.error("Connection was closed before the game ended") await client.quit() return result def get_replay_version(replay_path: Union[str, Path]) -> Tuple[str, str]: with open(replay_path, 'rb') as f: replay_data = f.read() replay_io = BytesIO() replay_io.write(replay_data) replay_io.seek(0) archive = mpyq.MPQArchive(replay_io).extract() metadata = json.loads(archive[b"replay.gamemetadata.json"].decode("utf-8")) return metadata["BaseBuild"], metadata["DataVersion"] # TODO Deprecate run_game function in favor of run_multiple_games def run_game(map_settings, players, **kwargs) -> Union[Result, List[Optional[Result]]]: """ Returns a single Result enum if the game was against the built-in computer. Returns a list of two Result enums if the game was "Human vs Bot" or "Bot vs Bot". """ if sum(isinstance(p, (Human, Bot)) for p in players) > 1: host_only_args = ["save_replay_as", "rgb_render_config", "random_seed", "sc2_version", "disable_fog"] join_kwargs = {k: v for k, v in kwargs.items() if k not in host_only_args} portconfig = Portconfig() async def run_host_and_join(): return await asyncio.gather( _host_game(map_settings, players, **kwargs, portconfig=portconfig), _join_game(players, **join_kwargs, portconfig=portconfig), return_exceptions=True ) result: List[Result] = asyncio.run(run_host_and_join()) assert isinstance(result, list) assert all(isinstance(r, Result) for r in result) else: result: Result = asyncio.run(_host_game(map_settings, players, **kwargs)) assert isinstance(result, Result) return result async def play_from_websocket( ws_connection: Union[str, ClientWebSocketResponse], player: AbstractPlayer, realtime: bool = False, portconfig: Portconfig = None, save_replay_as=None, game_time_limit: int = None, should_close=True, ): """Use this to play when the match is handled externally e.g. for bot ladder games. Portconfig MUST be specified if not playing vs Computer. :param ws_connection: either a string("ws://{address}:{port}/sc2api") or a ClientWebSocketResponse object :param should_close: closes the connection if True. Use False if something else will reuse the connection e.g. ladder usage: play_from_websocket("ws://127.0.0.1:5162/sc2api", MyBot, False, portconfig=my_PC) """ session = None try: if isinstance(ws_connection, str): session = ClientSession() ws_connection = await session.ws_connect(ws_connection, timeout=120) should_close = True client = Client(ws_connection) result = await _play_game(player, client, realtime, portconfig, game_time_limit=game_time_limit) if save_replay_as is not None: await client.save_replay(save_replay_as) except ConnectionAlreadyClosed: logger.error("Connection was closed before the game ended") return None finally: if should_close: await ws_connection.close() if session: await session.close() return result async def run_match(controllers: List[Controller], match: GameMatch, close_ws=True): await _setup_host_game(controllers[0], **match.host_game_kwargs) # Setup portconfig beforehand, so all players use the same ports startport = None portconfig = None if match.needed_sc2_count > 1: if any(isinstance(player, BotProcess) for player in match.players): portconfig = Portconfig.contiguous_ports() # Most ladder bots generate their server and client ports as [s+2, s+3], [s+4, s+5] startport = portconfig.server[0] - 2 else: portconfig = Portconfig() proxies = [] coros = [] players_that_need_sc2 = filter(lambda lambda_player: lambda_player.needs_sc2, match.players) for i, player in enumerate(players_that_need_sc2): if isinstance(player, BotProcess): pport = portpicker.pick_unused_port() p = Proxy(controllers[i], player, pport, match.game_time_limit, match.realtime) proxies.append(p) coros.append(p.play_with_proxy(startport)) else: coros.append( play_from_websocket( controllers[i]._ws, player, match.realtime, portconfig, should_close=close_ws, game_time_limit=match.game_time_limit, ) ) async_results = await asyncio.gather(*coros, return_exceptions=True) if not isinstance(async_results, list): async_results = [async_results] for i, a in enumerate(async_results): if isinstance(a, Exception): logger.error(f"Exception[{a}] thrown by {[p for p in match.players if p.needs_sc2][i]}") return process_results(match.players, async_results) def process_results(players: List[AbstractPlayer], async_results: List[Result]) -> Dict[AbstractPlayer, Result]: opp_res = {Result.Victory: Result.Defeat, Result.Defeat: Result.Victory, Result.Tie: Result.Tie} result: Dict[AbstractPlayer, Result] = {} i = 0 for player in players: if player.needs_sc2: if sum(r == Result.Victory for r in async_results) <= 1: result[player] = async_results[i] else: result[player] = Result.Undecided i += 1 else: # computer other_result = async_results[0] result[player] = None if other_result in opp_res: result[player] = opp_res[other_result] return result # pylint: disable=R0912 async def maintain_SCII_count(count: int, controllers: List[Controller], proc_args: List[Dict] = None): """Modifies the given list of controllers to reflect the desired amount of SCII processes""" # kill unhealthy ones. if controllers: to_remove = [] alive = await asyncio.wait_for( asyncio.gather(*(c.ping() for c in controllers if not c._ws.closed), return_exceptions=True), timeout=20 ) i = 0 # for alive for controller in controllers: if controller._ws.closed: if not controller._process._session.closed: await controller._process._session.close() to_remove.append(controller) else: if not isinstance(alive[i], sc_pb.Response): try: await controller._process._close_connection() finally: to_remove.append(controller) i += 1 for c in to_remove: c._process._clean(verbose=False) if c._process in kill_switch._to_kill: kill_switch._to_kill.remove(c._process) controllers.remove(c) # spawn more if len(controllers) < count: needed = count - len(controllers) if proc_args: index = len(controllers) % len(proc_args) else: proc_args = [{} for _ in range(needed)] index = 0 extra = [SC2Process(**proc_args[(index + _) % len(proc_args)]) for _ in range(needed)] logger.info(f"Creating {needed} more SC2 Processes") for _ in range(3): if platform.system() == "Linux": # Works on linux: start one client after the other # pylint: disable=C2801 new_controllers = [await asyncio.wait_for(sc.__aenter__(), timeout=50) for sc in extra] else: # Doesnt seem to work on linux: starting 2 clients nearly at the same time new_controllers = await asyncio.wait_for( # pylint: disable=C2801 asyncio.gather(*[sc.__aenter__() for sc in extra], return_exceptions=True), timeout=50 ) controllers.extend(c for c in new_controllers if isinstance(c, Controller)) if len(controllers) == count: await asyncio.wait_for(asyncio.gather(*(c.ping() for c in controllers)), timeout=20) break extra = [ extra[i] for i, result in enumerate(new_controllers) if not isinstance(new_controllers, Controller) ] else: logger.critical("Could not launch sufficient SC2") raise RuntimeError # kill excess while len(controllers) > count: proc = controllers.pop() proc = proc._process logger.info(f"Removing SCII listening to {proc._port}") await proc._close_connection() proc._clean(verbose=False) if proc in kill_switch._to_kill: kill_switch._to_kill.remove(proc) def run_multiple_games(matches: List[GameMatch]): return asyncio.get_event_loop().run_until_complete(a_run_multiple_games(matches)) # TODO Catching too general exception Exception (broad-except) # pylint: disable=W0703 async def a_run_multiple_games(matches: List[GameMatch]) -> List[Dict[AbstractPlayer, Result]]: """Run multiple matches. Non-python bots are supported. When playing bot vs bot, this is less likely to fatally crash than repeating run_game() """ if not matches: return [] results = [] controllers = [] for m in matches: result = None dont_restart = m.needed_sc2_count == 2 try: await maintain_SCII_count(m.needed_sc2_count, controllers, m.sc2_config) result = await run_match(controllers, m, close_ws=dont_restart) except SystemExit as e: logger.info(f"Game exit'ed as {e} during match {m}") except Exception as e: logger.exception(f"Caught unknown exception: {e}") logger.info(f"Exception {e} thrown in match {m}") finally: if dont_restart: # Keeping them alive after a non-computer match can cause crashes await maintain_SCII_count(0, controllers, m.sc2_config) results.append(result) kill_switch.kill_all() return results # TODO Catching too general exception Exception (broad-except) # pylint: disable=W0703 async def a_run_multiple_games_nokill(matches: List[GameMatch]) -> List[Dict[AbstractPlayer, Result]]: """Run multiple matches while reusing SCII processes. Prone to crashes and stalls """ # FIXME: check whether crashes between bot-vs-bot are avoidable or not if not matches: return [] # Start the matches results = [] controllers = [] for m in matches: logger.info(f"Starting match {1 + len(results)} / {len(matches)}: {m}") result = None try: await maintain_SCII_count(m.needed_sc2_count, controllers, m.sc2_config) result = await run_match(controllers, m, close_ws=False) except SystemExit as e: logger.critical(f"Game sys.exit'ed as {e} during match {m}") except Exception as e: logger.exception(f"Caught unknown exception: {e}") logger.info(f"Exception {e} thrown in match {m}") finally: for c in controllers: try: await c.ping() if c._status != Status.launched: await c._execute(leave_game=sc_pb.RequestLeaveGame()) except Exception as e: logger.exception(f"Caught unknown exception: {e}") if not (isinstance(e, ProtocolError) and e.is_game_over_error): logger.info(f"controller {c.__dict__} threw {e}") results.append(result) # Fire the killswitch manually, instead of letting the winning player fire it. await asyncio.wait_for(asyncio.gather(*(c._process._close_connection() for c in controllers)), timeout=50) kill_switch.kill_all() signal.signal(signal.SIGINT, signal.SIG_DFL) return results