WebHost: use a limited process pool to run Rooms (#3214)

This commit is contained in:
Fabian Dill 2024-05-17 12:21:01 +02:00 committed by GitHub
parent 3dbdd048cd
commit 7900e4c9a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 170 additions and 143 deletions

View File

@ -175,11 +175,13 @@ class Context:
all_item_and_group_names: typing.Dict[str, typing.Set[str]] all_item_and_group_names: typing.Dict[str, typing.Set[str]]
all_location_and_group_names: typing.Dict[str, typing.Set[str]] all_location_and_group_names: typing.Dict[str, typing.Set[str]]
non_hintable_names: typing.Dict[str, typing.Set[str]] non_hintable_names: typing.Dict[str, typing.Set[str]]
logger: logging.Logger
def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int, def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int,
hint_cost: int, item_cheat: bool, release_mode: str = "disabled", collect_mode="disabled", hint_cost: int, item_cheat: bool, release_mode: str = "disabled", collect_mode="disabled",
remaining_mode: str = "disabled", auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2, remaining_mode: str = "disabled", auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2,
log_network: bool = False): log_network: bool = False, logger: logging.Logger = logging.getLogger()):
self.logger = logger
super(Context, self).__init__() super(Context, self).__init__()
self.slot_info = {} self.slot_info = {}
self.log_network = log_network self.log_network = log_network
@ -287,12 +289,12 @@ class Context:
try: try:
await endpoint.socket.send(msg) await endpoint.socket.send(msg)
except websockets.ConnectionClosed: except websockets.ConnectionClosed:
logging.exception(f"Exception during send_msgs, could not send {msg}") self.logger.exception(f"Exception during send_msgs, could not send {msg}")
await self.disconnect(endpoint) await self.disconnect(endpoint)
return False return False
else: else:
if self.log_network: if self.log_network:
logging.info(f"Outgoing message: {msg}") self.logger.info(f"Outgoing message: {msg}")
return True return True
async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool: async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
@ -301,12 +303,12 @@ class Context:
try: try:
await endpoint.socket.send(msg) await endpoint.socket.send(msg)
except websockets.ConnectionClosed: except websockets.ConnectionClosed:
logging.exception("Exception during send_encoded_msgs") self.logger.exception("Exception during send_encoded_msgs")
await self.disconnect(endpoint) await self.disconnect(endpoint)
return False return False
else: else:
if self.log_network: if self.log_network:
logging.info(f"Outgoing message: {msg}") self.logger.info(f"Outgoing message: {msg}")
return True return True
async def broadcast_send_encoded_msgs(self, endpoints: typing.Iterable[Endpoint], msg: str) -> bool: async def broadcast_send_encoded_msgs(self, endpoints: typing.Iterable[Endpoint], msg: str) -> bool:
@ -317,11 +319,11 @@ class Context:
try: try:
websockets.broadcast(sockets, msg) websockets.broadcast(sockets, msg)
except RuntimeError: except RuntimeError:
logging.exception("Exception during broadcast_send_encoded_msgs") self.logger.exception("Exception during broadcast_send_encoded_msgs")
return False return False
else: else:
if self.log_network: if self.log_network:
logging.info(f"Outgoing broadcast: {msg}") self.logger.info(f"Outgoing broadcast: {msg}")
return True return True
def broadcast_all(self, msgs: typing.List[dict]): def broadcast_all(self, msgs: typing.List[dict]):
@ -330,7 +332,7 @@ class Context:
async_start(self.broadcast_send_encoded_msgs(endpoints, msgs)) async_start(self.broadcast_send_encoded_msgs(endpoints, msgs))
def broadcast_text_all(self, text: str, additional_arguments: dict = {}): def broadcast_text_all(self, text: str, additional_arguments: dict = {}):
logging.info("Notice (all): %s" % text) self.logger.info("Notice (all): %s" % text)
self.broadcast_all([{**{"cmd": "PrintJSON", "data": [{ "text": text }]}, **additional_arguments}]) self.broadcast_all([{**{"cmd": "PrintJSON", "data": [{ "text": text }]}, **additional_arguments}])
def broadcast_team(self, team: int, msgs: typing.List[dict]): def broadcast_team(self, team: int, msgs: typing.List[dict]):
@ -352,7 +354,7 @@ class Context:
def notify_client(self, client: Client, text: str, additional_arguments: dict = {}): def notify_client(self, client: Client, text: str, additional_arguments: dict = {}):
if not client.auth: if not client.auth:
return return
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text)) self.logger.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
async_start(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }], **additional_arguments}])) async_start(self.send_msgs(client, [{"cmd": "PrintJSON", "data": [{ "text": text }], **additional_arguments}]))
def notify_client_multiple(self, client: Client, texts: typing.List[str], additional_arguments: dict = {}): def notify_client_multiple(self, client: Client, texts: typing.List[str], additional_arguments: dict = {}):
@ -451,7 +453,7 @@ class Context:
for game_name, data in decoded_obj.get("datapackage", {}).items(): for game_name, data in decoded_obj.get("datapackage", {}).items():
if game_name in game_data_packages: if game_name in game_data_packages:
data = game_data_packages[game_name] data = game_data_packages[game_name]
logging.info(f"Loading embedded data package for game {game_name}") self.logger.info(f"Loading embedded data package for game {game_name}")
self.gamespackage[game_name] = data self.gamespackage[game_name] = data
self.item_name_groups[game_name] = data["item_name_groups"] self.item_name_groups[game_name] = data["item_name_groups"]
if "location_name_groups" in data: if "location_name_groups" in data:
@ -483,7 +485,7 @@ class Context:
with open(self.save_filename, "wb") as f: with open(self.save_filename, "wb") as f:
f.write(zlib.compress(encoded_save)) f.write(zlib.compress(encoded_save))
except Exception as e: except Exception as e:
logging.exception(e) self.logger.exception(e)
return False return False
else: else:
return True return True
@ -501,9 +503,9 @@ class Context:
save_data = restricted_loads(zlib.decompress(f.read())) save_data = restricted_loads(zlib.decompress(f.read()))
self.set_save(save_data) self.set_save(save_data)
except FileNotFoundError: except FileNotFoundError:
logging.error('No save data found, starting a new game') self.logger.error('No save data found, starting a new game')
except Exception as e: except Exception as e:
logging.exception(e) self.logger.exception(e)
self._start_async_saving() self._start_async_saving()
def _start_async_saving(self): def _start_async_saving(self):
@ -520,11 +522,11 @@ class Context:
next_wakeup = (second - get_datetime_second()) % self.auto_save_interval next_wakeup = (second - get_datetime_second()) % self.auto_save_interval
time.sleep(max(1.0, next_wakeup)) time.sleep(max(1.0, next_wakeup))
if self.save_dirty: if self.save_dirty:
logging.debug("Saving via thread.") self.logger.debug("Saving via thread.")
self._save() self._save()
except OperationalError as e: except OperationalError as e:
logging.exception(e) self.logger.exception(e)
logging.info(f"Saving failed. Retry in {self.auto_save_interval} seconds.") self.logger.info(f"Saving failed. Retry in {self.auto_save_interval} seconds.")
else: else:
self.save_dirty = False self.save_dirty = False
self.auto_saver_thread = threading.Thread(target=save_regularly, daemon=True) self.auto_saver_thread = threading.Thread(target=save_regularly, daemon=True)
@ -598,7 +600,7 @@ class Context:
if "stored_data" in savedata: if "stored_data" in savedata:
self.stored_data = savedata["stored_data"] self.stored_data = savedata["stored_data"]
# count items and slots from lists for items_handling = remote # count items and slots from lists for items_handling = remote
logging.info( self.logger.info(
f'Loaded save file with {sum([len(v) for k, v in self.received_items.items() if k[2]])} received items ' f'Loaded save file with {sum([len(v) for k, v in self.received_items.items() if k[2]])} received items '
f'for {sum(k[2] for k in self.received_items)} players') f'for {sum(k[2] for k in self.received_items)} players')
@ -640,13 +642,13 @@ class Context:
try: try:
raise Exception(f"Could not set server option {key}, skipping.") from e raise Exception(f"Could not set server option {key}, skipping.") from e
except Exception as e: except Exception as e:
logging.exception(e) self.logger.exception(e)
logging.debug(f"Setting server option {key} to {value} from supplied multidata") self.logger.debug(f"Setting server option {key} to {value} from supplied multidata")
setattr(self, key, value) setattr(self, key, value)
elif key == "disable_item_cheat": elif key == "disable_item_cheat":
self.item_cheat = not bool(value) self.item_cheat = not bool(value)
else: else:
logging.debug(f"Unrecognized server option {key}") self.logger.debug(f"Unrecognized server option {key}")
def get_aliased_name(self, team: int, slot: int): def get_aliased_name(self, team: int, slot: int):
if (team, slot) in self.name_aliases: if (team, slot) in self.name_aliases:
@ -680,7 +682,7 @@ class Context:
self.hints[team, player].add(hint) self.hints[team, player].add(hint)
new_hint_events.add(player) new_hint_events.add(player)
logging.info("Notice (Team #%d): %s" % (team + 1, format_hint(self, team, hint))) self.logger.info("Notice (Team #%d): %s" % (team + 1, format_hint(self, team, hint)))
for slot in new_hint_events: for slot in new_hint_events:
self.on_new_hint(team, slot) self.on_new_hint(team, slot)
for slot, hint_data in concerns.items(): for slot, hint_data in concerns.items():
@ -739,21 +741,21 @@ async def server(websocket, path: str = "/", ctx: Context = None):
try: try:
if ctx.log_network: if ctx.log_network:
logging.info("Incoming connection") ctx.logger.info("Incoming connection")
await on_client_connected(ctx, client) await on_client_connected(ctx, client)
if ctx.log_network: if ctx.log_network:
logging.info("Sent Room Info") ctx.logger.info("Sent Room Info")
async for data in websocket: async for data in websocket:
if ctx.log_network: if ctx.log_network:
logging.info(f"Incoming message: {data}") ctx.logger.info(f"Incoming message: {data}")
for msg in decode(data): for msg in decode(data):
await process_client_cmd(ctx, client, msg) await process_client_cmd(ctx, client, msg)
except Exception as e: except Exception as e:
if not isinstance(e, websockets.WebSocketException): if not isinstance(e, websockets.WebSocketException):
logging.exception(e) ctx.logger.exception(e)
finally: finally:
if ctx.log_network: if ctx.log_network:
logging.info("Disconnected") ctx.logger.info("Disconnected")
await ctx.disconnect(client) await ctx.disconnect(client)
@ -985,7 +987,7 @@ def register_location_checks(ctx: Context, team: int, slot: int, locations: typi
new_item = NetworkItem(item_id, location, slot, flags) new_item = NetworkItem(item_id, location, slot, flags)
send_items_to(ctx, team, target_player, new_item) send_items_to(ctx, team, target_player, new_item)
logging.info('(Team #%d) %s sent %s to %s (%s)' % ( ctx.logger.info('(Team #%d) %s sent %s to %s (%s)' % (
team + 1, ctx.player_names[(team, slot)], ctx.item_names[item_id], team + 1, ctx.player_names[(team, slot)], ctx.item_names[item_id],
ctx.player_names[(team, target_player)], ctx.location_names[location])) ctx.player_names[(team, target_player)], ctx.location_names[location]))
info_text = json_format_send_event(new_item, target_player) info_text = json_format_send_event(new_item, target_player)
@ -1625,7 +1627,7 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
try: try:
cmd: str = args["cmd"] cmd: str = args["cmd"]
except: except:
logging.exception(f"Could not get command from {args}") ctx.logger.exception(f"Could not get command from {args}")
await ctx.send_msgs(client, [{'cmd': 'InvalidPacket', "type": "cmd", "original_cmd": None, await ctx.send_msgs(client, [{'cmd': 'InvalidPacket', "type": "cmd", "original_cmd": None,
"text": f"Could not get command from {args} at `cmd`"}]) "text": f"Could not get command from {args} at `cmd`"}])
raise raise
@ -1668,7 +1670,7 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict):
if ctx.compatibility == 0 and args['version'] != version_tuple: if ctx.compatibility == 0 and args['version'] != version_tuple:
errors.add('IncompatibleVersion') errors.add('IncompatibleVersion')
if errors: if errors:
logging.info(f"A client connection was refused due to: {errors}, the sent connect information was {args}.") ctx.logger.info(f"A client connection was refused due to: {errors}, the sent connect information was {args}.")
await ctx.send_msgs(client, [{"cmd": "ConnectionRefused", "errors": list(errors)}]) await ctx.send_msgs(client, [{"cmd": "ConnectionRefused", "errors": list(errors)}])
else: else:
team, slot = ctx.connect_names[args['name']] team, slot = ctx.connect_names[args['name']]
@ -2286,7 +2288,7 @@ async def auto_shutdown(ctx, to_cancel=None):
if to_cancel: if to_cancel:
for task in to_cancel: for task in to_cancel:
task.cancel() task.cancel()
logging.info("Shutting down due to inactivity.") ctx.logger.info("Shutting down due to inactivity.")
while not ctx.exit_event.is_set(): while not ctx.exit_event.is_set():
if not ctx.client_activity_timers.values(): if not ctx.client_activity_timers.values():

