MultiServer: categorize methods
This commit is contained in:
		
							parent
							
								
									acbca78e2d
								
							
						
					
					
						commit
						f4f043ac87
					
				
							
								
								
									
										1
									
								
								Main.py
								
								
								
								
							
							
						
						
									
										1
									
								
								Main.py
								
								
								
								
							| 
						 | 
					@ -377,7 +377,6 @@ def main(args, seed=None):
 | 
				
			||||||
                f.write(bytes([1]))  # version of format
 | 
					                f.write(bytes([1]))  # version of format
 | 
				
			||||||
                f.write(multidata)
 | 
					                f.write(multidata)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        multidata_task = pool.submit(write_multidata)
 | 
					        multidata_task = pool.submit(write_multidata)
 | 
				
			||||||
        if not check_accessibility_task.result():
 | 
					        if not check_accessibility_task.result():
 | 
				
			||||||
            if not world.can_beat_game():
 | 
					            if not world.can_beat_game():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										183
									
								
								MultiServer.py
								
								
								
								
							
							
						
						
									
										183
									
								
								MultiServer.py
								
								
								
								
							| 
						 | 
					@ -31,10 +31,11 @@ from worlds import network_data_package, lookup_any_item_id_to_name, lookup_any_
 | 
				
			||||||
import Utils
 | 
					import Utils
 | 
				
			||||||
from Utils import get_item_name_from_id, get_location_name_from_id, \
 | 
					from Utils import get_item_name_from_id, get_location_name_from_id, \
 | 
				
			||||||
    version_tuple, restricted_loads, Version
 | 
					    version_tuple, restricted_loads, Version
 | 
				
			||||||
from NetUtils import Node, Endpoint, ClientStatus, NetworkItem, decode, NetworkPlayer
 | 
					from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
colorama.init()
 | 
					colorama.init()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Client(Endpoint):
 | 
					class Client(Endpoint):
 | 
				
			||||||
    version = Version(0, 0, 0)
 | 
					    version = Version(0, 0, 0)
 | 
				
			||||||
    tags: typing.List[str] = []
 | 
					    tags: typing.List[str] = []
 | 
				
			||||||
