From f4f043ac874c7b2aa508b4f7f39c5fcfd745f585 Mon Sep 17 00:00:00 2001 From: Fabian Dill Date: Thu, 26 Aug 2021 16:19:37 +0200 Subject: [PATCH] MultiServer: categorize methods --- Main.py | 1 - MultiServer.py | 183 ++++++++++++++++++++++++++++++------------------- NetUtils.py | 48 ------------- 3 files changed, 114 insertions(+), 118 deletions(-) diff --git a/Main.py b/Main.py index 3d3cb84b..ea408cd5 100644 --- a/Main.py +++ b/Main.py @@ -377,7 +377,6 @@ def main(args, seed=None): f.write(bytes([1])) # version of format f.write(multidata) - multidata_task = pool.submit(write_multidata) if not check_accessibility_task.result(): if not world.can_beat_game(): diff --git a/MultiServer.py b/MultiServer.py index 38cda876..aa760da6 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -31,10 +31,11 @@ from worlds import network_data_package, lookup_any_item_id_to_name, lookup_any_ import Utils from Utils import get_item_name_from_id, get_location_name_from_id, \ version_tuple, restricted_loads, Version -from NetUtils import Node, Endpoint, ClientStatus, NetworkItem, decode, NetworkPlayer +from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer colorama.init() + class Client(Endpoint): version = Version(0, 0, 0) tags: typing.List[str] = [] @@ -50,9 +51,14 @@ class Client(Endpoint): self.messageprocessor = client_message_processor(ctx, self) self.ctx = weakref.ref(ctx) + team_slot = typing.Tuple[int, int] -class Context(Node): + +class Context: + dumper = staticmethod(encode) + loader = staticmethod(decode) + simple_options = {"hint_cost": int, "location_check_points": int, "server_password": str, @@ -64,8 +70,10 @@ class Context(Node): def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int, hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", remaining_mode: str = "disabled", - auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2): + auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2, log_network: bool = False): super(Context, self).__init__() + self.log_network = log_network + self.endpoints = [] self.compatibility: int = compatibility self.shutdown_task = None self.data_filename = None @@ -113,10 +121,70 @@ class Context(Node): self.seed_name = "" self.random = random.Random() - def get_hint_cost(self, slot): - if self.hint_cost: - return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot]))) - return 0 + # General networking + + async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool: + if not endpoint.socket or not endpoint.socket.open: + return False + msg = self.dumper(msgs) + try: + await endpoint.socket.send(msg) + except websockets.ConnectionClosed: + logging.exception(f"Exception during send_msgs, could not send {msg}") + await self.disconnect(endpoint) + else: + if self.log_network: + logging.info(f"Outgoing message: {msg}") + return True + + async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool: + if not endpoint.socket or not endpoint.socket.open: + return False + try: + await endpoint.socket.send(msg) + except websockets.ConnectionClosed: + logging.exception("Exception during send_encoded_msgs") + await self.disconnect(endpoint) + else: + if self.log_network: + logging.info(f"Outgoing message: {msg}") + return True + + def broadcast_all(self, msgs): + msgs = self.dumper(msgs) + for endpoint in self.endpoints: + if endpoint.auth: + asyncio.create_task(self.send_encoded_msgs(endpoint, msgs)) + + def broadcast_team(self, team, msgs): + msgs = self.dumper(msgs) + for client in self.endpoints: + if client.auth and client.team == team: + asyncio.create_task(self.send_encoded_msgs(client, msgs)) + + async def disconnect(self, endpoint): + if endpoint in self.endpoints: + self.endpoints.remove(endpoint) + await on_client_disconnected(self, endpoint) + + # text + + def notify_all(self, text): + logging.info("Notice (all): %s" % text) + self.broadcast_all([{"cmd": "Print", "text": text}]) + + def notify_client(self, client: Client, text: str): + if not client.auth: + return + logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text)) + asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}])) + + def notify_client_multiple(self, client: Client, texts: typing.List[str]): + if not client.auth: + return + asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts])) + + # loading def load(self, multidatapath: str, use_embedded_server_options: bool = False): if multidatapath.lower().endswith(".zip"): @@ -177,27 +245,7 @@ class Context(Node): server_options = decoded_obj.get("server_options", {}) self._set_options(server_options) - def get_players_package(self): - return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()] - - def _set_options(self, server_options: dict): - for key, value in server_options.items(): - data_type = self.simple_options.get(key, None) - if data_type is not None: - if value not in {False, True, None}: # some can be boolean OR text, such as password - try: - value = data_type(value) - except Exception as e: - try: - raise Exception(f"Could not set server option {key}, skipping.") from e - except Exception as e: - logging.exception(e) - logging.debug(f"Setting server option {key} to {value} from supplied multidata") - setattr(self, key, value) - elif key == "disable_item_cheat": - self.item_cheat = not bool(value) - else: - logging.debug(f"Unrecognized server option {key}") + # saving def save(self, now=False) -> bool: if self.saving: @@ -228,7 +276,7 @@ class Context(Node): import os name, ext = os.path.splitext(self.data_filename) self.save_filename = name + '.apsave' if ext.lower() in ('.archipelago','.zip') \ - else self.data_filename + '_' + 'apsave' + else self.data_filename + '_' + 'apsave' try: with open(self.save_filename, 'rb') as f: save_data = restricted_loads(zlib.decompress(f.read())) @@ -256,13 +304,6 @@ class Context(Node): import atexit atexit.register(self._save, True) # make sure we save on exit too - def recheck_hints(self): - for team, slot in self.hints: - self.hints[team, slot] = { - hint.re_check(self, team) for hint in - self.hints[team, slot] - } - def get_save(self) -> dict: self.recheck_hints() d = { @@ -303,43 +344,48 @@ class Context(Node): logging.info(f'Loaded save file with {sum([len(p) for p in self.received_items.values()])} received items ' f'for {len(self.received_items)} players') + # rest + + def get_hint_cost(self, slot): + if self.hint_cost: + return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot]))) + return 0 + + def recheck_hints(self): + for team, slot in self.hints: + self.hints[team, slot] = { + hint.re_check(self, team) for hint in + self.hints[team, slot] + } + + def get_players_package(self): + return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()] + + def _set_options(self, server_options: dict): + for key, value in server_options.items(): + data_type = self.simple_options.get(key, None) + if data_type is not None: + if value not in {False, True, None}: # some can be boolean OR text, such as password + try: + value = data_type(value) + except Exception as e: + try: + raise Exception(f"Could not set server option {key}, skipping.") from e + except Exception as e: + logging.exception(e) + logging.debug(f"Setting server option {key} to {value} from supplied multidata") + setattr(self, key, value) + elif key == "disable_item_cheat": + self.item_cheat = not bool(value) + else: + logging.debug(f"Unrecognized server option {key}") + def get_aliased_name(self, team: int, slot: int): if (team, slot) in self.name_aliases: return f"{self.name_aliases[team, slot]} ({self.player_names[team, slot]})" else: return self.player_names[team, slot] - def notify_all(self, text): - logging.info("Notice (all): %s" % text) - self.broadcast_all([{"cmd": "Print", "text": text}]) - - def notify_client(self, client: Client, text: str): - if not client.auth: - return - logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text)) - asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}])) - - def notify_client_multiple(self, client: Client, texts: typing.List[str]): - if not client.auth: - return - asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts])) - - def broadcast_team(self, team, msgs): - msgs = self.dumper(msgs) - for client in self.endpoints: - if client.auth and client.team == team: - asyncio.create_task(self.send_encoded_msgs(client, msgs)) - - def broadcast_all(self, msgs): - msgs = self.dumper(msgs) - for endpoint in self.endpoints: - if endpoint.auth: - asyncio.create_task(self.send_encoded_msgs(endpoint, msgs)) - - async def disconnect(self, endpoint): - await super(Context, self).disconnect(endpoint) - await on_client_disconnected(self, endpoint) - def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]): concerns = collections.defaultdict(list) @@ -1431,8 +1477,7 @@ async def main(args: argparse.Namespace): ctx = Context(args.host, args.port, args.server_password, args.password, args.location_check_points, args.hint_cost, not args.disable_item_cheat, args.forfeit_mode, args.remaining_mode, - args.auto_shutdown, args.compatibility) - ctx.log_network = args.log_network + args.auto_shutdown, args.compatibility, args.log_network) data_filename = args.multidata try: diff --git a/NetUtils.py b/NetUtils.py index a1e57f85..248a05c3 100644 --- a/NetUtils.py +++ b/NetUtils.py @@ -1,6 +1,4 @@ from __future__ import annotations -import asyncio -import logging import typing import enum from json import JSONEncoder, JSONDecoder @@ -94,52 +92,6 @@ def _object_hook(o: typing.Any) -> typing.Any: decode = JSONDecoder(object_hook=_object_hook).decode -class Node: - endpoints: typing.List - dumper = staticmethod(encode) - loader = staticmethod(decode) - - def __init__(self): - self.endpoints = [] - super(Node, self).__init__() - self.log_network = 0 - - def broadcast_all(self, msgs): - msgs = self.dumper(msgs) - for endpoint in self.endpoints: - asyncio.create_task(self.send_encoded_msgs(endpoint, msgs)) - - async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool: - if not endpoint.socket or not endpoint.socket.open: - return False - msg = self.dumper(msgs) - try: - await endpoint.socket.send(msg) - except websockets.ConnectionClosed: - logging.exception(f"Exception during send_msgs, could not send {msg}") - await self.disconnect(endpoint) - else: - if self.log_network: - logging.info(f"Outgoing message: {msg}") - return True - - async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool: - if not endpoint.socket or not endpoint.socket.open: - return False - try: - await endpoint.socket.send(msg) - except websockets.ConnectionClosed: - logging.exception("Exception during send_encoded_msgs") - await self.disconnect(endpoint) - else: - if self.log_network: - logging.info(f"Outgoing message: {msg}") - return True - - async def disconnect(self, endpoint): - if endpoint in self.endpoints: - self.endpoints.remove(endpoint) - class Endpoint: socket: websockets.WebSocketServerProtocol