MultiServer: categorize methods
This commit is contained in:
		
							parent
							
								
									acbca78e2d
								
							
						
					
					
						commit
						f4f043ac87
					
				
							
								
								
									
										1
									
								
								Main.py
								
								
								
								
							
							
						
						
									
										1
									
								
								Main.py
								
								
								
								
							| 
						 | 
				
			
			@ -377,7 +377,6 @@ def main(args, seed=None):
 | 
			
		|||
                f.write(bytes([1]))  # version of format
 | 
			
		||||
                f.write(multidata)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        multidata_task = pool.submit(write_multidata)
 | 
			
		||||
        if not check_accessibility_task.result():
 | 
			
		||||
            if not world.can_beat_game():
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										181
									
								
								MultiServer.py
								
								
								
								
							
							
						
						
									
										181
									
								
								MultiServer.py
								
								
								
								
							| 
						 | 
				
			
			@ -31,10 +31,11 @@ from worlds import network_data_package, lookup_any_item_id_to_name, lookup_any_
 | 
			
		|||
import Utils
 | 
			
		||||
from Utils import get_item_name_from_id, get_location_name_from_id, \
 | 
			
		||||
    version_tuple, restricted_loads, Version
 | 
			
		||||
from NetUtils import Node, Endpoint, ClientStatus, NetworkItem, decode, NetworkPlayer
 | 
			
		||||
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer
 | 
			
		||||
 | 
			
		||||
colorama.init()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Client(Endpoint):
 | 
			
		||||
    version = Version(0, 0, 0)
 | 
			
		||||
    tags: typing.List[str] = []
 | 
			
		||||
| 
						 | 
				
			
			@ -50,9 +51,14 @@ class Client(Endpoint):
 | 
			
		|||
        self.messageprocessor = client_message_processor(ctx, self)
 | 
			
		||||
        self.ctx = weakref.ref(ctx)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
team_slot = typing.Tuple[int, int]
 | 
			
		||||
 | 
			
		||||
class Context(Node):
 | 
			
		||||
 | 
			
		||||
class Context:
 | 
			
		||||
    dumper = staticmethod(encode)
 | 
			
		||||
    loader = staticmethod(decode)
 | 
			
		||||
 | 
			
		||||
    simple_options = {"hint_cost": int,
 | 
			
		||||
                      "location_check_points": int,
 | 
			
		||||
                      "server_password": str,
 | 
			
		||||
| 
						 | 
				
			
			@ -64,8 +70,10 @@ class Context(Node):
 | 
			
		|||
 | 
			
		||||
    def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int,
 | 
			
		||||
                 hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", remaining_mode: str = "disabled",
 | 
			
		||||
                 auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2):
 | 
			
		||||
                 auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2, log_network: bool = False):
 | 
			
		||||
        super(Context, self).__init__()
 | 
			
		||||
        self.log_network = log_network
 | 
			
		||||
        self.endpoints = []
 | 
			
		||||
        self.compatibility: int = compatibility
 | 
			
		||||
        self.shutdown_task = None
 | 
			
		||||
        self.data_filename = None
 | 
			
		||||