| 
						 | 
					@ -50,9 +51,14 @@ class Client(Endpoint):
 | 
				
			||||||
        self.messageprocessor = client_message_processor(ctx, self)
 | 
					        self.messageprocessor = client_message_processor(ctx, self)
 | 
				
			||||||
        self.ctx = weakref.ref(ctx)
 | 
					        self.ctx = weakref.ref(ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
team_slot = typing.Tuple[int, int]
 | 
					team_slot = typing.Tuple[int, int]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Context(Node):
 | 
					
 | 
				
			||||||
 | 
					class Context:
 | 
				
			||||||
 | 
					    dumper = staticmethod(encode)
 | 
				
			||||||
 | 
					    loader = staticmethod(decode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    simple_options = {"hint_cost": int,
 | 
					    simple_options = {"hint_cost": int,
 | 
				
			||||||
                      "location_check_points": int,
 | 
					                      "location_check_points": int,
 | 
				
			||||||
                      "server_password": str,
 | 
					                      "server_password": str,
 | 
				
			||||||
| 
						 | 
					@ -64,8 +70,10 @@ class Context(Node):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int,
 | 
					    def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int,
 | 
				
			||||||
                 hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", remaining_mode: str = "disabled",
 | 
					                 hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", remaining_mode: str = "disabled",
 | 
				
			||||||
                 auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2):
 | 
					                 auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2, log_network: bool = False):
 | 
				
			||||||
        super(Context, self).__init__()
 | 
					        super(Context, self).__init__()
 | 
				
			||||||
 | 
					        self.log_network = log_network
 | 
				
			||||||
 | 
					        self.endpoints = []
 | 
				
			||||||
        self.compatibility: int = compatibility
 | 
					        self.compatibility: int = compatibility
 | 
				
			||||||
        self.shutdown_task = None
 | 
					        self.shutdown_task = None
 | 
				
			||||||
        self.data_filename = None
 | 
					        self.data_filename = None
 | 
				
			||||||
| 
						 | 
					@ -113,10 +121,70 @@ class Context(Node):
 | 
				
			||||||
        self.seed_name = ""
 | 
					        self.seed_name = ""
 | 
				
			||||||
        self.random = random.Random()
 | 
					        self.random = random.Random()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_hint_cost(self, slot):
 | 
					    # General networking
 | 
				
			||||||
        if self.hint_cost:
 | 
					
 | 
				
			||||||
            return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot])))
 | 
					    async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
 | 
				
			||||||
        return 0
 | 
					        if not endpoint.socket or not endpoint.socket.open:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        msg = self.dumper(msgs)
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            await endpoint.socket.send(msg)
 | 
				
			||||||
 | 
					        except websockets.ConnectionClosed:
 | 
				
			||||||
 | 
					            logging.exception(f"Exception during send_msgs, could not send {msg}")
 | 
				
			||||||
 | 
					            await self.disconnect(endpoint)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if self.log_network:
 | 
				
			||||||
 | 
					                logging.info(f"Outgoing message: {msg}")
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
 | 
				
			||||||
 | 
					        if not endpoint.socket or not endpoint.socket.open:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            await endpoint.socket.send(msg)
 | 
				
			||||||
 | 
					        except websockets.ConnectionClosed:
 | 
				
			||||||
 | 
					            logging.exception("Exception during send_encoded_msgs")
 | 
				
			||||||
 | 
					            await self.disconnect(endpoint)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if self.log_network:
 | 
				
			||||||
 | 
					                logging.info(f"Outgoing message: {msg}")
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def broadcast_all(self, msgs):
 | 
				
			||||||
 | 
					        msgs = self.dumper(msgs)
 | 
				
			||||||
 | 
					        for endpoint in self.endpoints:
 | 
				
			||||||
 | 
					            if endpoint.auth:
 | 
				
			||||||
 | 
					                asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def broadcast_team(self, team, msgs):
 | 
				
			||||||
 | 
					        msgs = self.dumper(msgs)
 | 
				
			||||||
 | 
					        for client in self.endpoints:
 | 
				
			||||||
 | 
					            if client.auth and client.team == team:
 | 
				
			||||||
 | 
					                asyncio.create_task(self.send_encoded_msgs(client, msgs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def disconnect(self, endpoint):
 | 
				
			||||||
 | 
					        if endpoint in self.endpoints:
 | 
				
			||||||
 | 
					            self.endpoints.remove(endpoint)
 | 
				
			||||||
 | 
					        await on_client_disconnected(self, endpoint)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def notify_all(self, text):
 | 
				
			||||||
 | 
					        logging.info("Notice (all): %s" % text)
 | 
				
			||||||
 | 
					        self.broadcast_all([{"cmd": "Print", "text": text}])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def notify_client(self, client: Client, text: str):
 | 
				
			||||||
 | 
					        if not client.auth:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
 | 
				
			||||||
 | 
					        asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def notify_client_multiple(self, client: Client, texts: typing.List[str]):
 | 
				
			||||||
 | 
					        if not client.auth:
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # loading
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load(self, multidatapath: str, use_embedded_server_options: bool = False):
 | 
					    def load(self, multidatapath: str, use_embedded_server_options: bool = False):
 | 
				
			||||||
        if multidatapath.lower().endswith(".zip"):
 | 
					        if multidatapath.lower().endswith(".zip"):
 | 
				
			||||||
| 
						 | 
					@ -177,27 +245,7 @@ class Context(Node):
 | 
				
			||||||
            server_options = decoded_obj.get("server_options", {})
 | 
					            server_options = decoded_obj.get("server_options", {})
 | 
				
			||||||
            self._set_options(server_options)
 | 
					            self._set_options(server_options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_players_package(self):
 | 
					    # saving
 | 
				
			||||||
        return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _set_options(self, server_options: dict):
 | 
					 | 
				
			||||||
        for key, value in server_options.items():
 | 
					 | 
				
			||||||
            data_type = self.simple_options.get(key, None)
 | 
					 | 
				
			||||||
            if data_type is not None:
 | 
					 | 
				
			||||||
                if value not in {False, True, None}:  # some can be boolean OR text, such as password
 | 
					 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        value = data_type(value)
 | 
					 | 
				
			||||||
                    except Exception as e:
 | 
					 | 
				
			||||||
                        try:
 | 
					 | 
				
			||||||
                            raise Exception(f"Could not set server option {key}, skipping.") from e
 | 
					 | 
				
			||||||
                        except Exception as e:
 | 
					 | 
				
			||||||
                            logging.exception(e)
 | 
					 | 
				
			||||||
                logging.debug(f"Setting server option {key} to {value} from supplied multidata")
 | 
					 | 
				
			||||||
                setattr(self, key, value)
 | 
					 | 
				
			||||||
            elif key == "disable_item_cheat":
 | 
					 | 
				
			||||||
                self.item_cheat = not bool(value)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                logging.debug(f"Unrecognized server option {key}")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save(self, now=False) -> bool:
 | 
					    def save(self, now=False) -> bool:
 | 
				
			||||||
        if self.saving:
 | 
					        if self.saving:
 | 
				
			||||||
| 
						 | 
					@ -228,7 +276,7 @@ class Context(Node):
 | 
				
			||||||
                import os
 | 
					                import os
 | 
				
			||||||
                name, ext = os.path.splitext(self.data_filename)
 | 
					                name, ext = os.path.splitext(self.data_filename)
 | 
				
			||||||
                self.save_filename = name + '.apsave' if ext.lower() in ('.archipelago','.zip') \
 | 
					                self.save_filename = name + '.apsave' if ext.lower() in ('.archipelago','.zip') \
 | 
				
			||||||
                                     else self.data_filename + '_' + 'apsave'
 | 
					                    else self.data_filename + '_' + 'apsave'
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                with open(self.save_filename, 'rb') as f:
 | 
					                with open(self.save_filename, 'rb') as f:
 | 
				
			||||||
                    save_data = restricted_loads(zlib.decompress(f.read()))
 | 
					                    save_data = restricted_loads(zlib.decompress(f.read()))
 | 
				
			||||||
| 
						 | 
					@ -256,13 +304,6 @@ class Context(Node):
 | 
				
			||||||
            import atexit
 | 
					            import atexit
 | 
				
			||||||
            atexit.register(self._save, True)  # make sure we save on exit too
 | 
					            atexit.register(self._save, True)  # make sure we save on exit too
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def recheck_hints(self):
 | 
					 | 
				
			||||||
        for team, slot in self.hints:
 | 
					 | 
				
			||||||
            self.hints[team, slot] = {
 | 
					 | 
				
			||||||
                hint.re_check(self, team) for hint in
 | 
					 | 
				
			||||||
                self.hints[team, slot]
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_save(self) -> dict:
 | 
					    def get_save(self) -> dict:
 | 
				
			||||||
        self.recheck_hints()
 | 
					        self.recheck_hints()
 | 
				
			||||||
        d = {
 | 
					        d = {
 | 
				
			||||||
| 
						 | 
					@ -303,43 +344,48 @@ class Context(Node):
 | 
				
			||||||
        logging.info(f'Loaded save file with {sum([len(p) for p in self.received_items.values()])} received items '
 | 
					        logging.info(f'Loaded save file with {sum([len(p) for p in self.received_items.values()])} received items '
 | 
				
			||||||
                     f'for {len(self.received_items)} players')
 | 
					                     f'for {len(self.received_items)} players')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # rest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_hint_cost(self, slot):
 | 
				
			||||||
 | 
					        if self.hint_cost:
 | 
				
			||||||
 | 
					            return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot])))
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recheck_hints(self):
 | 
				
			||||||
 | 
					        for team, slot in self.hints:
 | 
				
			||||||
 | 
					            self.hints[team, slot] = {
 | 
				
			||||||
 | 
					                hint.re_check(self, team) for hint in
 | 
				
			||||||
 | 
					                self.hints[team, slot]
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_players_package(self):
 | 
				
			||||||
 | 
					        return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _set_options(self, server_options: dict):
 | 
				
			||||||
 | 
					        for key, value in server_options.items():
 | 
				
			||||||
 | 
					            data_type = self.simple_options.get(key, None)
 | 
				
			||||||
 | 
					            if data_type is not None:
 | 
				
			||||||
 | 
					                if value not in {False, True, None}:  # some can be boolean OR text, such as password
 | 
				
			||||||
 | 
					                    try:
 | 
				
			||||||
 | 
					                        value = data_type(value)
 | 
				
			||||||
 | 
					                    except Exception as e:
 | 
				
			||||||
 | 
					                        try:
 | 
				
			||||||
 | 
					                            raise Exception(f"Could not set server option {key}, skipping.") from e
 | 
				
			||||||
 | 
					                        except Exception as e:
 | 
				
			||||||
 | 
					                            logging.exception(e)
 | 
				
			||||||
 | 
					                logging.debug(f"Setting server option {key} to {value} from supplied multidata")
 | 
				
			||||||
 | 
					                setattr(self, key, value)
 | 
				
			||||||
 | 
					            elif key == "disable_item_cheat":
 | 
				
			||||||
 | 
					                self.item_cheat = not bool(value)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                logging.debug(f"Unrecognized server option {key}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_aliased_name(self, team: int, slot: int):
 | 
					    def get_aliased_name(self, team: int, slot: int):
 | 
				
			||||||
        if (team, slot) in self.name_aliases:
 | 
					        if (team, slot) in self.name_aliases:
 | 
				
			||||||
            return f"{self.name_aliases[team, slot]} ({self.player_names[team, slot]})"
 | 
					            return f"{self.name_aliases[team, slot]} ({self.player_names[team, slot]})"
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return self.player_names[team, slot]
 | 
					            return self.player_names[team, slot]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def notify_all(self, text):
 | 
					 | 
				
			||||||
        logging.info("Notice (all): %s" % text)
 | 
					 | 
				
			||||||
        self.broadcast_all([{"cmd": "Print", "text": text}])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def notify_client(self, client: Client, text: str):
 | 
					 | 
				
			||||||
        if not client.auth:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
 | 
					 | 
				
			||||||
        asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def notify_client_multiple(self, client: Client, texts: typing.List[str]):
 | 
					 | 
				
			||||||
        if not client.auth:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def broadcast_team(self, team, msgs):
 | 
					 | 
				
			||||||
        msgs = self.dumper(msgs)
 | 
					 | 
				
			||||||
        for client in self.endpoints:
 | 
					 | 
				
			||||||
            if client.auth and client.team == team:
 | 
					 | 
				
			||||||
                asyncio.create_task(self.send_encoded_msgs(client, msgs))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def broadcast_all(self, msgs):
 | 
					 | 
				
			||||||
        msgs = self.dumper(msgs)
 | 
					 | 
				
			||||||
        for endpoint in self.endpoints:
 | 
					 | 
				
			||||||
            if endpoint.auth:
 | 
					 | 
				
			||||||
                asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def disconnect(self, endpoint):
 | 
					 | 
				
			||||||
        await super(Context, self).disconnect(endpoint)
 | 
					 | 
				
			||||||
        await on_client_disconnected(self, endpoint)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]):
 | 
					def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]):
 | 
				
			||||||
    concerns = collections.defaultdict(list)
 | 
					    concerns = collections.defaultdict(list)
 | 
				
			||||||
