Archipelago/worlds/_sc2common/bot/main.py

647 lines
24 KiB
Python

# pylint: disable=W0212
from __future__ import annotations
import asyncio
import json
import platform
import signal
from contextlib import suppress
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import mpyq
import portpicker
from aiohttp import ClientSession, ClientWebSocketResponse
from worlds._sc2common.bot import logger
from s2clientprotocol import sc2api_pb2 as sc_pb
from .bot_ai import BotAI
from .client import Client
from .controller import Controller
from .data import CreateGameError, Result, Status
from .game_state import GameState
from .maps import Map
from .player import AbstractPlayer, Bot, BotProcess, Human
from .portconfig import Portconfig
from .protocol import ConnectionAlreadyClosed, ProtocolError
from .proxy import Proxy
from .sc2process import SC2Process, kill_switch
@dataclass
class GameMatch:
"""Dataclass for hosting a match of SC2.
This contains all of the needed information for RequestCreateGame.
:param sc2_config: dicts of arguments to unpack into sc2process's construction, one per player
second sc2_config will be ignored if only one sc2_instance is spawned
e.g. sc2_args=[{"fullscreen": True}, {}]: only player 1's sc2instance will be fullscreen
:param game_time_limit: The time (in seconds) until a match is artificially declared a Tie
"""
map_sc2: Map
players: List[AbstractPlayer]
realtime: bool = False
random_seed: int = None
disable_fog: bool = None
sc2_config: List[Dict] = None
game_time_limit: int = None
def __post_init__(self):
# avoid players sharing names
if len(self.players) > 1 and self.players[0].name is not None and self.players[0].name == self.players[1].name:
self.players[1].name += "2"
if self.sc2_config is not None:
if isinstance(self.sc2_config, dict):
self.sc2_config = [self.sc2_config]
if len(self.sc2_config) == 0:
self.sc2_config = [{}]
while len(self.sc2_config) < len(self.players):
self.sc2_config += self.sc2_config
self.sc2_config = self.sc2_config[:len(self.players)]
@property
def needed_sc2_count(self) -> int:
return sum(player.needs_sc2 for player in self.players)
@property
def host_game_kwargs(self) -> Dict:
return {
"map_settings": self.map_sc2,
"players": self.players,
"realtime": self.realtime,
"random_seed": self.random_seed,
"disable_fog": self.disable_fog,
}
def __repr__(self):
p1 = self.players[0]
p1 = p1.name if p1.name else p1
p2 = self.players[1]
p2 = p2.name if p2.name else p2
return f"Map: {self.map_sc2.name}, {p1} vs {p2}, realtime={self.realtime}, seed={self.random_seed}"
async def _play_game_human(client, player_id, realtime, game_time_limit):
while True:
state = await client.observation()
if client._game_result:
return client._game_result[player_id]
if game_time_limit and state.observation.observation.game_loop / 22.4 > game_time_limit:
logger.info(state.observation.game_loop, state.observation.game_loop / 22.4)
return Result.Tie
if not realtime:
await client.step()
# pylint: disable=R0912,R0911,R0914
async def _play_game_ai(
client: Client, player_id: int, ai: BotAI, realtime: bool, game_time_limit: Optional[int]
) -> Result:
gs: GameState = None
async def initialize_first_step() -> Optional[Result]:
nonlocal gs
ai._initialize_variables()
game_data = await client.get_game_data()
game_info = await client.get_game_info()
ping_response = await client.ping()
# This game_data will become self.game_data in botAI
ai._prepare_start(
client, player_id, game_info, game_data, realtime=realtime, base_build=ping_response.ping.base_build
)
state = await client.observation()
# check game result every time we get the observation
if client._game_result:
await ai.on_end(client._game_result[player_id])
return client._game_result[player_id]
gs = GameState(state.observation)
proto_game_info = await client._execute(game_info=sc_pb.RequestGameInfo())
try:
ai._prepare_step(gs, proto_game_info)
await ai.on_before_start()
ai._prepare_first_step()
await ai.on_start()
# TODO Catching too general exception Exception (broad-except)
# pylint: disable=W0703
except Exception as e:
logger.exception(f"Caught unknown exception in AI on_start: {e}")
logger.error("Resigning due to previous error")
await ai.on_end(Result.Defeat)
return Result.Defeat
result = await initialize_first_step()
if result is not None:
return result
async def run_bot_iteration(iteration: int):
nonlocal gs
logger.debug(f"Running AI step, it={iteration} {gs.game_loop / 22.4:.2f}s")
# Issue event like unit created or unit destroyed
await ai.issue_events()
# In on_step various errors can occur - log properly
try:
await ai.on_step(iteration)
except (AttributeError, ) as e:
logger.exception(f"Caught exception: {e}")
raise
except Exception as e:
logger.exception(f"Caught unknown exception: {e}")
raise
await ai._after_step()
logger.debug("Running AI step: done")
# Only used in realtime=True
previous_state_observation = None
for iteration in range(10**10):
if realtime and gs:
# On realtime=True, might get an error here: sc2.protocol.ProtocolError: ['Not in a game']
with suppress(ProtocolError):
requested_step = gs.game_loop + client.game_step
state = await client.observation(requested_step)
# If the bot took too long in the previous observation, request another observation one frame after
if state.observation.observation.game_loop > requested_step:
logger.debug("Skipped a step in realtime=True")
previous_state_observation = state.observation
state = await client.observation(state.observation.observation.game_loop + 1)
else:
state = await client.observation()
# check game result every time we get the observation
if client._game_result:
await ai.on_end(client._game_result[player_id])
return client._game_result[player_id]
gs = GameState(state.observation, previous_state_observation)
previous_state_observation = None
logger.debug(f"Score: {gs.score.score}")
if game_time_limit and gs.game_loop / 22.4 > game_time_limit:
await ai.on_end(Result.Tie)
return Result.Tie
proto_game_info = await client._execute(game_info=sc_pb.RequestGameInfo())
ai._prepare_step(gs, proto_game_info)
await run_bot_iteration(iteration) # Main bot loop
if not realtime:
if not client.in_game: # Client left (resigned) the game
await ai.on_end(client._game_result[player_id])
return client._game_result[player_id]
# TODO: In bot vs bot, if the other bot ends the game, this bot gets stuck in requesting an observation when using main.py:run_multiple_games
await client.step()
return Result.Undecided
async def _play_game(
player: AbstractPlayer,
client: Client,
realtime,
portconfig,
game_time_limit=None,
rgb_render_config=None
) -> Result:
assert isinstance(realtime, bool), repr(realtime)
player_id = await client.join_game(
player.name, player.race, portconfig=portconfig, rgb_render_config=rgb_render_config
)
logger.info(f"Player {player_id} - {player.name if player.name else str(player)}")
if isinstance(player, Human):
result = await _play_game_human(client, player_id, realtime, game_time_limit)
else:
result = await _play_game_ai(client, player_id, player.ai, realtime, game_time_limit)
logger.info(
f"Result for player {player_id} - {player.name if player.name else str(player)}: "
f"{result._name_ if isinstance(result, Result) else result}"
)
return result
async def _setup_host_game(
server: Controller, map_settings, players, realtime, random_seed=None, disable_fog=None, save_replay_as=None
):
r = await server.create_game(map_settings, players, realtime, random_seed, disable_fog)
if r.create_game.HasField("error"):
err = f"Could not create game: {CreateGameError(r.create_game.error)}"
if r.create_game.HasField("error_details"):
err += f": {r.create_game.error_details}"
logger.critical(err)
raise RuntimeError(err)
return Client(server._ws, save_replay_as)
async def _host_game(
map_settings,
players,
realtime=False,
portconfig=None,
save_replay_as=None,
game_time_limit=None,
rgb_render_config=None,
random_seed=None,
sc2_version=None,
disable_fog=None,
):
assert players, "Can't create a game without players"
assert any(isinstance(p, (Human, Bot)) for p in players)
async with SC2Process(
fullscreen=players[0].fullscreen, render=rgb_render_config is not None, sc2_version=sc2_version
) as server:
await server.ping()
client = await _setup_host_game(
server, map_settings, players, realtime, random_seed, disable_fog, save_replay_as
)
# Bot can decide if it wants to launch with 'raw_affects_selection=True'
if not isinstance(players[0], Human) and getattr(players[0].ai, "raw_affects_selection", None) is not None:
client.raw_affects_selection = players[0].ai.raw_affects_selection
result = await _play_game(players[0], client, realtime, portconfig, game_time_limit, rgb_render_config)
if client.save_replay_path is not None:
await client.save_replay(client.save_replay_path)
try:
await client.leave()
except ConnectionAlreadyClosed:
logger.error("Connection was closed before the game ended")
await client.quit()
return result
async def _host_game_aiter(
map_settings,
players,
realtime,
portconfig=None,
save_replay_as=None,
game_time_limit=None,
):
assert players, "Can't create a game without players"
assert any(isinstance(p, (Human, Bot)) for p in players)
async with SC2Process() as server:
while True:
await server.ping()
client = await _setup_host_game(server, map_settings, players, realtime)
if not isinstance(players[0], Human) and getattr(players[0].ai, "raw_affects_selection", None) is not None:
client.raw_affects_selection = players[0].ai.raw_affects_selection
try:
result = await _play_game(players[0], client, realtime, portconfig, game_time_limit)
if save_replay_as is not None:
await client.save_replay(save_replay_as)
await client.leave()
except ConnectionAlreadyClosed:
logger.error("Connection was closed before the game ended")
return
new_players = yield result
if new_players is not None:
players = new_players
def _host_game_iter(*args, **kwargs):
game = _host_game_aiter(*args, **kwargs)
new_playerconfig = None
while True:
new_playerconfig = yield asyncio.get_event_loop().run_until_complete(game.asend(new_playerconfig))
async def _join_game(
players,
realtime,
portconfig,
save_replay_as=None,
game_time_limit=None,
):
async with SC2Process(fullscreen=players[1].fullscreen) as server:
await server.ping()
client = Client(server._ws)
# Bot can decide if it wants to launch with 'raw_affects_selection=True'
if not isinstance(players[1], Human) and getattr(players[1].ai, "raw_affects_selection", None) is not None:
client.raw_affects_selection = players[1].ai.raw_affects_selection
result = await _play_game(players[1], client, realtime, portconfig, game_time_limit)
if save_replay_as is not None:
await client.save_replay(save_replay_as)
try:
await client.leave()
except ConnectionAlreadyClosed:
logger.error("Connection was closed before the game ended")
await client.quit()
return result
def get_replay_version(replay_path: Union[str, Path]) -> Tuple[str, str]:
with open(replay_path, 'rb') as f:
replay_data = f.read()
replay_io = BytesIO()
replay_io.write(replay_data)
replay_io.seek(0)
archive = mpyq.MPQArchive(replay_io).extract()
metadata = json.loads(archive[b"replay.gamemetadata.json"].decode("utf-8"))
return metadata["BaseBuild"], metadata["DataVersion"]
# TODO Deprecate run_game function in favor of run_multiple_games
def run_game(map_settings, players, **kwargs) -> Union[Result, List[Optional[Result]]]:
"""
Returns a single Result enum if the game was against the built-in computer.
Returns a list of two Result enums if the game was "Human vs Bot" or "Bot vs Bot".
"""
if sum(isinstance(p, (Human, Bot)) for p in players) > 1:
host_only_args = ["save_replay_as", "rgb_render_config", "random_seed", "sc2_version", "disable_fog"]
join_kwargs = {k: v for k, v in kwargs.items() if k not in host_only_args}
portconfig = Portconfig()
async def run_host_and_join():
return await asyncio.gather(
_host_game(map_settings, players, **kwargs, portconfig=portconfig),
_join_game(players, **join_kwargs, portconfig=portconfig),
return_exceptions=True
)
result: List[Result] = asyncio.run(run_host_and_join())
assert isinstance(result, list)
assert all(isinstance(r, Result) for r in result)
else:
result: Result = asyncio.run(_host_game(map_settings, players, **kwargs))
assert isinstance(result, Result)
return result
async def play_from_websocket(
ws_connection: Union[str, ClientWebSocketResponse],
player: AbstractPlayer,
realtime: bool = False,
portconfig: Portconfig = None,
save_replay_as=None,
game_time_limit: int = None,
should_close=True,
):
"""Use this to play when the match is handled externally e.g. for bot ladder games.
Portconfig MUST be specified if not playing vs Computer.
:param ws_connection: either a string("ws://{address}:{port}/sc2api") or a ClientWebSocketResponse object
:param should_close: closes the connection if True. Use False if something else will reuse the connection
e.g. ladder usage: play_from_websocket("ws://127.0.0.1:5162/sc2api", MyBot, False, portconfig=my_PC)
"""
session = None
try:
if isinstance(ws_connection, str):
session = ClientSession()
ws_connection = await session.ws_connect(ws_connection, timeout=120)
should_close = True
client = Client(ws_connection)
result = await _play_game(player, client, realtime, portconfig, game_time_limit=game_time_limit)
if save_replay_as is not None:
await client.save_replay(save_replay_as)
except ConnectionAlreadyClosed:
logger.error("Connection was closed before the game ended")
return None
finally:
if should_close:
await ws_connection.close()
if session:
await session.close()
return result
async def run_match(controllers: List[Controller], match: GameMatch, close_ws=True):
await _setup_host_game(controllers[0], **match.host_game_kwargs)
# Setup portconfig beforehand, so all players use the same ports
startport = None
portconfig = None
if match.needed_sc2_count > 1:
if any(isinstance(player, BotProcess) for player in match.players):
portconfig = Portconfig.contiguous_ports()
# Most ladder bots generate their server and client ports as [s+2, s+3], [s+4, s+5]
startport = portconfig.server[0] - 2
else:
portconfig = Portconfig()
proxies = []
coros = []
players_that_need_sc2 = filter(lambda lambda_player: lambda_player.needs_sc2, match.players)
for i, player in enumerate(players_that_need_sc2):
if isinstance(player, BotProcess):
pport = portpicker.pick_unused_port()
p = Proxy(controllers[i], player, pport, match.game_time_limit, match.realtime)
proxies.append(p)
coros.append(p.play_with_proxy(startport))
else:
coros.append(
play_from_websocket(
controllers[i]._ws,
player,
match.realtime,
portconfig,
should_close=close_ws,
game_time_limit=match.game_time_limit,
)
)
async_results = await asyncio.gather(*coros, return_exceptions=True)
if not isinstance(async_results, list):
async_results = [async_results]
for i, a in enumerate(async_results):
if isinstance(a, Exception):
logger.error(f"Exception[{a}] thrown by {[p for p in match.players if p.needs_sc2][i]}")
return process_results(match.players, async_results)
def process_results(players: List[AbstractPlayer], async_results: List[Result]) -> Dict[AbstractPlayer, Result]:
opp_res = {Result.Victory: Result.Defeat, Result.Defeat: Result.Victory, Result.Tie: Result.Tie}
result: Dict[AbstractPlayer, Result] = {}
i = 0
for player in players:
if player.needs_sc2:
if sum(r == Result.Victory for r in async_results) <= 1:
result[player] = async_results[i]
else:
result[player] = Result.Undecided
i += 1
else: # computer
other_result = async_results[0]
result[player] = None
if other_result in opp_res:
result[player] = opp_res[other_result]
return result
# pylint: disable=R0912
async def maintain_SCII_count(count: int, controllers: List[Controller], proc_args: List[Dict] = None):
"""Modifies the given list of controllers to reflect the desired amount of SCII processes"""
# kill unhealthy ones.
if controllers:
to_remove = []
alive = await asyncio.wait_for(
asyncio.gather(*(c.ping() for c in controllers if not c._ws.closed), return_exceptions=True), timeout=20
)
i = 0 # for alive
for controller in controllers:
if controller._ws.closed:
if not controller._process._session.closed:
await controller._process._session.close()
to_remove.append(controller)
else:
if not isinstance(alive[i], sc_pb.Response):
try:
await controller._process._close_connection()
finally:
to_remove.append(controller)
i += 1
for c in to_remove:
c._process._clean(verbose=False)
if c._process in kill_switch._to_kill:
kill_switch._to_kill.remove(c._process)
controllers.remove(c)
# spawn more
if len(controllers) < count:
needed = count - len(controllers)
if proc_args:
index = len(controllers) % len(proc_args)
else:
proc_args = [{} for _ in range(needed)]
index = 0
extra = [SC2Process(**proc_args[(index + _) % len(proc_args)]) for _ in range(needed)]
logger.info(f"Creating {needed} more SC2 Processes")
for _ in range(3):
if platform.system() == "Linux":
# Works on linux: start one client after the other
# pylint: disable=C2801
new_controllers = [await asyncio.wait_for(sc.__aenter__(), timeout=50) for sc in extra]
else:
# Doesnt seem to work on linux: starting 2 clients nearly at the same time
new_controllers = await asyncio.wait_for(
# pylint: disable=C2801
asyncio.gather(*[sc.__aenter__() for sc in extra], return_exceptions=True),
timeout=50
)
controllers.extend(c for c in new_controllers if isinstance(c, Controller))
if len(controllers) == count:
await asyncio.wait_for(asyncio.gather(*(c.ping() for c in controllers)), timeout=20)
break
extra = [
extra[i] for i, result in enumerate(new_controllers) if not isinstance(new_controllers, Controller)
]
else:
logger.critical("Could not launch sufficient SC2")
raise RuntimeError
# kill excess
while len(controllers) > count:
proc = controllers.pop()
proc = proc._process
logger.info(f"Removing SCII listening to {proc._port}")
await proc._close_connection()
proc._clean(verbose=False)
if proc in kill_switch._to_kill:
kill_switch._to_kill.remove(proc)
def run_multiple_games(matches: List[GameMatch]):
return asyncio.get_event_loop().run_until_complete(a_run_multiple_games(matches))
# TODO Catching too general exception Exception (broad-except)
# pylint: disable=W0703
async def a_run_multiple_games(matches: List[GameMatch]) -> List[Dict[AbstractPlayer, Result]]:
"""Run multiple matches.
Non-python bots are supported.
When playing bot vs bot, this is less likely to fatally crash than repeating run_game()
"""
if not matches:
return []
results = []
controllers = []
for m in matches:
result = None
dont_restart = m.needed_sc2_count == 2
try:
await maintain_SCII_count(m.needed_sc2_count, controllers, m.sc2_config)
result = await run_match(controllers, m, close_ws=dont_restart)
except SystemExit as e:
logger.info(f"Game exit'ed as {e} during match {m}")
except Exception as e:
logger.exception(f"Caught unknown exception: {e}")
logger.info(f"Exception {e} thrown in match {m}")
finally:
if dont_restart: # Keeping them alive after a non-computer match can cause crashes
await maintain_SCII_count(0, controllers, m.sc2_config)
results.append(result)
kill_switch.kill_all()
return results
# TODO Catching too general exception Exception (broad-except)
# pylint: disable=W0703
async def a_run_multiple_games_nokill(matches: List[GameMatch]) -> List[Dict[AbstractPlayer, Result]]:
"""Run multiple matches while reusing SCII processes.
Prone to crashes and stalls
"""
# FIXME: check whether crashes between bot-vs-bot are avoidable or not
if not matches:
return []
# Start the matches
results = []
controllers = []
for m in matches:
logger.info(f"Starting match {1 + len(results)} / {len(matches)}: {m}")
result = None
try:
await maintain_SCII_count(m.needed_sc2_count, controllers, m.sc2_config)
result = await run_match(controllers, m, close_ws=False)
except SystemExit as e:
logger.critical(f"Game sys.exit'ed as {e} during match {m}")
except Exception as e:
logger.exception(f"Caught unknown exception: {e}")
logger.info(f"Exception {e} thrown in match {m}")
finally:
for c in controllers:
try:
await c.ping()
if c._status != Status.launched:
await c._execute(leave_game=sc_pb.RequestLeaveGame())
except Exception as e:
logger.exception(f"Caught unknown exception: {e}")
if not (isinstance(e, ProtocolError) and e.is_game_over_error):
logger.info(f"controller {c.__dict__} threw {e}")
results.append(result)
# Fire the killswitch manually, instead of letting the winning player fire it.
await asyncio.wait_for(asyncio.gather(*(c._process._close_connection() for c in controllers)), timeout=50)
kill_switch.kill_all()
signal.signal(signal.SIGINT, signal.SIG_DFL)
return results