647 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			647 lines
		
	
	
		
			24 KiB
		
	
	
	
		
			Python
		
	
	
	
# pylint: disable=W0212
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
import asyncio
 | 
						|
import json
 | 
						|
import platform
 | 
						|
import signal
 | 
						|
from contextlib import suppress
 | 
						|
from dataclasses import dataclass
 | 
						|
from io import BytesIO
 | 
						|
from pathlib import Path
 | 
						|
from typing import Dict, List, Optional, Tuple, Union
 | 
						|
 | 
						|
import mpyq
 | 
						|
import portpicker
 | 
						|
from aiohttp import ClientSession, ClientWebSocketResponse
 | 
						|
from worlds._sc2common.bot import logger
 | 
						|
from s2clientprotocol import sc2api_pb2 as sc_pb
 | 
						|
 | 
						|
from .bot_ai import BotAI
 | 
						|
from .client import Client
 | 
						|
from .controller import Controller
 | 
						|
from .data import CreateGameError, Result, Status
 | 
						|
from .game_state import GameState
 | 
						|
from .maps import Map
 | 
						|
from .player import AbstractPlayer, Bot, BotProcess, Human
 | 
						|
from .portconfig import Portconfig
 | 
						|
from .protocol import ConnectionAlreadyClosed, ProtocolError
 | 
						|
from .proxy import Proxy
 | 
						|
from .sc2process import SC2Process, kill_switch
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class GameMatch:
 | 
						|
    """Dataclass for hosting a match of SC2.
 | 
						|
    This contains all of the needed information for RequestCreateGame.
 | 
						|
    :param sc2_config: dicts of arguments to unpack into sc2process's construction, one per player
 | 
						|
        second sc2_config will be ignored if only one sc2_instance is spawned
 | 
						|
        e.g. sc2_args=[{"fullscreen": True}, {}]: only player 1's sc2instance will be fullscreen
 | 
						|
    :param game_time_limit: The time (in seconds) until a match is artificially declared a Tie
 | 
						|
    """
 | 
						|
 | 
						|
    map_sc2: Map
 | 
						|
    players: List[AbstractPlayer]
 | 
						|
    realtime: bool = False
 | 
						|
    random_seed: int = None
 | 
						|
    disable_fog: bool = None
 | 
						|
    sc2_config: List[Dict] = None
 | 
						|
    game_time_limit: int = None
 | 
						|
 | 
						|
    def __post_init__(self):
 | 
						|
        # avoid players sharing names
 | 
						|
        if len(self.players) > 1 and self.players[0].name is not None and self.players[0].name == self.players[1].name:
 | 
						|
            self.players[1].name += "2"
 | 
						|
 | 
						|
        if self.sc2_config is not None:
 | 
						|
            if isinstance(self.sc2_config, dict):
 | 
						|
                self.sc2_config = [self.sc2_config]
 | 
						|
            if len(self.sc2_config) == 0:
 | 
						|
                self.sc2_config = [{}]
 | 
						|
            while len(self.sc2_config) < len(self.players):
 | 
						|
                self.sc2_config += self.sc2_config
 | 
						|
            self.sc2_config = self.sc2_config[:len(self.players)]
 | 
						|
 | 
						|
    @property
 | 
						|
    def needed_sc2_count(self) -> int:
 | 
						|
        return sum(player.needs_sc2 for player in self.players)
 | 
						|
 | 
						|
    @property
 | 
						|
    def host_game_kwargs(self) -> Dict:
 | 
						|
        return {
 | 
						|
            "map_settings": self.map_sc2,
 | 
						|
            "players": self.players,
 | 
						|
            "realtime": self.realtime,
 | 
						|
            "random_seed": self.random_seed,
 | 
						|
            "disable_fog": self.disable_fog,
 | 
						|
        }
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        p1 = self.players[0]
 | 
						|
        p1 = p1.name if p1.name else p1
 | 
						|
        p2 = self.players[1]
 | 
						|
        p2 = p2.name if p2.name else p2
 | 
						|
        return f"Map: {self.map_sc2.name}, {p1} vs {p2}, realtime={self.realtime}, seed={self.random_seed}"
 | 
						|
 | 
						|
 | 
						|