| 
						 | 
					@ -1431,8 +1477,7 @@ async def main(args: argparse.Namespace):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ctx = Context(args.host, args.port, args.server_password, args.password, args.location_check_points,
 | 
					    ctx = Context(args.host, args.port, args.server_password, args.password, args.location_check_points,
 | 
				
			||||||
                  args.hint_cost, not args.disable_item_cheat, args.forfeit_mode, args.remaining_mode,
 | 
					                  args.hint_cost, not args.disable_item_cheat, args.forfeit_mode, args.remaining_mode,
 | 
				
			||||||
                  args.auto_shutdown, args.compatibility)
 | 
					                  args.auto_shutdown, args.compatibility, args.log_network)
 | 
				
			||||||
    ctx.log_network = args.log_network
 | 
					 | 
				
			||||||
    data_filename = args.multidata
 | 
					    data_filename = args.multidata
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										48
									
								
								NetUtils.py
								
								
								
								
							
							
						
						
									
										48
									
								
								NetUtils.py
								
								
								
								
							| 
						 | 
					@ -1,6 +1,4 @@
 | 
				
			||||||
from __future__ import annotations
 | 
					from __future__ import annotations
 | 
				
			||||||
import asyncio
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import typing
 | 
					import typing
 | 
				
			||||||
import enum
 | 
					import enum
 | 
				
			||||||