View File

@ -23,6 +23,7 @@ app.jinja_env.filters['all'] = all
app.config["SELFHOST"] = True # application process is in charge of running the websites app.config["SELFHOST"] = True # application process is in charge of running the websites
app.config["GENERATORS"] = 8 # maximum concurrent world gens app.config["GENERATORS"] = 8 # maximum concurrent world gens
app.config["HOSTERS"] = 8 # maximum concurrent room hosters
app.config["SELFLAUNCH"] = True # application process is in charge of launching Rooms. app.config["SELFLAUNCH"] = True # application process is in charge of launching Rooms.
app.config["SELFLAUNCHCERT"] = None # can point to a SSL Certificate to encrypt Room websocket connections app.config["SELFLAUNCHCERT"] = None # can point to a SSL Certificate to encrypt Room websocket connections
app.config["SELFLAUNCHKEY"] = None # can point to a SSL Certificate Key to encrypt Room websocket connections app.config["SELFLAUNCHKEY"] = None # can point to a SSL Certificate Key to encrypt Room websocket connections

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import json import json
import logging import logging
import multiprocessing import multiprocessing
import threading
import time import time
import typing import typing
from uuid import UUID from uuid import UUID
@ -15,16 +14,6 @@ from Utils import restricted_loads
from .locker import Locker, AlreadyRunningException from .locker import Locker, AlreadyRunningException
def launch_room(room: Room, config: dict):
# requires db_session!
if room.last_activity >= datetime.utcnow() - timedelta(seconds=room.timeout):
multiworld = multiworlds.get(room.id, None)
if not multiworld:
multiworld = MultiworldInstance(room, config)
multiworld.start()
def handle_generation_success(seed_id): def handle_generation_success(seed_id):
logging.info(f"Generation finished for seed {seed_id}") logging.info(f"Generation finished for seed {seed_id}")
@ -59,11 +48,8 @@ def init_db(pony_config: dict):
db.generate_mapping() db.generate_mapping()
def autohost(config: dict): def cleanup():
def keep_running(): """delete unowned user-content"""
try:
with Locker("autohost"):
# delete unowned user-content
with db_session: with db_session:
# >>> bool(uuid.UUID(int=0)) # >>> bool(uuid.UUID(int=0))
# True # True
@ -73,7 +59,19 @@ def autohost(config: dict):
# Command gets deleted by ponyorm Cascade Delete, as Room is Required # Command gets deleted by ponyorm Cascade Delete, as Room is Required
if rooms or seeds or slots: if rooms or seeds or slots:
logging.info(f"{rooms} Rooms, {seeds} Seeds and {slots} Slots have been deleted.") logging.info(f"{rooms} Rooms, {seeds} Seeds and {slots} Slots have been deleted.")
run_guardian()
def autohost(config: dict):
def keep_running():
try:
with Locker("autohost"):
cleanup()
hosters = []
for x in range(config["HOSTERS"]):
hoster = MultiworldInstance(config, x)
hosters.append(hoster)
hoster.start()
while 1: while 1:
time.sleep(0.1) time.sleep(0.1)
with db_session: with db_session:
@ -81,7 +79,9 @@ def autohost(config: dict):
room for room in Room if room for room in Room if
room.last_activity >= datetime.utcnow() - timedelta(days=3)) room.last_activity >= datetime.utcnow() - timedelta(days=3))
for room in rooms: for room in rooms:
launch_room(room, config) # we have to filter twice, as the per-room timeout can't currently be PonyORM transpiled.
if room.last_activity >= datetime.utcnow() - timedelta(seconds=room.timeout):
hosters[room.id.int % len(hosters)].start_room(room.id)
except AlreadyRunningException: except AlreadyRunningException:
logging.info("Autohost reports as already running, not starting another.") logging.info("Autohost reports as already running, not starting another.")
@ -132,29 +132,38 @@ multiworlds: typing.Dict[type(Room.id), MultiworldInstance] = {}
class MultiworldInstance(): class MultiworldInstance():
def __init__(self, room: Room, config: dict): def __init__(self, config: dict, id: int):
self.room_id = room.id self.room_ids = set()
self.process: typing.Optional[multiprocessing.Process] = None self.process: typing.Optional[multiprocessing.Process] = None
with guardian_lock:
multiworlds[self.room_id] = self
self.ponyconfig = config["PONY"] self.ponyconfig = config["PONY"]
self.cert = config["SELFLAUNCHCERT"] self.cert = config["SELFLAUNCHCERT"]
self.key = config["SELFLAUNCHKEY"] self.key = config["SELFLAUNCHKEY"]
self.host = config["HOST_ADDRESS"] self.host = config["HOST_ADDRESS"]
self.rooms_to_start = multiprocessing.Queue()
self.rooms_shutting_down = multiprocessing.Queue()
self.name = f"MultiHoster{id}"
def start(self): def start(self):
if self.process and self.process.is_alive(): if self.process and self.process.is_alive():
return False return False
logging.info(f"Spinning up {self.room_id}")
process = multiprocessing.Process(group=None, target=run_server_process, process = multiprocessing.Process(group=None, target=run_server_process,
args=(self.room_id, self.ponyconfig, get_static_server_data(), args=(self.name, self.ponyconfig, get_static_server_data(),
self.cert, self.key, self.host), self.cert, self.key, self.host,
name="MultiHost") self.rooms_to_start, self.rooms_shutting_down),
name=self.name)
process.start() process.start()
# bind after start to prevent thread sync issues with guardian.
self.process = process self.process = process
def start_room(self, room_id):
while not self.rooms_shutting_down.empty():
self.room_ids.remove(self.rooms_shutting_down.get(block=True, timeout=None))
if room_id in self.room_ids:
pass # should already be hosted currently.
else:
self.room_ids.add(room_id)
self.rooms_to_start.put(room_id)
def stop(self): def stop(self):
if self.process: if self.process:
self.process.terminate() self.process.terminate()
@ -168,40 +177,6 @@ class MultiworldInstance():
self.process = None self.process = None
guardian = None
guardian_lock = threading.Lock()
def run_guardian():
global guardian
global multiworlds
with guardian_lock:
if not guardian:
try:
import resource
except ModuleNotFoundError:
pass # unix only module
else:
# Each Server is another file handle, so request as many as we can from the system
file_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
# set soft limit to hard limit
resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit))
def guard():
while 1:
time.sleep(1)
done = []
with guardian_lock:
for key, instance in multiworlds.items():
if instance.done():
instance.collect()
done.append(key)
for key in done:
del (multiworlds[key])
guardian = threading.Thread(name="Guardian", target=guard)
from .models import Room, Generation, STATE_QUEUED, STATE_STARTED, STATE_ERROR, db, Seed, Slot from .models import Room, Generation, STATE_QUEUED, STATE_STARTED, STATE_ERROR, db, Seed, Slot
from .customserver import run_server_process, get_static_server_data from .customserver import run_server_process, get_static_server_data
from .generate import gen_game from .generate import gen_game