async def _play_game_human(client, player_id, realtime, game_time_limit):
 | 
						|
    while True:
 | 
						|
        state = await client.observation()
 | 
						|
        if client._game_result:
 | 
						|
            return client._game_result[player_id]
 | 
						|
 | 
						|
        if game_time_limit and state.observation.observation.game_loop / 22.4 > game_time_limit:
 | 
						|
            logger.info(state.observation.game_loop, state.observation.game_loop / 22.4)
 | 
						|
            return Result.Tie
 | 
						|
 | 
						|
        if not realtime:
 | 
						|
            await client.step()
 | 
						|
 | 
						|
 | 
						|
# pylint: disable=R0912,R0911,R0914
 | 
						|
async def _play_game_ai(
 | 
						|
    client: Client, player_id: int, ai: BotAI, realtime: bool, game_time_limit: Optional[int]
 | 
						|
) -> Result:
 | 
						|
    gs: GameState = None
 | 
						|
 | 
						|
    async def initialize_first_step() -> Optional[Result]:
 | 
						|
        nonlocal gs
 | 
						|
        ai._initialize_variables()
 | 
						|
 | 
						|
        game_data = await client.get_game_data()
 | 
						|
        game_info = await client.get_game_info()
 | 
						|
        ping_response = await client.ping()
 | 
						|
 | 
						|
        # This game_data will become self.game_data in botAI
 | 
						|
        ai._prepare_start(
 | 
						|
            client, player_id, game_info, game_data, realtime=realtime, base_build=ping_response.ping.base_build
 | 
						|
        )
 | 
						|
        state = await client.observation()
 | 
						|
        # check game result every time we get the observation
 | 
						|
        if client._game_result:
 | 
						|
            await ai.on_end(client._game_result[player_id])
 | 
						|
            return client._game_result[player_id]
 | 
						|
        gs = GameState(state.observation)
 | 
						|
        proto_game_info = await client._execute(game_info=sc_pb.RequestGameInfo())
 | 
						|
        try:
 | 
						|
            ai._prepare_step(gs, proto_game_info)
 | 
						|
            await ai.on_before_start()
 | 
						|
            ai._prepare_first_step()
 | 
						|
            await ai.on_start()
 | 
						|
        # TODO Catching too general exception Exception (broad-except)
 | 
						|
        # pylint: disable=W0703
 | 
						|
        except Exception as e:
 | 
						|
            logger.exception(f"Caught unknown exception in AI on_start: {e}")
 | 
						|
            logger.error("Resigning due to previous error")
 | 
						|
            await ai.on_end(Result.Defeat)
 | 
						|
            return Result.Defeat
 | 
						|
 | 
						|
    result = await initialize_first_step()
 | 
						|
    if result is not None:
 | 
						|
        return result
 | 
						|
 | 
						|
    async def run_bot_iteration(iteration: int):
 | 
						|
        nonlocal gs
 | 
						|
        logger.debug(f"Running AI step, it={iteration} {gs.game_loop / 22.4:.2f}s")
 | 
						|
        # Issue event like unit created or unit destroyed
 | 
						|
        await ai.issue_events()
 | 
						|
        # In on_step various errors can occur - log properly
 | 
						|
        try:
 | 
						|
            await ai.on_step(iteration)
 | 
						|
        except (AttributeError, ) as e:
 | 
						|
            logger.exception(f"Caught exception: {e}")
 | 
						|
            raise
 | 
						|
        except Exception as e:
 | 
						|
            logger.exception(f"Caught unknown exception: {e}")
 | 
						|
            raise
 | 
						|
        await ai._after_step()
 | 
						|
        logger.debug("Running AI step: done")
 | 
						|
 | 
						|
    # Only used in realtime=True
 | 
						|
    previous_state_observation = None
 | 
						|
    for iteration in range(10**10):
 | 
						|
        if realtime and gs:
 | 
						|
            # On realtime=True, might get an error here: sc2.protocol.ProtocolError: ['Not in a game']
 | 
						|
            with suppress(ProtocolError):
 | 
						|
                requested_step = gs.game_loop + client.game_step
 | 
						|
                state = await client.observation(requested_step)
 | 
						|
                # If the bot took too long in the previous observation, request another observation one frame after
 | 
						|
                if state.observation.observation.game_loop > requested_step:
 | 
						|
                    logger.debug("Skipped a step in realtime=True")
 | 
						|
                    previous_state_observation = state.observation
 | 
						|
                    state = await client.observation(state.observation.observation.game_loop + 1)
 | 
						|
        else:
 | 
						|
            state = await client.observation()
 | 
						|
 | 
						|
        # check game result every time we get the observation
 | 
						|
        if client._game_result:
 | 
						|
            await ai.on_end(client._game_result[player_id])
 | 
						|
            return client._game_result[player_id]
 | 
						|
        gs = GameState(state.observation, previous_state_observation)
 | 
						|
        previous_state_observation = None
 | 
						|
        logger.debug(f"Score: {gs.score.score}")
 | 
						|
 | 
						|
        if game_time_limit and gs.game_loop / 22.4 > game_time_limit:
 | 
						|
            await ai.on_end(Result.Tie)
 | 
						|
            return Result.Tie
 | 
						|
        proto_game_info = await client._execute(game_info=sc_pb.RequestGameInfo())
 | 
						|
        ai._prepare_step(gs, proto_game_info)
 | 
						|
 | 
						|
        await run_bot_iteration(iteration)  # Main bot loop
 | 
						|
 | 
						|
        if not realtime:
 | 
						|
            if not client.in_game:  # Client left (resigned) the game
 | 
						|
                await ai.on_end(client._game_result[player_id])
 | 
						|
                return client._game_result[player_id]
 | 
						|
 | 
						|
            # TODO: In bot vs bot, if the other bot ends the game, this bot gets stuck in requesting an observation when using main.py:run_multiple_games
 | 
						|
            await client.step()
 | 
						|
    return Result.Undecided
 | 
						|
 | 
						|
 | 
						|