from json import JSONEncoder, JSONDecoder
 | 
					from json import JSONEncoder, JSONDecoder
 | 
				
			||||||
| 
						 | 
					@ -94,52 +92,6 @@ def _object_hook(o: typing.Any) -> typing.Any:
 | 
				
			||||||
decode = JSONDecoder(object_hook=_object_hook).decode
 | 
					decode = JSONDecoder(object_hook=_object_hook).decode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Node:
 | 
					 | 
				
			||||||
    endpoints: typing.List
 | 
					 | 
				
			||||||
    dumper = staticmethod(encode)
 | 
					 | 
				
			||||||
    loader = staticmethod(decode)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self):
 | 
					 | 
				
			||||||
        self.endpoints = []
 | 
					 | 
				
			||||||
        super(Node, self).__init__()
 | 
					 | 
				
			||||||
        self.log_network = 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def broadcast_all(self, msgs):
 | 
					 | 
				
			||||||
        msgs = self.dumper(msgs)
 | 
					 | 
				
			||||||
        for endpoint in self.endpoints:
 | 
					 | 
				
			||||||
            asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
 | 
					 | 
				
			||||||
        if not endpoint.socket or not endpoint.socket.open:
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
        msg = self.dumper(msgs)
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            await endpoint.socket.send(msg)
 | 
					 | 
				
			||||||
        except websockets.ConnectionClosed:
 | 
					 | 
				
			||||||
            logging.exception(f"Exception during send_msgs, could not send {msg}")
 | 
					 | 
				
			||||||
            await self.disconnect(endpoint)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if self.log_network:
 | 
					 | 
				
			||||||
                logging.info(f"Outgoing message: {msg}")
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
 | 
					 | 
				
			||||||
        if not endpoint.socket or not endpoint.socket.open:
 | 
					 | 
				
			||||||
            return False
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            await endpoint.socket.send(msg)
 | 
					 | 
				
			||||||
        except websockets.ConnectionClosed:
 | 
					 | 
				
			||||||
            logging.exception("Exception during send_encoded_msgs")
 | 
					 | 
				
			||||||
            await self.disconnect(endpoint)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if self.log_network:
 | 
					 | 
				
			||||||
                logging.info(f"Outgoing message: {msg}")
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def disconnect(self, endpoint):
 | 
					 | 
				
			||||||
        if endpoint in self.endpoints:
 | 
					 | 
				
			||||||
            self.endpoints.remove(endpoint)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Endpoint:
 | 
					class Endpoint:
 | 
				
			||||||
    socket: websockets.WebSocketServerProtocol
 | 
					    socket: websockets.WebSocketServerProtocol
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue