234 lines
10 KiB
Python
234 lines
10 KiB
Python
# pylint: disable=W0212
|
|
import asyncio
|
|
import os
|
|
import platform
|
|
import subprocess
|
|
import time
|
|
import traceback
|
|
|
|
from aiohttp import WSMsgType, web
|
|
from worlds._sc2common.bot import logger
|
|
from s2clientprotocol import sc2api_pb2 as sc_pb
|
|
|
|
from .controller import Controller
|
|
from .data import Result, Status
|
|
from .player import BotProcess
|
|
|
|
|
|
class Proxy:
|
|
"""
|
|
Class for handling communication between sc2 and an external bot.
|
|
This "middleman" is needed for enforcing time limits, collecting results, and closing things properly.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
controller: Controller,
|
|
player: BotProcess,
|
|
proxyport: int,
|
|
game_time_limit: int = None,
|
|
realtime: bool = False,
|
|
):
|
|
self.controller = controller
|
|
self.player = player
|
|
self.port = proxyport
|
|
self.timeout_loop = game_time_limit * 22.4 if game_time_limit else None
|
|
self.realtime = realtime
|
|
logger.debug(
|
|
f"Proxy Inited with ctrl {controller}({controller._process._port}), player {player}, proxyport {proxyport}, lim {game_time_limit}"
|
|
)
|
|
|
|
self.result = None
|
|
self.player_id: int = None
|
|
self.done = False
|
|
|
|
async def parse_request(self, msg):
|
|
request = sc_pb.Request()
|
|
request.ParseFromString(msg.data)
|
|
if request.HasField("quit"):
|
|
request = sc_pb.Request(leave_game=sc_pb.RequestLeaveGame())
|
|
if request.HasField("leave_game"):
|
|
if self.controller._status == Status.in_game:
|
|
logger.info(f"Proxy: player {self.player.name}({self.player_id}) surrenders")
|
|
self.result = {self.player_id: Result.Defeat}
|
|
elif self.controller._status == Status.ended:
|
|
await self.get_response()
|
|
elif request.HasField("join_game") and not request.join_game.HasField("player_name"):
|
|
request.join_game.player_name = self.player.name
|
|
await self.controller._ws.send_bytes(request.SerializeToString())
|
|
|
|
# TODO Catching too general exception Exception (broad-except)
|
|
# pylint: disable=W0703
|
|
async def get_response(self):
|
|
response_bytes = None
|
|
try:
|
|
response_bytes = await self.controller._ws.receive_bytes()
|
|
except TypeError as e:
|
|
logger.exception("Cannot receive: SC2 Connection already closed.")
|
|
tb = traceback.format_exc()
|
|
logger.error(f"Exception {e}: {tb}")
|
|
except asyncio.CancelledError:
|
|
logger.info(f"Proxy({self.player.name}), caught receive from sc2")
|
|
try:
|
|
x = await self.controller._ws.receive_bytes()
|
|
if response_bytes is None:
|
|
response_bytes = x
|
|
except (asyncio.CancelledError, asyncio.TimeoutError, Exception) as e:
|
|
logger.exception(f"Exception {e}")
|
|
except Exception as e:
|
|
logger.exception(f"Caught unknown exception: {e}")
|
|
return response_bytes
|
|
|
|
async def parse_response(self, response_bytes):
|
|
response = sc_pb.Response()
|
|
response.ParseFromString(response_bytes)
|
|
|
|
if not response.HasField("status"):
|
|
logger.critical("Proxy: RESPONSE HAS NO STATUS {response}")
|
|
else:
|
|
new_status = Status(response.status)
|
|
if new_status != self.controller._status:
|
|
logger.info(f"Controller({self.player.name}): {self.controller._status}->{new_status}")
|
|
self.controller._status = new_status
|
|
|
|
if self.player_id is None:
|
|
if response.HasField("join_game"):
|
|
self.player_id = response.join_game.player_id
|
|
logger.info(f"Proxy({self.player.name}): got join_game for {self.player_id}")
|
|
|
|
if self.result is None:
|
|
if response.HasField("observation"):
|
|
obs: sc_pb.ResponseObservation = response.observation
|
|
if obs.player_result:
|
|
self.result = {pr.player_id: Result(pr.result) for pr in obs.player_result}
|
|
elif (
|
|
self.timeout_loop and obs.HasField("observation") and obs.observation.game_loop > self.timeout_loop
|
|
):
|
|
self.result = {i: Result.Tie for i in range(1, 3)}
|
|
logger.info(f"Proxy({self.player.name}) timing out")
|
|
act = [sc_pb.Action(action_chat=sc_pb.ActionChat(message="Proxy: Timing out"))]
|
|
await self.controller._execute(action=sc_pb.RequestAction(actions=act))
|
|
return response
|
|
|
|
async def get_result(self):
|
|
try:
|
|
res = await self.controller.ping()
|
|
if res.status in {Status.in_game, Status.in_replay, Status.ended}:
|
|
res = await self.controller._execute(observation=sc_pb.RequestObservation())
|
|
if res.HasField("observation") and res.observation.player_result:
|
|
self.result = {pr.player_id: Result(pr.result) for pr in res.observation.player_result}
|
|
# pylint: disable=W0703
|
|
# TODO Catching too general exception Exception (broad-except)
|
|
except Exception as e:
|
|
logger.exception(f"Caught unknown exception: {e}")
|
|
|
|
async def proxy_handler(self, request):
|
|
bot_ws = web.WebSocketResponse(receive_timeout=30)
|
|
await bot_ws.prepare(request)
|
|
try:
|
|
async for msg in bot_ws:
|
|
if msg.data is None:
|
|
raise TypeError(f"data is None, {msg}")
|
|
if msg.data and msg.type == WSMsgType.BINARY:
|
|
|
|
await self.parse_request(msg)
|
|
|
|
response_bytes = await self.get_response()
|
|
if response_bytes is None:
|
|
raise ConnectionError("Could not get response_bytes")
|
|
|
|
new_response = await self.parse_response(response_bytes)
|
|
await bot_ws.send_bytes(new_response.SerializeToString())
|
|
|
|
elif msg.type == WSMsgType.CLOSED:
|
|
logger.error("Client shutdown")
|
|
else:
|
|
logger.error("Incorrect message type")
|
|
# pylint: disable=W0703
|
|
# TODO Catching too general exception Exception (broad-except)
|
|
except Exception as e:
|
|
logger.exception(f"Caught unknown exception: {e}")
|
|
ignored_errors = {ConnectionError, asyncio.CancelledError}
|
|
if not any(isinstance(e, E) for E in ignored_errors):
|
|
tb = traceback.format_exc()
|
|
logger.info(f"Proxy({self.player.name}): Caught {e} traceback: {tb}")
|
|
finally:
|
|
try:
|
|
if self.controller._status in {Status.in_game, Status.in_replay}:
|
|
await self.controller._execute(leave_game=sc_pb.RequestLeaveGame())
|
|
await bot_ws.close()
|
|
# pylint: disable=W0703
|
|
# TODO Catching too general exception Exception (broad-except)
|
|
except Exception as e:
|
|
logger.exception(f"Caught unknown exception during surrender: {e}")
|
|
self.done = True
|
|
return bot_ws
|
|
|
|
# pylint: disable=R0912
|
|
async def play_with_proxy(self, startport):
|
|
logger.info(f"Proxy({self.port}): Starting app")
|
|
app = web.Application()
|
|
app.router.add_route("GET", "/sc2api", self.proxy_handler)
|
|
apprunner = web.AppRunner(app, access_log=None)
|
|
await apprunner.setup()
|
|
appsite = web.TCPSite(apprunner, self.controller._process._host, self.port)
|
|
await appsite.start()
|
|
|
|
subproc_args = {"cwd": str(self.player.path), "stderr": subprocess.STDOUT}
|
|
if platform.system() == "Linux":
|
|
subproc_args["preexec_fn"] = os.setpgrp
|
|
elif platform.system() == "Windows":
|
|
subproc_args["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
|
|
|
|
player_command_line = self.player.cmd_line(self.port, startport, self.controller._process._host, self.realtime)
|
|
logger.info(f"Starting bot with command: {' '.join(player_command_line)}")
|
|
if self.player.stdout is None:
|
|
bot_process = subprocess.Popen(player_command_line, stdout=subprocess.DEVNULL, **subproc_args)
|
|
else:
|
|
with open(self.player.stdout, "w+") as out:
|
|
bot_process = subprocess.Popen(player_command_line, stdout=out, **subproc_args)
|
|
|
|
while self.result is None:
|
|
bot_alive = bot_process and bot_process.poll() is None
|
|
sc2_alive = self.controller.running
|
|
if self.done or not (bot_alive and sc2_alive):
|
|
logger.info(
|
|
f"Proxy({self.port}): {self.player.name} died, "
|
|
f"bot{(not bot_alive) * ' not'} alive, sc2{(not sc2_alive) * ' not'} alive"
|
|
)
|
|
# Maybe its still possible to retrieve a result
|
|
if sc2_alive and not self.done:
|
|
await self.get_response()
|
|
logger.info(f"Proxy({self.port}): breaking, result {self.result}")
|
|
break
|
|
await asyncio.sleep(5)
|
|
|
|
# cleanup
|
|
logger.info(f"({self.port}): cleaning up {self.player !r}")
|
|
for _i in range(3):
|
|
if isinstance(bot_process, subprocess.Popen):
|
|
if bot_process.stdout and not bot_process.stdout.closed: # should not run anymore
|
|
logger.info(f"==================output for player {self.player.name}")
|
|
for l in bot_process.stdout.readlines():
|
|
logger.opt(raw=True).info(l.decode("utf-8"))
|
|
bot_process.stdout.close()
|
|
logger.info("==================")
|
|
bot_process.terminate()
|
|
bot_process.wait()
|
|
time.sleep(0.5)
|
|
if not bot_process or bot_process.poll() is not None:
|
|
break
|
|
else:
|
|
bot_process.terminate()
|
|
bot_process.wait()
|
|
try:
|
|
await apprunner.cleanup()
|
|
# pylint: disable=W0703
|
|
# TODO Catching too general exception Exception (broad-except)
|
|
except Exception as e:
|
|
logger.exception(f"Caught unknown exception during cleaning: {e}")
|
|
if isinstance(self.result, dict):
|
|
self.result[None] = None
|
|
return self.result[self.player_id]
|
|
return self.result
|