async def _play_game(
 | 
						|
    player: AbstractPlayer,
 | 
						|
    client: Client,
 | 
						|
    realtime,
 | 
						|
    portconfig,
 | 
						|
    game_time_limit=None,
 | 
						|
    rgb_render_config=None
 | 
						|
) -> Result:
 | 
						|
    assert isinstance(realtime, bool), repr(realtime)
 | 
						|
 | 
						|
    player_id = await client.join_game(
 | 
						|
        player.name, player.race, portconfig=portconfig, rgb_render_config=rgb_render_config
 | 
						|
    )
 | 
						|
    logger.info(f"Player {player_id} - {player.name if player.name else str(player)}")
 | 
						|
 | 
						|
    if isinstance(player, Human):
 | 
						|
        result = await _play_game_human(client, player_id, realtime, game_time_limit)
 | 
						|
    else:
 | 
						|
        result = await _play_game_ai(client, player_id, player.ai, realtime, game_time_limit)
 | 
						|
 | 
						|
    logger.info(
 | 
						|
        f"Result for player {player_id} - {player.name if player.name else str(player)}: "
 | 
						|
        f"{result._name_ if isinstance(result, Result) else result}"
 | 
						|
    )
 | 
						|
 | 
						|
    return result
 | 
						|
 | 
						|
async def _setup_host_game(
 | 
						|
    server: Controller, map_settings, players, realtime, random_seed=None, disable_fog=None, save_replay_as=None
 | 
						|
):
 | 
						|
    r = await server.create_game(map_settings, players, realtime, random_seed, disable_fog)
 | 
						|
    if r.create_game.HasField("error"):
 | 
						|
        err = f"Could not create game: {CreateGameError(r.create_game.error)}"
 | 
						|
        if r.create_game.HasField("error_details"):
 | 
						|
            err += f": {r.create_game.error_details}"
 | 
						|
        logger.critical(err)
 | 
						|
        raise RuntimeError(err)
 | 
						|
 | 
						|
    return Client(server._ws, save_replay_as)
 | 
						|
 | 
						|
 | 
						|
async def _host_game(
 | 
						|
    map_settings,
 | 
						|
    players,
 | 
						|
    realtime=False,
 | 
						|
    portconfig=None,
 | 
						|
    save_replay_as=None,
 | 
						|
    game_time_limit=None,
 | 
						|
    rgb_render_config=None,
 | 
						|
    random_seed=None,
 | 
						|
    sc2_version=None,
 | 
						|
    disable_fog=None,
 | 
						|
):
 | 
						|
 | 
						|
    assert players, "Can't create a game without players"
 | 
						|
 | 
						|
    assert any(isinstance(p, (Human, Bot)) for p in players)
 | 
						|
 | 
						|
    async with SC2Process(
 | 
						|
        fullscreen=players[0].fullscreen, render=rgb_render_config is not None, sc2_version=sc2_version
 | 
						|
    ) as server:
 | 
						|
        await server.ping()
 | 
						|
 | 
						|
        client = await _setup_host_game(
 | 
						|
            server, map_settings, players, realtime, random_seed, disable_fog, save_replay_as
 | 
						|
        )
 | 
						|
        # Bot can decide if it wants to launch with 'raw_affects_selection=True'
 | 
						|
        if not isinstance(players[0], Human) and getattr(players[0].ai, "raw_affects_selection", None) is not None:
 | 
						|
            client.raw_affects_selection = players[0].ai.raw_affects_selection
 | 
						|
 | 
						|
        result = await _play_game(players[0], client, realtime, portconfig, game_time_limit, rgb_render_config)
 | 
						|
        if client.save_replay_path is not None:
 | 
						|
            await client.save_replay(client.save_replay_path)
 | 
						|
        try:
 | 
						|
            await client.leave()
 | 
						|
        except ConnectionAlreadyClosed:
 | 
						|
            logger.error("Connection was closed before the game ended")
 | 
						|
        await client.quit()
 | 
						|
 | 
						|
        return result
 | 
						|
 | 
						|
 | 
						|
async def _host_game_aiter(
 | 
						|
    map_settings,
 | 
						|
    players,
 | 
						|
    realtime,
 | 
						|
    portconfig=None,
 | 
						|
    save_replay_as=None,
 | 
						|
    game_time_limit=None,
 | 
						|
):
 | 
						|
    assert players, "Can't create a game without players"
 | 
						|
 | 
						|
    assert any(isinstance(p, (Human, Bot)) for p in players)
 | 
						|
 | 
						|
    async with SC2Process() as server:
 | 
						|
        while True:
 | 
						|
            await server.ping()
 | 
						|
 | 
						|
            client = await _setup_host_game(server, map_settings, players, realtime)
 | 
						|
            if not isinstance(players[0], Human) and getattr(players[0].ai, "raw_affects_selection", None) is not None:
 | 
						|
                client.raw_affects_selection = players[0].ai.raw_affects_selection
 | 
						|
 | 
						|
            try:
 | 
						|
                result = await _play_game(players[0], client, realtime, portconfig, game_time_limit)
 | 
						|
 | 
						|
                if save_replay_as is not None:
 | 
						|
                    await client.save_replay(save_replay_as)
 | 
						|
                await client.leave()
 | 
						|
            except ConnectionAlreadyClosed:
 | 
						|
                logger.error("Connection was closed before the game ended")
 | 
						|
                return
 | 
						|
 | 
						|
            new_players = yield result
 | 
						|
            if new_players is not None:
 | 
						|
                players = new_players
 | 
						|
 | 
						|
 | 
						|
def _host_game_iter(*args, **kwargs):
 | 
						|
    game = _host_game_aiter(*args, **kwargs)
 | 
						|
    new_playerconfig = None
 | 
						|
    while True:
 | 
						|
        new_playerconfig = yield asyncio.get_event_loop().run_until_complete(game.asend(new_playerconfig))
 | 
						|
 | 
						|
 | 
						|
async def _join_game(
 | 
						|
    players,
 | 
						|
    realtime,
 | 
						|
    portconfig,
 | 
						|
    save_replay_as=None,
 | 
						|
    game_time_limit=None,
 | 
						|
):
 | 
						|
    async with SC2Process(fullscreen=players[1].fullscreen) as server:
 | 
						|
        await server.ping()
 | 
						|
 | 
						|
        client = Client(server._ws)
 | 
						|
        # Bot can decide if it wants to launch with 'raw_affects_selection=True'
 | 
						|
        if not isinstance(players[1], Human) and getattr(players[1].ai, "raw_affects_selection", None) is not None:
 | 
						|
            client.raw_affects_selection = players[1].ai.raw_affects_selection
 | 
						|
 | 
						|
        result = await _play_game(players[1], client, realtime, portconfig, game_time_limit)
 | 
						|
        if save_replay_as is not None:
 | 
						|
            await client.save_replay(save_replay_as)
 | 
						|
        try:
 | 
						|
            await client.leave()
 | 
						|
        except ConnectionAlreadyClosed:
 | 
						|
            logger.error("Connection was closed before the game ended")
 | 
						|
        await client.quit()
 | 
						|
 | 
						|
        return result
 | 
						|
 | 
						|
 | 
						|
def get_replay_version(replay_path: Union[str, Path]) -> Tuple[str, str]:
 | 
						|
    with open(replay_path, 'rb') as f:
 | 
						|
        replay_data = f.read()
 | 
						|
        replay_io = BytesIO()
 | 
						|
        replay_io.write(replay_data)
 | 
						|
        replay_io.seek(0)
 | 
						|
        archive = mpyq.MPQArchive(replay_io).extract()
 | 
						|
        metadata = json.loads(archive[b"replay.gamemetadata.json"].decode("utf-8"))
 | 
						|
        return metadata["BaseBuild"], metadata["DataVersion"]
 | 
						|
 | 
						|
 | 
						|
# TODO Deprecate run_game function in favor of run_multiple_games
 | 
						|
def run_game(map_settings, players, **kwargs) -> Union[Result, List[Optional[Result]]]:
 | 
						|
    """
 | 
						|
    Returns a single Result enum if the game was against the built-in computer.
 | 
						|
    Returns a list of two Result enums if the game was "Human vs Bot" or "Bot vs Bot".
 | 
						|
    """
 | 
						|
    if sum(isinstance(p, (Human, Bot)) for p in players) > 1:
 | 
						|
        host_only_args = ["save_replay_as", "rgb_render_config", "random_seed", "sc2_version", "disable_fog"]
 | 
						|
        join_kwargs = {k: v for k, v in kwargs.items() if k not in host_only_args}
 | 
						|
 | 
						|
        portconfig = Portconfig()
 | 
						|
 | 
						|
        async def run_host_and_join():
 | 
						|
            return await asyncio.gather(
 | 
						|
                _host_game(map_settings, players, **kwargs, portconfig=portconfig),
 | 
						|
                _join_game(players, **join_kwargs, portconfig=portconfig),
 | 
						|
                return_exceptions=True
 | 
						|
            )
 | 
						|
 | 
						|
        result: List[Result] = asyncio.run(run_host_and_join())
 | 
						|
        assert isinstance(result, list)
 | 
						|
        assert all(isinstance(r, Result) for r in result)
 | 
						|
    else:
 | 
						|
        result: Result = asyncio.run(_host_game(map_settings, players, **kwargs))
 | 
						|
        assert isinstance(result, Result)
 | 
						|
    return result
 | 
						|
 | 
						|
 | 
						|
async def play_from_websocket(
 | 
						|
    ws_connection: Union[str, ClientWebSocketResponse],
 | 
						|
    player: AbstractPlayer,
 | 
						|
    realtime: bool = False,
 | 
						|
    portconfig: Portconfig = None,
 | 
						|
    save_replay_as=None,
 | 
						|
    game_time_limit: int = None,
 | 
						|
    should_close=True,
 | 
						|
):
 | 
						|
    """Use this to play when the match is handled externally e.g. for bot ladder games.
 | 
						|
    Portconfig MUST be specified if not playing vs Computer.
 | 
						|
    :param ws_connection: either a string("ws://{address}:{port}/sc2api") or a ClientWebSocketResponse object
 | 
						|
    :param should_close: closes the connection if True. Use False if something else will reuse the connection
 | 
						|
 | 
						|
    e.g. ladder usage: play_from_websocket("ws://127.0.0.1:5162/sc2api", MyBot, False, portconfig=my_PC)
 | 
						|
    """
 | 
						|
    session = None
 | 
						|
    try:
 | 
						|
        if isinstance(ws_connection, str):
 | 
						|
            session = ClientSession()
 | 
						|
            ws_connection = await session.ws_connect(ws_connection, timeout=120)
 | 
						|
            should_close = True
 | 
						|
        client = Client(ws_connection)
 | 
						|
        result = await _play_game(player, client, realtime, portconfig, game_time_limit=game_time_limit)
 | 
						|
        if save_replay_as is not None:
 | 
						|
            await client.save_replay(save_replay_as)
 | 
						|
    except ConnectionAlreadyClosed:
 | 
						|
        logger.error("Connection was closed before the game ended")
 | 
						|
        return None
 | 
						|
    finally:
 | 
						|
        if should_close:
 | 
						|
            await ws_connection.close()
 | 
						|
            if session:
 | 
						|
                await session.close()
 | 
						|
 | 
						|
    return result
 | 
						|
 | 
						|
 | 
						|