| 
						 | 
				
			
			@ -113,10 +121,70 @@ class Context(Node):
 | 
			
		|||
        self.seed_name = ""
 | 
			
		||||
        self.random = random.Random()
 | 
			
		||||
 | 
			
		||||
    def get_hint_cost(self, slot):
 | 
			
		||||
        if self.hint_cost:
 | 
			
		||||
            return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot])))
 | 
			
		||||
        return 0
 | 
			
		||||
    # General networking
 | 
			
		||||
 | 
			
		||||
    async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
 | 
			
		||||
        if not endpoint.socket or not endpoint.socket.open:
 | 
			
		||||
            return False
 | 
			
		||||
        msg = self.dumper(msgs)
 | 
			
		||||
        try:
 | 
			
		||||
            await endpoint.socket.send(msg)
 | 
			
		||||
        except websockets.ConnectionClosed:
 | 
			
		||||
            logging.exception(f"Exception during send_msgs, could not send {msg}")
 | 
			
		||||
            await self.disconnect(endpoint)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.log_network:
 | 
			
		||||
                logging.info(f"Outgoing message: {msg}")
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
    async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
 | 
			
		||||
        if not endpoint.socket or not endpoint.socket.open:
 | 
			
		||||
            return False
 | 
			
		||||
        try:
 | 
			
		||||
            await endpoint.socket.send(msg)
 | 
			
		||||
        except websockets.ConnectionClosed:
 | 
			
		||||
            logging.exception("Exception during send_encoded_msgs")
 | 
			
		||||
            await self.disconnect(endpoint)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.log_network:
 | 
			
		||||
                logging.info(f"Outgoing message: {msg}")
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
    def broadcast_all(self, msgs):
 | 
			
		||||
        msgs = self.dumper(msgs)
 | 
			
		||||
        for endpoint in self.endpoints:
 | 
			
		||||
            if endpoint.auth:
 | 
			
		||||
                asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
 | 
			
		||||
 | 
			
		||||
    def broadcast_team(self, team, msgs):
 | 
			
		||||
        msgs = self.dumper(msgs)
 | 
			
		||||
        for client in self.endpoints:
 | 
			
		||||
            if client.auth and client.team == team:
 | 
			
		||||
                asyncio.create_task(self.send_encoded_msgs(client, msgs))
 | 
			
		||||
 | 
			
		||||
    async def disconnect(self, endpoint):
 | 
			
		||||
        if endpoint in self.endpoints:
 | 
			
		||||
            self.endpoints.remove(endpoint)
 | 
			
		||||
        await on_client_disconnected(self, endpoint)
 | 
			
		||||
 | 
			
		||||
    # text
 | 
			
		||||
 | 
			
		||||
    def notify_all(self, text):
 | 
			
		||||
        logging.info("Notice (all): %s" % text)
 | 
			
		||||
        self.broadcast_all([{"cmd": "Print", "text": text}])
 | 
			
		||||
 | 
			
		||||
    def notify_client(self, client: Client, text: str):
 | 
			
		||||
        if not client.auth:
 | 
			
		||||
            return
 | 
			
		||||
        logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
 | 
			
		||||
        asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
 | 
			
		||||
 | 
			
		||||
    def notify_client_multiple(self, client: Client, texts: typing.List[str]):
 | 
			
		||||
        if not client.auth:
 | 
			
		||||
            return
 | 
			
		||||
        asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
 | 
			
		||||
 | 
			
		||||
    # loading
 | 
			
		||||
 | 
			
		||||
    def load(self, multidatapath: str, use_embedded_server_options: bool = False):
 | 
			
		||||
        if multidatapath.lower().endswith(".zip"):
 | 
			
		||||
| 
						 | 
				
			
			@ -177,27 +245,7 @@ class Context(Node):
 | 
			
		|||
            server_options = decoded_obj.get("server_options", {})
 | 
			
		||||
            self._set_options(server_options)
 | 
			
		||||
 | 
			
		||||
    def get_players_package(self):
 | 
			
		||||
        return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
 | 
			
		||||
 | 
			
		||||
    def _set_options(self, server_options: dict):
 | 
			
		||||
        for key, value in server_options.items():
 | 
			
		||||
            data_type = self.simple_options.get(key, None)
 | 
			
		||||
            if data_type is not None:
 | 
			
		||||
                if value not in {False, True, None}:  # some can be boolean OR text, such as password
 | 
			
		||||
                    try:
 | 
			
		||||
                        value = data_type(value)
 | 
			
		||||
                    except Exception as e:
 | 
			
		||||
                        try:
 | 
			
		||||
                            raise Exception(f"Could not set server option {key}, skipping.") from e
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            logging.exception(e)
 | 
			
		||||
                logging.debug(f"Setting server option {key} to {value} from supplied multidata")
 | 
			
		||||
                setattr(self, key, value)
 | 
			
		||||
            elif key == "disable_item_cheat":
 | 
			
		||||
                self.item_cheat = not bool(value)
 | 
			
		||||
            else:
 | 
			
		||||
                logging.debug(f"Unrecognized server option {key}")
 | 
			
		||||
    # saving
 | 
			
		||||
 | 
			
		||||
    def save(self, now=False) -> bool:
 | 
			
		||||
        if self.saving:
 | 
			
		||||
| 
						 | 
				
			
			@ -256,13 +304,6 @@ class Context(Node):
 | 
			
		|||
            import atexit
 | 
			
		||||
            atexit.register(self._save, True)  # make sure we save on exit too
 | 
			
		||||
 | 
			
		||||
    def recheck_hints(self):
 | 
			
		||||
        for team, slot in self.hints:
 | 
			
		||||
            self.hints[team, slot] = {
 | 
			
		||||
                hint.re_check(self, team) for hint in
 | 
			
		||||
                self.hints[team, slot]
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
    def get_save(self) -> dict:
 | 
			
		||||
        self.recheck_hints()
 | 
			
		||||
        d = {
 | 
			
		||||
| 
						 | 
				
			
			@ -303,43 +344,48 @@ class Context(Node):
 | 
			
		|||
        logging.info(f'Loaded save file with {sum([len(p) for p in self.received_items.values()])} received items '
 | 
			
		||||
                     f'for {len(self.received_items)} players')
 | 
			
		||||
 | 
			
		||||
    # rest
 | 
			
		||||
 | 
			
		||||
    def get_hint_cost(self, slot):
 | 
			
		||||
        if self.hint_cost:
 | 
			
		||||
            return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot])))
 | 
			
		||||
        return 0
 | 
			
		||||
 | 
			
		||||
    def recheck_hints(self):
 | 
			
		||||
        for team, slot in self.hints:
 | 
			
		||||
            self.hints[team, slot] = {
 | 
			
		||||
                hint.re_check(self, team) for hint in
 | 
			
		||||
                self.hints[team, slot]
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
    def get_players_package(self):
 | 
			
		||||
        return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
 | 
			
		||||
 | 
			
		||||
    def _set_options(self, server_options: dict):
 | 
			
		||||
        for key, value in server_options.items():
 | 
			
		||||
            data_type = self.simple_options.get(key, None)
 | 
			
		||||
            if data_type is not None:
 | 
			
		||||
                if value not in {False, True, None}:  # some can be boolean OR text, such as password
 | 
			
		||||
                    try:
 | 
			
		||||
                        value = data_type(value)
 | 
			
		||||
                    except Exception as e:
 | 
			
		||||
                        try:
 | 
			
		||||
                            raise Exception(f"Could not set server option {key}, skipping.") from e
 | 
			
		||||
                        except Exception as e:
 | 
			
		||||
                            logging.exception(e)
 | 
			
		||||
                logging.debug(f"Setting server option {key} to {value} from supplied multidata")
 | 
			
		||||
                setattr(self, key, value)
 | 
			
		||||
            elif key == "disable_item_cheat":
 | 
			
		||||
                self.item_cheat = not bool(value)
 | 
			
		||||
            else:
 | 
			
		||||
                logging.debug(f"Unrecognized server option {key}")
 | 
			
		||||
 | 
			
		||||
    def get_aliased_name(self, team: int, slot: int):
 | 
			
		||||
        if (team, slot) in self.name_aliases:
 | 
			
		||||
            return f"{self.name_aliases[team, slot]} ({self.player_names[team, slot]})"
 | 
			
		||||
        else:
 | 
			
		||||
            return self.player_names[team, slot]
 | 
			
		||||
 | 
			
		||||
    def notify_all(self, text):
 | 
			
		||||
        logging.info("Notice (all): %s" % text)
 | 
			
		||||
        self.broadcast_all([{"cmd": "Print", "text": text}])
 | 
			
		||||
 | 
			
		||||
    def notify_client(self, client: Client, text: str):
 | 
			
		||||
        if not client.auth:
 | 
			
		||||
            return
 | 
			
		||||
        logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
 | 
			
		||||
        asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
 | 
			
		||||
 | 
			
		||||
    def notify_client_multiple(self, client: Client, texts: typing.List[str]):
 | 
			
		||||
        if not client.auth:
 | 
			
		||||
            return
 | 
			
		||||
        asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
 | 
			
		||||
 | 
			
		||||
    def broadcast_team(self, team, msgs):
 | 
			
		||||
        msgs = self.dumper(msgs)
 | 
			
		||||
        for client in self.endpoints:
 | 
			
		||||
            if client.auth and client.team == team:
 | 
			
		||||
                asyncio.create_task(self.send_encoded_msgs(client, msgs))
 | 
			
		||||
 | 
			
		||||
    def broadcast_all(self, msgs):
 | 
			
		||||
        msgs = self.dumper(msgs)
 | 
			
		||||
        for endpoint in self.endpoints:
 | 
			
		||||
            if endpoint.auth:
 | 
			
		||||
                asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
 | 
			
		||||
 | 
			
		||||
    async def disconnect(self, endpoint):
 | 
			
		||||
        await super(Context, self).disconnect(endpoint)
 | 
			
		||||
        await on_client_disconnected(self, endpoint)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]):
 | 
			
		||||
    concerns = collections.defaultdict(list)
 | 
			
		||||
| 
						 | 
				
			
			@ -1431,8 +1477,7 @@ async def main(args: argparse.Namespace):
 | 
			
		|||
 | 
			
		||||
    ctx = Context(args.host, args.port, args.server_password, args.password, args.location_check_points,
 | 
			
		||||
                  args.hint_cost, not args.disable_item_cheat, args.forfeit_mode, args.remaining_mode,
 | 
			
		||||
                  args.auto_shutdown, args.compatibility)
 | 
			
		||||
    ctx.log_network = args.log_network
 | 
			
		||||
                  args.auto_shutdown, args.compatibility, args.log_network)
 | 
			
		||||
    data_filename = args.multidata
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										48
									
								
								NetUtils.py
								
								
								
								
							
							
						
						
									
										48
									
								
								NetUtils.py
								
								
								
								
							| 
						 | 
				
			
			@ -1,6 +1,4 @@
 | 
			
		|||
from __future__ import annotations
 | 
			
		||||
import asyncio
 | 
			
		||||
import logging
 | 
			
		||||
import typing
 | 
			
		||||
import enum
 | 
			
		||||
from json import JSONEncoder, JSONDecoder
 | 
			
		||||
| 
						 | 
				
			
			@ -94,52 +92,6 @@ def _object_hook(o: typing.Any) -> typing.Any:
 | 
			
		|||
decode = JSONDecoder(object_hook=_object_hook).decode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Node:
 | 
			
		||||
    endpoints: typing.List
 | 
			
		||||
    dumper = staticmethod(encode)
 | 
			
		||||
    loader = staticmethod(decode)
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.endpoints = []
 | 
			
		||||
        super(Node, self).__init__()
 | 
			
		||||
        self.log_network = 0
 | 
			
		||||
 | 
			
		||||
    def broadcast_all(self, msgs):
 | 
			
		||||
        msgs = self.dumper(msgs)
 | 
			
		||||
        for endpoint in self.endpoints:
 | 
			
		||||
            asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
 | 
			
		||||
 | 
			
		||||
    async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
 | 
			
		||||
        if not endpoint.socket or not endpoint.socket.open:
 | 
			
		||||
            return False
 | 
			
		||||
        msg = self.dumper(msgs)
 | 
			
		||||
        try:
 | 
			
		||||
            await endpoint.socket.send(msg)
 | 
			
		||||
        except websockets.ConnectionClosed:
 | 
			
		||||
            logging.exception(f"Exception during send_msgs, could not send {msg}")
 | 
			
		||||
            await self.disconnect(endpoint)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.log_network:
 | 
			
		||||
                logging.info(f"Outgoing message: {msg}")
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
    async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
 | 
			
		||||
        if not endpoint.socket or not endpoint.socket.open:
 | 
			
		||||
            return False
 | 
			
		||||
        try:
 | 
			
		||||
            await endpoint.socket.send(msg)
 | 
			
		||||
        except websockets.ConnectionClosed:
 | 
			
		||||
            logging.exception("Exception during send_encoded_msgs")
 | 
			
		||||
            await self.disconnect(endpoint)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.log_network:
 | 
			
		||||
                logging.info(f"Outgoing message: {msg}")
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
    async def disconnect(self, endpoint):
 | 
			
		||||
        if endpoint in self.endpoints:
 | 
			
		||||
            self.endpoints.remove(endpoint)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Endpoint:
 | 
			
		||||
    socket: websockets.WebSocketServerProtocol
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue