diff --git a/MultiServer.py b/MultiServer.py index 3bf14524..3f741da5 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -13,6 +13,7 @@ import datetime import threading import random import pickle +import itertools import ModuleUpdate import NetUtils @@ -44,7 +45,6 @@ class Client(Endpoint): def __init__(self, socket: websockets.WebSocketServerProtocol, ctx: Context): super().__init__(socket) self.auth = False - self.name = None self.team = None self.slot = None self.send_index = 0 @@ -52,6 +52,13 @@ class Client(Endpoint): self.messageprocessor = client_message_processor(ctx, self) self.ctx = weakref.ref(ctx) + @property + def name(self) -> str: + ctx = self.ctx() + if ctx: + return ctx.player_names[self.team, self.slot] + return "Deallocated" + team_slot = typing.Tuple[int, int] @@ -69,6 +76,8 @@ class Context: "collect_mode": str, "item_cheat": bool, "compatibility": int} + # team -> slot id -> list of clients authenticated to slot. + clients: typing.Dict[int, typing.Dict[int, typing.List[Client]]] def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int, hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", collect_mode="disabled", @@ -77,6 +86,7 @@ class Context: super(Context, self).__init__() self.log_network = log_network self.endpoints = [] + self.clients = {} self.compatibility: int = compatibility self.shutdown_task = None self.data_filename = None @@ -170,23 +180,25 @@ class Context: logging.info(f"Outgoing broadcast: {msg}") return True - def broadcast_all(self, msgs): + def broadcast_all(self, msgs: typing.List[dict]): msgs = self.dumper(msgs) endpoints = (endpoint for endpoint in self.endpoints if endpoint.auth) asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) - def broadcast_team(self, team: int, msgs): + def broadcast_team(self, team: int, msgs: typing.List[dict]): msgs = self.dumper(msgs) - endpoints = (endpoint for endpoint in self.endpoints if endpoint.auth and endpoint.team == team) + endpoints = (endpoint for endpoint in itertools.chain.from_iterable(self.clients[team].values())) asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) - def broadcast(self, endpoints: typing.Iterable[Endpoint], msgs): + def broadcast(self, endpoints: typing.Iterable[Client], msgs: typing.List[dict]): msgs = self.dumper(msgs) asyncio.create_task(self.broadcast_send_encoded_msgs(endpoints, msgs)) - async def disconnect(self, endpoint): + async def disconnect(self, endpoint: Client): if endpoint in self.endpoints: self.endpoints.remove(endpoint) + if endpoint.slot and endpoint in self.clients[endpoint.team][endpoint.slot]: + self.clients[endpoint.team][endpoint.slot].remove(endpoint) await on_client_disconnected(self, endpoint) # text @@ -243,8 +255,11 @@ class Context: for player, version in clients_ver.items(): self.minimum_client_versions[player] = Utils.Version(*version) + self.clients = {} for team, names in enumerate(decoded_obj['names']): + self.clients[team] = {} for player, name in enumerate(names, 1): + self.clients[team][player] = [] self.player_names[team, player] = name self.player_name_lookup[name] = team, player self.seed_name = decoded_obj["seed_name"] @@ -420,22 +435,20 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]): for text in (format_hint(ctx, team, hint) for hint in hints): logging.info("Notice (Team #%d): %s" % (team + 1, text)) - for client in ctx.endpoints: - if client.auth and client.team == team: - client_hints = concerns[client.slot] - if client_hints: + for slot, clients in ctx.clients[team].items(): + client_hints = concerns[slot] + if client_hints: + for client in clients: asyncio.create_task(ctx.send_msgs(client, client_hints)) -def update_aliases(ctx: Context, team: int, client: typing.Optional[Client] = None): +def update_aliases(ctx: Context, team: int): cmd = ctx.dumper([{"cmd": "RoomUpdate", "players": ctx.get_players_package()}]) - if client is None: - for client in ctx.endpoints: - if client.team == team and client.auth: - asyncio.create_task(ctx.send_encoded_msgs(client, cmd)) - else: - asyncio.create_task(ctx.send_encoded_msgs(client, cmd)) + + for clients in ctx.clients[team].values(): + for client in clients: + asyncio.create_task(ctx.send_encoded_msgs(client, cmd)) async def server(websocket, path, ctx: Context): @@ -463,13 +476,19 @@ async def server(websocket, path, ctx: Context): async def on_client_connected(ctx: Context, client: Client): + players = [] + for team, clients in ctx.clients.items(): + for slot, connected_clients in clients.items(): + if connected_clients: + name = ctx.player_names[team, slot] + players.append( + NetworkPlayer(team, slot, + ctx.name_aliases.get((team, slot), name), name) + ) await ctx.send_msgs(client, [{ 'cmd': 'RoomInfo', 'password': bool(ctx.password), - 'players': [ - NetworkPlayer(client.team, client.slot, ctx.name_aliases.get((client.team, client.slot), client.name), - client.name) for client - in ctx.endpoints if client.auth], + 'players': players, # tags are for additional features in the communication. # Name them by feature or fork, as you feel is appropriate. 'tags': ctx.tags, @@ -552,21 +571,21 @@ def get_received_items(ctx: Context, team: int, player: int) -> typing.List[Netw def send_new_items(ctx: Context): - for client in ctx.endpoints: - if client.auth: # can't send to disconnected client - items = get_received_items(ctx, client.team, client.slot) - if len(items) > client.send_index: - asyncio.create_task(ctx.send_msgs(client, [{ - "cmd": "ReceivedItems", - "index": client.send_index, - "items": items[client.send_index:]}])) - client.send_index = len(items) + for team, clients in ctx.clients.items(): + for slot, clients in clients.items(): + items = get_received_items(ctx, team, slot) + for client in clients: + if len(items) > client.send_index: + asyncio.create_task(ctx.send_msgs(client, [{ + "cmd": "ReceivedItems", + "index": client.send_index, + "items": items[client.send_index:]}])) + client.send_index = len(items) def update_checked_locations(ctx: Context, team: int, slot: int): - for client in ctx.endpoints: - if client.team == team and client.slot == slot: - ctx.send_msgs(client, [{"cmd": "RoomUpdate", "checked_locations": get_checked_checks(ctx, client)}]) + ctx.broadcast(ctx.clients[team][slot], + [{"cmd": "RoomUpdate", "checked_locations": get_checked_checks(ctx, team, slot)}]) def forfeit_player(ctx: Context, team: int, slot: int): @@ -618,10 +637,11 @@ def register_location_checks(ctx: Context, team: int, slot: int, locations: typi ctx.location_checks[team, slot] |= new_locations send_new_items(ctx) - for client in ctx.endpoints: - if client.team == team and client.slot == slot: - asyncio.create_task(ctx.send_msgs(client, [{"cmd": "RoomUpdate", - "hint_points": get_client_points(ctx, client)}])) + ctx.broadcast(ctx.clients[team][slot], [{ + "cmd": "RoomUpdate", + "hint_points": get_slot_points(ctx, team, slot), + "checked_locations": locations, # duplicated data, but used for coop + }]) ctx.save() @@ -965,7 +985,7 @@ class ClientMessageProcessor(CommonCommandProcessor): def _cmd_missing(self) -> bool: """List all missing location checks from the server's perspective""" - locations = get_missing_checks(self.ctx, self.client) + locations = get_missing_checks(self.ctx, self.client.team, self.client.slot) if locations: texts = [f'Missing: {get_location_name_from_id(location)}' for location in locations] @@ -1117,16 +1137,16 @@ class ClientMessageProcessor(CommonCommandProcessor): return self.get_hints(location, True) -def get_checked_checks(ctx: Context, client: Client) -> typing.List[int]: +def get_checked_checks(ctx: Context, team: int, slot: int) -> typing.List[int]: return [location_id for - location_id in ctx.locations[client.slot] if - location_id in ctx.location_checks[client.team, client.slot]] + location_id in ctx.locations[slot] if + location_id in ctx.location_checks[team, slot]] -def get_missing_checks(ctx: Context, client: Client) -> typing.List[int]: +def get_missing_checks(ctx: Context, team: int, slot: int) -> typing.List[int]: return [location_id for - location_id in ctx.locations[client.slot] if - location_id not in ctx.location_checks[client.team, client.slot]] + location_id in ctx.locations[slot] if + location_id not in ctx.location_checks[team, slot]] def get_client_points(ctx: Context, client: Client) -> int: @@ -1134,6 +1154,11 @@ def get_client_points(ctx: Context, client: Client) -> int: ctx.get_hint_cost(client.slot) * ctx.hints_used[client.team, client.slot]) +def get_slot_points(ctx: Context, team: int, slot: int) -> int: + return (ctx.location_check_points * len(ctx.location_checks[team, slot]) - + ctx.get_hint_cost(slot) * ctx.hints_used[team, slot]) + + async def process_client_cmd(ctx: Context, client: Client, args: dict): try: cmd: str = args["cmd"] @@ -1166,26 +1191,6 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict): game = ctx.games[slot] if "IgnoreGame" not in args["tags"] and args['game'] != game: errors.add('InvalidGame') - # this can only ever be 0 or 1 elements - clients = [c for c in ctx.endpoints if c.auth and c.slot == slot and c.team == team] - if clients: - # likely same player with a "ghosted" slot. We bust the ghost. - if "uuid" in args and ctx.client_ids[team, slot] == args["uuid"]: - await ctx.send_msgs(clients[0], [{"cmd": "Print", "text": "You are getting kicked " - "by yourself reconnecting."}]) - await clients[0].socket.close() # we have to await the DC of the ghost, so not to create data pasta - client.name = ctx.player_names[(team, slot)] - client.team = team - client.slot = slot - else: - errors.add('SlotAlreadyTaken') - else: - client.name = ctx.player_names[(team, slot)] - client.team = team - client.slot = slot - minver = ctx.minimum_client_versions[slot] - if minver > args['version']: - errors.add('IncompatibleVersion') # only exact version match allowed if ctx.compatibility == 0 and args['version'] != version_tuple: @@ -1194,16 +1199,23 @@ async def process_client_cmd(ctx: Context, client: Client, args: dict): logging.info(f"A client connection was refused due to: {errors}, the sent connect information was {args}.") await ctx.send_msgs(client, [{"cmd": "ConnectionRefused", "errors": list(errors)}]) else: + team, slot = ctx.connect_names[args['name']] + client.team = team + client.slot = slot + minver = ctx.minimum_client_versions[slot] + if minver > args['version']: + errors.add('IncompatibleVersion') ctx.client_ids[client.team, client.slot] = args["uuid"] client.auth = True + ctx.clients[team][slot].append(client) client.version = args['version'] client.tags = args['tags'] reply = [{ "cmd": "Connected", "team": client.team, "slot": client.slot, "players": ctx.get_players_package(), - "missing_locations": get_missing_checks(ctx, client), - "checked_locations": get_checked_checks(ctx, client), + "missing_locations": get_missing_checks(ctx, team, slot), + "checked_locations": get_checked_checks(ctx, team, slot), # get is needed for old multidata that was sparsely populated "slot_data": ctx.slot_data.get(client.slot, {}) }] @@ -1303,20 +1315,6 @@ class ServerCommandProcessor(CommonCommandProcessor): def default(self, raw: str): self.ctx.notify_all('[Server]: ' + raw) - @mark_raw - def _cmd_kick(self, player_name: str) -> bool: - """Kick specified player from the server""" - for client in self.ctx.endpoints: - if client.auth and client.name.lower() == player_name.lower() and client.socket and not client.socket.closed: - asyncio.create_task(client.socket.close()) - self.output(f"Kicked {self.ctx.get_aliased_name(client.team, client.slot)}") - if self.ctx.commandprocessor.client == client: - self.ctx.commandprocessor.client = None - return True - - self.output(f"Could not find player {player_name} to kick") - return False - def _cmd_save(self) -> bool: """Save current state to multidata""" if self.ctx.saving: diff --git a/NetUtils.py b/NetUtils.py index ee89ccc8..8b2e08c8 100644 --- a/NetUtils.py +++ b/NetUtils.py @@ -62,7 +62,7 @@ def _scan_for_TypedTuples(obj: typing.Any) -> typing.Any: data = obj._asdict() data["class"] = obj.__class__.__name__ return data - if isinstance(obj, (tuple, list)): + if isinstance(obj, (tuple, list, set)): return tuple(_scan_for_TypedTuples(o) for o in obj) if isinstance(obj, dict): return {key: _scan_for_TypedTuples(value) for key, value in obj.items()}