async def run_match(controllers: List[Controller], match: GameMatch, close_ws=True):
 | 
						|
    await _setup_host_game(controllers[0], **match.host_game_kwargs)
 | 
						|
 | 
						|
    # Setup portconfig beforehand, so all players use the same ports
 | 
						|
    startport = None
 | 
						|
    portconfig = None
 | 
						|
    if match.needed_sc2_count > 1:
 | 
						|
        if any(isinstance(player, BotProcess) for player in match.players):
 | 
						|
            portconfig = Portconfig.contiguous_ports()
 | 
						|
            # Most ladder bots generate their server and client ports as [s+2, s+3], [s+4, s+5]
 | 
						|
            startport = portconfig.server[0] - 2
 | 
						|
        else:
 | 
						|
            portconfig = Portconfig()
 | 
						|
 | 
						|
    proxies = []
 | 
						|
    coros = []
 | 
						|
    players_that_need_sc2 = filter(lambda lambda_player: lambda_player.needs_sc2, match.players)
 | 
						|
    for i, player in enumerate(players_that_need_sc2):
 | 
						|
        if isinstance(player, BotProcess):
 | 
						|
            pport = portpicker.pick_unused_port()
 | 
						|
            p = Proxy(controllers[i], player, pport, match.game_time_limit, match.realtime)
 | 
						|
            proxies.append(p)
 | 
						|
            coros.append(p.play_with_proxy(startport))
 | 
						|
        else:
 | 
						|
            coros.append(
 | 
						|
                play_from_websocket(
 | 
						|
                    controllers[i]._ws,
 | 
						|
                    player,
 | 
						|
                    match.realtime,
 | 
						|
                    portconfig,
 | 
						|
                    should_close=close_ws,
 | 
						|
                    game_time_limit=match.game_time_limit,
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
    async_results = await asyncio.gather(*coros, return_exceptions=True)
 | 
						|
 | 
						|
    if not isinstance(async_results, list):
 | 
						|
        async_results = [async_results]
 | 
						|
    for i, a in enumerate(async_results):
 | 
						|
        if isinstance(a, Exception):
 | 
						|
            logger.error(f"Exception[{a}] thrown by {[p for p in match.players if p.needs_sc2][i]}")
 | 
						|
 | 
						|
    return process_results(match.players, async_results)
 | 
						|
 | 
						|
 | 
						|
def process_results(players: List[AbstractPlayer], async_results: List[Result]) -> Dict[AbstractPlayer, Result]:
 | 
						|
    opp_res = {Result.Victory: Result.Defeat, Result.Defeat: Result.Victory, Result.Tie: Result.Tie}
 | 
						|
    result: Dict[AbstractPlayer, Result] = {}
 | 
						|
    i = 0
 | 
						|
    for player in players:
 | 
						|
        if player.needs_sc2:
 | 
						|
            if sum(r == Result.Victory for r in async_results) <= 1:
 | 
						|
                result[player] = async_results[i]
 | 
						|
            else:
 | 
						|
                result[player] = Result.Undecided
 | 
						|
            i += 1
 | 
						|
        else:  # computer
 | 
						|
            other_result = async_results[0]
 | 
						|
            result[player] = None
 | 
						|
            if other_result in opp_res:
 | 
						|
                result[player] = opp_res[other_result]
 | 
						|
 | 
						|
    return result
 | 
						|
 | 
						|
 | 
						|
# pylint: disable=R0912
 | 
						|
async def maintain_SCII_count(count: int, controllers: List[Controller], proc_args: List[Dict] = None):
 | 
						|
    """Modifies the given list of controllers to reflect the desired amount of SCII processes"""
 | 
						|
    # kill unhealthy ones.
 | 
						|
    if controllers:
 | 
						|
        to_remove = []
 | 
						|
        alive = await asyncio.wait_for(
 | 
						|
            asyncio.gather(*(c.ping() for c in controllers if not c._ws.closed), return_exceptions=True), timeout=20
 | 
						|
        )
 | 
						|
        i = 0  # for alive
 | 
						|
        for controller in controllers:
 | 
						|
            if controller._ws.closed:
 | 
						|
                if not controller._process._session.closed:
 | 
						|
                    await controller._process._session.close()
 | 
						|
                to_remove.append(controller)
 | 
						|
            else:
 | 
						|
                if not isinstance(alive[i], sc_pb.Response):
 | 
						|
                    try:
 | 
						|
                        await controller._process._close_connection()
 | 
						|
                    finally:
 | 
						|
                        to_remove.append(controller)
 | 
						|
                i += 1
 | 
						|
        for c in to_remove:
 | 
						|
            c._process._clean(verbose=False)
 | 
						|
            if c._process in kill_switch._to_kill:
 | 
						|
                kill_switch._to_kill.remove(c._process)
 | 
						|
            controllers.remove(c)
 | 
						|
 | 
						|
    # spawn more
 | 
						|
    if len(controllers) < count:
 | 
						|
        needed = count - len(controllers)
 | 
						|
        if proc_args:
 | 
						|
            index = len(controllers) % len(proc_args)
 | 
						|
        else:
 | 
						|
            proc_args = [{} for _ in range(needed)]
 | 
						|
            index = 0
 | 
						|
        extra = [SC2Process(**proc_args[(index + _) % len(proc_args)]) for _ in range(needed)]
 | 
						|
        logger.info(f"Creating {needed} more SC2 Processes")
 | 
						|
        for _ in range(3):
 | 
						|
            if platform.system() == "Linux":
 | 
						|
                # Works on linux: start one client after the other
 | 
						|
                # pylint: disable=C2801
 | 
						|
                new_controllers = [await asyncio.wait_for(sc.__aenter__(), timeout=50) for sc in extra]
 | 
						|
            else:
 | 
						|
                # Doesnt seem to work on linux: starting 2 clients nearly at the same time
 | 
						|
                new_controllers = await asyncio.wait_for(
 | 
						|
                    # pylint: disable=C2801
 | 
						|
                    asyncio.gather(*[sc.__aenter__() for sc in extra], return_exceptions=True),
 | 
						|
                    timeout=50
 | 
						|
                )
 | 
						|
 | 
						|
            controllers.extend(c for c in new_controllers if isinstance(c, Controller))
 | 
						|
            if len(controllers) == count:
 | 
						|
                await asyncio.wait_for(asyncio.gather(*(c.ping() for c in controllers)), timeout=20)
 | 
						|
                break
 | 
						|
            extra = [
 | 
						|
                extra[i] for i, result in enumerate(new_controllers) if not isinstance(new_controllers, Controller)
 | 
						|
            ]
 | 
						|
        else:
 | 
						|
            logger.critical("Could not launch sufficient SC2")
 | 
						|
            raise RuntimeError
 | 
						|
 | 
						|
    # kill excess
 | 
						|
    while len(controllers) > count:
 | 
						|
        proc = controllers.pop()
 | 
						|
        proc = proc._process
 | 
						|
        logger.info(f"Removing SCII listening to {proc._port}")
 | 
						|
        await proc._close_connection()
 | 
						|
        proc._clean(verbose=False)
 | 
						|
        if proc in kill_switch._to_kill:
 | 
						|
            kill_switch._to_kill.remove(proc)
 | 
						|
 | 
						|
 | 
						|
def run_multiple_games(matches: List[GameMatch]):
 | 
						|
    return asyncio.get_event_loop().run_until_complete(a_run_multiple_games(matches))
 | 
						|
 | 
						|
 | 
						|
# TODO Catching too general exception Exception (broad-except)
 | 
						|
# pylint: disable=W0703
 | 
						|
async def a_run_multiple_games(matches: List[GameMatch]) -> List[Dict[AbstractPlayer, Result]]:
 | 
						|
    """Run multiple matches.
 | 
						|
    Non-python bots are supported.
 | 
						|
    When playing bot vs bot, this is less likely to fatally crash than repeating run_game()
 | 
						|
    """
 | 
						|
    if not matches:
 | 
						|
        return []
 | 
						|
 | 
						|
    results = []
 | 
						|
    controllers = []
 | 
						|
    for m in matches:
 | 
						|
        result = None
 | 
						|
        dont_restart = m.needed_sc2_count == 2
 | 
						|
        try:
 | 
						|
            await maintain_SCII_count(m.needed_sc2_count, controllers, m.sc2_config)
 | 
						|
            result = await run_match(controllers, m, close_ws=dont_restart)
 | 
						|
        except SystemExit as e:
 | 
						|
            logger.info(f"Game exit'ed as {e} during match {m}")
 | 
						|
        except Exception as e:
 | 
						|
            logger.exception(f"Caught unknown exception: {e}")
 | 
						|
            logger.info(f"Exception {e} thrown in match {m}")
 | 
						|
        finally:
 | 
						|
            if dont_restart:  # Keeping them alive after a non-computer match can cause crashes
 | 
						|
                await maintain_SCII_count(0, controllers, m.sc2_config)
 | 
						|
            results.append(result)
 | 
						|
    kill_switch.kill_all()
 | 
						|
    return results
 | 
						|
 | 
						|
 | 
						|
# TODO Catching too general exception Exception (broad-except)
 | 
						|
# pylint: disable=W0703
 | 
						|
async def a_run_multiple_games_nokill(matches: List[GameMatch]) -> List[Dict[AbstractPlayer, Result]]:
 | 
						|
    """Run multiple matches while reusing SCII processes.
 | 
						|
    Prone to crashes and stalls
 | 
						|
    """
 | 
						|
    # FIXME: check whether crashes between bot-vs-bot are avoidable or not
 | 
						|
    if not matches:
 | 
						|
        return []
 | 
						|
 | 
						|
    # Start the matches
 | 
						|
    results = []
 | 
						|
    controllers = []
 | 
						|
    for m in matches:
 | 
						|
        logger.info(f"Starting match {1 + len(results)} / {len(matches)}: {m}")
 | 
						|
        result = None
 | 
						|
        try:
 | 
						|
            await maintain_SCII_count(m.needed_sc2_count, controllers, m.sc2_config)
 | 
						|
            result = await run_match(controllers, m, close_ws=False)
 | 
						|
        except SystemExit as e:
 | 
						|
            logger.critical(f"Game sys.exit'ed as {e} during match {m}")
 | 
						|
        except Exception as e:
 | 
						|
            logger.exception(f"Caught unknown exception: {e}")
 | 
						|
            logger.info(f"Exception {e} thrown in match {m}")
 | 
						|
        finally:
 | 
						|
            for c in controllers:
 | 
						|
                try:
 | 
						|
                    await c.ping()
 | 
						|
                    if c._status != Status.launched:
 | 
						|
                        await c._execute(leave_game=sc_pb.RequestLeaveGame())
 | 
						|
                except Exception as e:
 | 
						|
                    logger.exception(f"Caught unknown exception: {e}")
 | 
						|
                    if not (isinstance(e, ProtocolError) and e.is_game_over_error):
 | 
						|
                        logger.info(f"controller {c.__dict__} threw {e}")
 | 
						|
 | 
						|
            results.append(result)
 | 
						|
 | 
						|
    # Fire the killswitch manually, instead of letting the winning player fire it.
 | 
						|
    await asyncio.wait_for(asyncio.gather(*(c._process._close_connection() for c in controllers)), timeout=50)
 | 
						|
    kill_switch.kill_all()
 | 
						|
    signal.signal(signal.SIGINT, signal.SIG_DFL)
 | 
						|
 | 
						|
    return results
 |