MultiServer: categorize methods

This commit is contained in:
Fabian Dill 2021-08-26 16:19:37 +02:00
parent acbca78e2d
commit f4f043ac87
3 changed files with 114 additions and 118 deletions

View File

@ -377,7 +377,6 @@ def main(args, seed=None):
f.write(bytes([1])) # version of format f.write(bytes([1])) # version of format
f.write(multidata) f.write(multidata)
multidata_task = pool.submit(write_multidata) multidata_task = pool.submit(write_multidata)
if not check_accessibility_task.result(): if not check_accessibility_task.result():
if not world.can_beat_game(): if not world.can_beat_game():

View File

@ -31,10 +31,11 @@ from worlds import network_data_package, lookup_any_item_id_to_name, lookup_any_
import Utils import Utils
from Utils import get_item_name_from_id, get_location_name_from_id, \ from Utils import get_item_name_from_id, get_location_name_from_id, \
version_tuple, restricted_loads, Version version_tuple, restricted_loads, Version
from NetUtils import Node, Endpoint, ClientStatus, NetworkItem, decode, NetworkPlayer from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer
colorama.init() colorama.init()
class Client(Endpoint): class Client(Endpoint):
version = Version(0, 0, 0) version = Version(0, 0, 0)
tags: typing.List[str] = [] tags: typing.List[str] = []
@ -50,9 +51,14 @@ class Client(Endpoint):
self.messageprocessor = client_message_processor(ctx, self) self.messageprocessor = client_message_processor(ctx, self)
self.ctx = weakref.ref(ctx) self.ctx = weakref.ref(ctx)
team_slot = typing.Tuple[int, int] team_slot = typing.Tuple[int, int]
class Context(Node):
class Context:
dumper = staticmethod(encode)
loader = staticmethod(decode)
simple_options = {"hint_cost": int, simple_options = {"hint_cost": int,
"location_check_points": int, "location_check_points": int,
"server_password": str, "server_password": str,
@ -64,8 +70,10 @@ class Context(Node):
def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int, def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int,
hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", remaining_mode: str = "disabled", hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", remaining_mode: str = "disabled",
auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2): auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2, log_network: bool = False):
super(Context, self).__init__() super(Context, self).__init__()
self.log_network = log_network
self.endpoints = []
self.compatibility: int = compatibility self.compatibility: int = compatibility
self.shutdown_task = None self.shutdown_task = None
self.data_filename = None self.data_filename = None
@ -113,10 +121,70 @@ class Context(Node):
self.seed_name = "" self.seed_name = ""
self.random = random.Random() self.random = random.Random()
def get_hint_cost(self, slot): # General networking
if self.hint_cost:
return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot]))) async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
return 0 if not endpoint.socket or not endpoint.socket.open:
return False
msg = self.dumper(msgs)
try:
await endpoint.socket.send(msg)
except websockets.ConnectionClosed:
logging.exception(f"Exception during send_msgs, could not send {msg}")
await self.disconnect(endpoint)
else:
if self.log_network:
logging.info(f"Outgoing message: {msg}")
return True
async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
if not endpoint.socket or not endpoint.socket.open:
return False
try:
await endpoint.socket.send(msg)
except websockets.ConnectionClosed:
logging.exception("Exception during send_encoded_msgs")
await self.disconnect(endpoint)
else:
if self.log_network:
logging.info(f"Outgoing message: {msg}")
return True
def broadcast_all(self, msgs):
msgs = self.dumper(msgs)
for endpoint in self.endpoints:
if endpoint.auth:
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
def broadcast_team(self, team, msgs):
msgs = self.dumper(msgs)
for client in self.endpoints:
if client.auth and client.team == team:
asyncio.create_task(self.send_encoded_msgs(client, msgs))
async def disconnect(self, endpoint):
if endpoint in self.endpoints:
self.endpoints.remove(endpoint)
await on_client_disconnected(self, endpoint)
# text
def notify_all(self, text):
logging.info("Notice (all): %s" % text)
self.broadcast_all([{"cmd": "Print", "text": text}])
def notify_client(self, client: Client, text: str):
if not client.auth:
return
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
def notify_client_multiple(self, client: Client, texts: typing.List[str]):
if not client.auth:
return
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
# loading
def load(self, multidatapath: str, use_embedded_server_options: bool = False): def load(self, multidatapath: str, use_embedded_server_options: bool = False):
if multidatapath.lower().endswith(".zip"): if multidatapath.lower().endswith(".zip"):
@ -177,27 +245,7 @@ class Context(Node):
server_options = decoded_obj.get("server_options", {}) server_options = decoded_obj.get("server_options", {})
self._set_options(server_options) self._set_options(server_options)
def get_players_package(self): # saving
return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
def _set_options(self, server_options: dict):
for key, value in server_options.items():
data_type = self.simple_options.get(key, None)
if data_type is not None:
if value not in {False, True, None}: # some can be boolean OR text, such as password
try:
value = data_type(value)
except Exception as e:
try:
raise Exception(f"Could not set server option {key}, skipping.") from e
except Exception as e:
logging.exception(e)
logging.debug(f"Setting server option {key} to {value} from supplied multidata")
setattr(self, key, value)
elif key == "disable_item_cheat":
self.item_cheat = not bool(value)
else:
logging.debug(f"Unrecognized server option {key}")
def save(self, now=False) -> bool: def save(self, now=False) -> bool:
if self.saving: if self.saving:
@ -228,7 +276,7 @@ class Context(Node):
import os import os
name, ext = os.path.splitext(self.data_filename) name, ext = os.path.splitext(self.data_filename)
self.save_filename = name + '.apsave' if ext.lower() in ('.archipelago','.zip') \ self.save_filename = name + '.apsave' if ext.lower() in ('.archipelago','.zip') \
else self.data_filename + '_' + 'apsave' else self.data_filename + '_' + 'apsave'
try: try:
with open(self.save_filename, 'rb') as f: with open(self.save_filename, 'rb') as f:
save_data = restricted_loads(zlib.decompress(f.read())) save_data = restricted_loads(zlib.decompress(f.read()))
@ -256,13 +304,6 @@ class Context(Node):
import atexit import atexit
atexit.register(self._save, True) # make sure we save on exit too atexit.register(self._save, True) # make sure we save on exit too
def recheck_hints(self):
for team, slot in self.hints:
self.hints[team, slot] = {
hint.re_check(self, team) for hint in
self.hints[team, slot]
}
def get_save(self) -> dict: def get_save(self) -> dict:
self.recheck_hints() self.recheck_hints()
d = { d = {
@ -303,43 +344,48 @@ class Context(Node):
logging.info(f'Loaded save file with {sum([len(p) for p in self.received_items.values()])} received items ' logging.info(f'Loaded save file with {sum([len(p) for p in self.received_items.values()])} received items '
f'for {len(self.received_items)} players') f'for {len(self.received_items)} players')
# rest
def get_hint_cost(self, slot):
if self.hint_cost:
return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot])))
return 0
def recheck_hints(self):
for team, slot in self.hints:
self.hints[team, slot] = {
hint.re_check(self, team) for hint in
self.hints[team, slot]
}
def get_players_package(self):
return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
def _set_options(self, server_options: dict):
for key, value in server_options.items():
data_type = self.simple_options.get(key, None)
if data_type is not None:
if value not in {False, True, None}: # some can be boolean OR text, such as password
try:
value = data_type(value)
except Exception as e:
try:
raise Exception(f"Could not set server option {key}, skipping.") from e
except Exception as e:
logging.exception(e)
logging.debug(f"Setting server option {key} to {value} from supplied multidata")
setattr(self, key, value)
elif key == "disable_item_cheat":
self.item_cheat = not bool(value)
else:
logging.debug(f"Unrecognized server option {key}")
def get_aliased_name(self, team: int, slot: int): def get_aliased_name(self, team: int, slot: int):
if (team, slot) in self.name_aliases: if (team, slot) in self.name_aliases:
return f"{self.name_aliases[team, slot]} ({self.player_names[team, slot]})" return f"{self.name_aliases[team, slot]} ({self.player_names[team, slot]})"
else: else:
return self.player_names[team, slot] return self.player_names[team, slot]
def notify_all(self, text):
logging.info("Notice (all): %s" % text)
self.broadcast_all([{"cmd": "Print", "text": text}])
def notify_client(self, client: Client, text: str):
if not client.auth:
return
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
def notify_client_multiple(self, client: Client, texts: typing.List[str]):
if not client.auth:
return
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
def broadcast_team(self, team, msgs):
msgs = self.dumper(msgs)
for client in self.endpoints:
if client.auth and client.team == team:
asyncio.create_task(self.send_encoded_msgs(client, msgs))
def broadcast_all(self, msgs):
msgs = self.dumper(msgs)
for endpoint in self.endpoints:
if endpoint.auth:
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
async def disconnect(self, endpoint):
await super(Context, self).disconnect(endpoint)
await on_client_disconnected(self, endpoint)
def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]): def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]):
concerns = collections.defaultdict(list) concerns = collections.defaultdict(list)
@ -1431,8 +1477,7 @@ async def main(args: argparse.Namespace):
ctx = Context(args.host, args.port, args.server_password, args.password, args.location_check_points, ctx = Context(args.host, args.port, args.server_password, args.password, args.location_check_points,
args.hint_cost, not args.disable_item_cheat, args.forfeit_mode, args.remaining_mode, args.hint_cost, not args.disable_item_cheat, args.forfeit_mode, args.remaining_mode,
args.auto_shutdown, args.compatibility) args.auto_shutdown, args.compatibility, args.log_network)
ctx.log_network = args.log_network
data_filename = args.multidata data_filename = args.multidata
try: try:

View File

@ -1,6 +1,4 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging
import typing import typing
import enum import enum
from json import JSONEncoder, JSONDecoder from json import JSONEncoder, JSONDecoder
@ -94,52 +92,6 @@ def _object_hook(o: typing.Any) -> typing.Any:
decode = JSONDecoder(object_hook=_object_hook).decode decode = JSONDecoder(object_hook=_object_hook).decode
class Node:
endpoints: typing.List
dumper = staticmethod(encode)
loader = staticmethod(decode)
def __init__(self):
self.endpoints = []
super(Node, self).__init__()
self.log_network = 0
def broadcast_all(self, msgs):
msgs = self.dumper(msgs)
for endpoint in self.endpoints:
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
if not endpoint.socket or not endpoint.socket.open:
return False
msg = self.dumper(msgs)
try:
await endpoint.socket.send(msg)
except websockets.ConnectionClosed:
logging.exception(f"Exception during send_msgs, could not send {msg}")
await self.disconnect(endpoint)
else:
if self.log_network:
logging.info(f"Outgoing message: {msg}")
return True
async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
if not endpoint.socket or not endpoint.socket.open:
return False
try:
await endpoint.socket.send(msg)
except websockets.ConnectionClosed:
logging.exception("Exception during send_encoded_msgs")
await self.disconnect(endpoint)
else:
if self.log_network:
logging.info(f"Outgoing message: {msg}")
return True
async def disconnect(self, endpoint):
if endpoint in self.endpoints:
self.endpoints.remove(endpoint)
class Endpoint: class Endpoint:
socket: websockets.WebSocketServerProtocol socket: websockets.WebSocketServerProtocol