View File

@ -5,6 +5,7 @@ import collections
import datetime import datetime
import functools import functools
import logging import logging
import multiprocessing
import pickle import pickle
import random import random
import socket import socket
@ -53,17 +54,19 @@ del MultiServer
class DBCommandProcessor(ServerCommandProcessor): class DBCommandProcessor(ServerCommandProcessor):
def output(self, text: str): def output(self, text: str):
logging.info(text) self.ctx.logger.info(text)
class WebHostContext(Context): class WebHostContext(Context):
room_id: int room_id: int
def __init__(self, static_server_data: dict): def __init__(self, static_server_data: dict, logger: logging.Logger):
# static server data is used during _load_game_data to load required data, # static server data is used during _load_game_data to load required data,
# without needing to import worlds system, which takes quite a bit of memory # without needing to import worlds system, which takes quite a bit of memory
self.static_server_data = static_server_data self.static_server_data = static_server_data
super(WebHostContext, self).__init__("", 0, "", "", 1, 40, True, "enabled", "enabled", "enabled", 0, 2) super(WebHostContext, self).__init__("", 0, "", "", 1,
40, True, "enabled", "enabled",
"enabled", 0, 2, logger=logger)
del self.static_server_data del self.static_server_data
self.main_loop = asyncio.get_running_loop() self.main_loop = asyncio.get_running_loop()
self.video = {} self.video = {}
@ -159,24 +162,61 @@ def get_static_server_data() -> dict:
return data return data
def run_server_process(room_id, ponyconfig: dict, static_server_data: dict, def set_up_logging(room_id) -> logging.Logger:
import os
# logger setup
logger = logging.getLogger(f"RoomLogger {room_id}")
# this *should* be empty, but just in case.
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler.close()
file_handler = logging.FileHandler(
os.path.join(Utils.user_path("logs"), f"{room_id}.txt"),
"a",
encoding="utf-8-sig")
file_handler.setFormatter(logging.Formatter("[%(asctime)s]: %(message)s"))
logger.setLevel(logging.INFO)
logger.addHandler(file_handler)
return logger
def run_server_process(name: str, ponyconfig: dict, static_server_data: dict,
cert_file: typing.Optional[str], cert_key_file: typing.Optional[str], cert_file: typing.Optional[str], cert_key_file: typing.Optional[str],
host: str): host: str, rooms_to_run: multiprocessing.Queue, rooms_shutting_down: multiprocessing.Queue):
Utils.init_logging(name)
try:
import resource
except ModuleNotFoundError:
pass # unix only module
else:
# Each Server is another file handle, so request as many as we can from the system
file_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
# set soft limit to hard limit
resource.setrlimit(resource.RLIMIT_NOFILE, (file_limit, file_limit))
del resource, file_limit
# establish DB connection for multidata and multisave # establish DB connection for multidata and multisave
db.bind(**ponyconfig) db.bind(**ponyconfig)
db.generate_mapping(check_tables=False) db.generate_mapping(check_tables=False)
async def main():
if "worlds" in sys.modules: if "worlds" in sys.modules:
raise Exception("Worlds system should not be loaded in the custom server.") raise Exception("Worlds system should not be loaded in the custom server.")
import gc import gc
Utils.init_logging(str(room_id), write_mode="a") ssl_context = load_server_cert(cert_file, cert_key_file) if cert_file else None
ctx = WebHostContext(static_server_data) del cert_file, cert_key_file, ponyconfig
gc.collect() # free intermediate objects used during setup
loop = asyncio.get_event_loop()
async def start_room(room_id):
try:
logger = set_up_logging(room_id)
ctx = WebHostContext(static_server_data, logger)
ctx.load(room_id) ctx.load(room_id)
ctx.init_save() ctx.init_save()
ssl_context = load_server_cert(cert_file, cert_key_file) if cert_file else None
gc.collect() # free intermediate objects used during setup
try: try:
ctx.server = websockets.serve(functools.partial(server, ctx=ctx), ctx.host, ctx.port, ssl=ssl_context) ctx.server = websockets.serve(functools.partial(server, ctx=ctx), ctx.host, ctx.port, ssl=ssl_context)
@ -195,12 +235,12 @@ def run_server_process(room_id, ponyconfig: dict, static_server_data: dict,
elif wssocket.family == socket.AF_INET: elif wssocket.family == socket.AF_INET:
port = socketname[1] port = socketname[1]
if port: if port:
logging.info(f'Hosting game at {host}:{port}') ctx.logger.info(f'Hosting game at {host}:{port}')
with db_session: with db_session:
room = Room.get(id=ctx.room_id) room = Room.get(id=ctx.room_id)
room.last_port = port room.last_port = port
else: else:
logging.exception("Could not determine port. Likely hosting failure.") ctx.logger.exception("Could not determine port. Likely hosting failure.")
with db_session: with db_session:
ctx.auto_shutdown = Room.get(id=room_id).timeout ctx.auto_shutdown = Room.get(id=room_id).timeout
ctx.shutdown_task = asyncio.create_task(auto_shutdown(ctx, [])) ctx.shutdown_task = asyncio.create_task(auto_shutdown(ctx, []))
@ -211,11 +251,6 @@ def run_server_process(room_id, ponyconfig: dict, static_server_data: dict,
room: Room = Room.get(id=ctx.room_id) room: Room = Room.get(id=ctx.room_id)
room.last_activity = datetime.datetime.utcnow() - datetime.timedelta(seconds=room.timeout + 60) room.last_activity = datetime.datetime.utcnow() - datetime.timedelta(seconds=room.timeout + 60)
logging.info("Shutting down")
with Locker(room_id):
try:
asyncio.run(main())
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
with db_session: with db_session:
room = Room.get(id=room_id) room = Room.get(id=room_id)
@ -228,3 +263,17 @@ def run_server_process(room_id, ponyconfig: dict, static_server_data: dict,
# ensure the Room does not spin up again on its own, minute of safety buffer # ensure the Room does not spin up again on its own, minute of safety buffer
room.last_activity = datetime.datetime.utcnow() - datetime.timedelta(minutes=1, seconds=room.timeout) room.last_activity = datetime.datetime.utcnow() - datetime.timedelta(minutes=1, seconds=room.timeout)
raise raise
finally:
rooms_shutting_down.put(room_id)
class Starter(threading.Thread):
def run(self):
while 1:
next_room = rooms_to_run.get(block=True, timeout=None)
asyncio.run_coroutine_threadsafe(start_room(next_room), loop)
logging.info(f"Starting room {next_room} on {name}.")
starter = Starter()
starter.daemon = True
starter.start()
loop.run_forever()