MultiServer: categorize methods
This commit is contained in:
parent
acbca78e2d
commit
f4f043ac87
1
Main.py
1
Main.py
|
@ -377,7 +377,6 @@ def main(args, seed=None):
|
|||
f.write(bytes([1])) # version of format
|
||||
f.write(multidata)
|
||||
|
||||
|
||||
multidata_task = pool.submit(write_multidata)
|
||||
if not check_accessibility_task.result():
|
||||
if not world.can_beat_game():
|
||||
|
|
183
MultiServer.py
183
MultiServer.py
|
@ -31,10 +31,11 @@ from worlds import network_data_package, lookup_any_item_id_to_name, lookup_any_
|
|||
import Utils
|
||||
from Utils import get_item_name_from_id, get_location_name_from_id, \
|
||||
version_tuple, restricted_loads, Version
|
||||
from NetUtils import Node, Endpoint, ClientStatus, NetworkItem, decode, NetworkPlayer
|
||||
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer
|
||||
|
||||
colorama.init()
|
||||
|
||||
|
||||
class Client(Endpoint):
|
||||
version = Version(0, 0, 0)
|
||||
tags: typing.List[str] = []
|
||||
|
@ -50,9 +51,14 @@ class Client(Endpoint):
|
|||
self.messageprocessor = client_message_processor(ctx, self)
|
||||
self.ctx = weakref.ref(ctx)
|
||||
|
||||
|
||||
team_slot = typing.Tuple[int, int]
|
||||
|
||||
class Context(Node):
|
||||
|
||||
class Context:
|
||||
dumper = staticmethod(encode)
|
||||
loader = staticmethod(decode)
|
||||
|
||||
simple_options = {"hint_cost": int,
|
||||
"location_check_points": int,
|
||||
"server_password": str,
|
||||
|
@ -64,8 +70,10 @@ class Context(Node):
|
|||
|
||||
def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int,
|
||||
hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", remaining_mode: str = "disabled",
|
||||
auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2):
|
||||
auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2, log_network: bool = False):
|
||||
super(Context, self).__init__()
|
||||
self.log_network = log_network
|
||||
self.endpoints = []
|
||||
self.compatibility: int = compatibility
|
||||
self.shutdown_task = None
|
||||
self.data_filename = None
|
||||
|
@ -113,10 +121,70 @@ class Context(Node):
|
|||
self.seed_name = ""
|
||||
self.random = random.Random()
|
||||
|
||||
def get_hint_cost(self, slot):
|
||||
if self.hint_cost:
|
||||
return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot])))
|
||||
return 0
|
||||
# General networking
|
||||
|
||||
async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
|
||||
if not endpoint.socket or not endpoint.socket.open:
|
||||
return False
|
||||
msg = self.dumper(msgs)
|
||||
try:
|
||||
await endpoint.socket.send(msg)
|
||||
except websockets.ConnectionClosed:
|
||||
logging.exception(f"Exception during send_msgs, could not send {msg}")
|
||||
await self.disconnect(endpoint)
|
||||
else:
|
||||
if self.log_network:
|
||||
logging.info(f"Outgoing message: {msg}")
|
||||
return True
|
||||
|
||||
async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
|
||||
if not endpoint.socket or not endpoint.socket.open:
|
||||
return False
|
||||
try:
|
||||
await endpoint.socket.send(msg)
|
||||
except websockets.ConnectionClosed:
|
||||
logging.exception("Exception during send_encoded_msgs")
|
||||
await self.disconnect(endpoint)
|
||||
else:
|
||||
if self.log_network:
|
||||
logging.info(f"Outgoing message: {msg}")
|
||||
return True
|
||||
|
||||
def broadcast_all(self, msgs):
|
||||
msgs = self.dumper(msgs)
|
||||
for endpoint in self.endpoints:
|
||||
if endpoint.auth:
|
||||
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
|
||||
|
||||
def broadcast_team(self, team, msgs):
|
||||
msgs = self.dumper(msgs)
|
||||
for client in self.endpoints:
|
||||
if client.auth and client.team == team:
|
||||
asyncio.create_task(self.send_encoded_msgs(client, msgs))
|
||||
|
||||
async def disconnect(self, endpoint):
|
||||
if endpoint in self.endpoints:
|
||||
self.endpoints.remove(endpoint)
|
||||
await on_client_disconnected(self, endpoint)
|
||||
|
||||
# text
|
||||
|
||||
def notify_all(self, text):
|
||||
logging.info("Notice (all): %s" % text)
|
||||
self.broadcast_all([{"cmd": "Print", "text": text}])
|
||||
|
||||
def notify_client(self, client: Client, text: str):
|
||||
if not client.auth:
|
||||
return
|
||||
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
|
||||
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
|
||||
|
||||
def notify_client_multiple(self, client: Client, texts: typing.List[str]):
|
||||
if not client.auth:
|
||||
return
|
||||
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
|
||||
|
||||
# loading
|
||||
|
||||
def load(self, multidatapath: str, use_embedded_server_options: bool = False):
|
||||
if multidatapath.lower().endswith(".zip"):
|
||||
|
@ -177,27 +245,7 @@ class Context(Node):
|
|||
server_options = decoded_obj.get("server_options", {})
|
||||
self._set_options(server_options)
|
||||
|
||||
def get_players_package(self):
|
||||
return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
|
||||
|
||||
def _set_options(self, server_options: dict):
|
||||
for key, value in server_options.items():
|
||||
data_type = self.simple_options.get(key, None)
|
||||
if data_type is not None:
|
||||
if value not in {False, True, None}: # some can be boolean OR text, such as password
|
||||
try:
|
||||
value = data_type(value)
|
||||
except Exception as e:
|
||||
try:
|
||||
raise Exception(f"Could not set server option {key}, skipping.") from e
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
logging.debug(f"Setting server option {key} to {value} from supplied multidata")
|
||||
setattr(self, key, value)
|
||||
elif key == "disable_item_cheat":
|
||||
self.item_cheat = not bool(value)
|
||||
else:
|
||||
logging.debug(f"Unrecognized server option {key}")
|
||||
# saving
|
||||
|
||||
def save(self, now=False) -> bool:
|
||||
if self.saving:
|
||||
|
@ -228,7 +276,7 @@ class Context(Node):
|
|||
import os
|
||||
name, ext = os.path.splitext(self.data_filename)
|
||||
self.save_filename = name + '.apsave' if ext.lower() in ('.archipelago','.zip') \
|
||||
else self.data_filename + '_' + 'apsave'
|
||||
else self.data_filename + '_' + 'apsave'
|
||||
try:
|
||||
with open(self.save_filename, 'rb') as f:
|
||||
save_data = restricted_loads(zlib.decompress(f.read()))
|
||||
|
@ -256,13 +304,6 @@ class Context(Node):
|
|||
import atexit
|
||||
atexit.register(self._save, True) # make sure we save on exit too
|
||||
|
||||
def recheck_hints(self):
|
||||
for team, slot in self.hints:
|
||||
self.hints[team, slot] = {
|
||||
hint.re_check(self, team) for hint in
|
||||
self.hints[team, slot]
|
||||
}
|
||||
|
||||
def get_save(self) -> dict:
|
||||
self.recheck_hints()
|
||||
d = {
|
||||
|
@ -303,43 +344,48 @@ class Context(Node):
|
|||
logging.info(f'Loaded save file with {sum([len(p) for p in self.received_items.values()])} received items '
|
||||
f'for {len(self.received_items)} players')
|
||||
|
||||
# rest
|
||||
|
||||
def get_hint_cost(self, slot):
|
||||
if self.hint_cost:
|
||||
return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot])))
|
||||
return 0
|
||||
|
||||
def recheck_hints(self):
|
||||
for team, slot in self.hints:
|
||||
self.hints[team, slot] = {
|
||||
hint.re_check(self, team) for hint in
|
||||
self.hints[team, slot]
|
||||
}
|
||||
|
||||
def get_players_package(self):
|
||||
return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
|
||||
|
||||
def _set_options(self, server_options: dict):
|
||||
for key, value in server_options.items():
|
||||
data_type = self.simple_options.get(key, None)
|
||||
if data_type is not None:
|
||||
if value not in {False, True, None}: # some can be boolean OR text, such as password
|
||||
try:
|
||||
value = data_type(value)
|
||||
except Exception as e:
|
||||
try:
|
||||
raise Exception(f"Could not set server option {key}, skipping.") from e
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
logging.debug(f"Setting server option {key} to {value} from supplied multidata")
|
||||
setattr(self, key, value)
|
||||
elif key == "disable_item_cheat":
|
||||
self.item_cheat = not bool(value)
|
||||
else:
|
||||
logging.debug(f"Unrecognized server option {key}")
|
||||
|
||||
def get_aliased_name(self, team: int, slot: int):
|
||||
if (team, slot) in self.name_aliases:
|
||||
return f"{self.name_aliases[team, slot]} ({self.player_names[team, slot]})"
|
||||
else:
|
||||
return self.player_names[team, slot]
|
||||
|
||||
def notify_all(self, text):
|
||||
logging.info("Notice (all): %s" % text)
|
||||
self.broadcast_all([{"cmd": "Print", "text": text}])
|
||||
|
||||
def notify_client(self, client: Client, text: str):
|
||||
if not client.auth:
|
||||
return
|
||||
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
|
||||
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
|
||||
|
||||
def notify_client_multiple(self, client: Client, texts: typing.List[str]):
|
||||
if not client.auth:
|
||||
return
|
||||
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
|
||||
|
||||
def broadcast_team(self, team, msgs):
|
||||
msgs = self.dumper(msgs)
|
||||
for client in self.endpoints:
|
||||
if client.auth and client.team == team:
|
||||
asyncio.create_task(self.send_encoded_msgs(client, msgs))
|
||||
|
||||
def broadcast_all(self, msgs):
|
||||
msgs = self.dumper(msgs)
|
||||
for endpoint in self.endpoints:
|
||||
if endpoint.auth:
|
||||
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
|
||||
|
||||
async def disconnect(self, endpoint):
|
||||
await super(Context, self).disconnect(endpoint)
|
||||
await on_client_disconnected(self, endpoint)
|
||||
|
||||
|
||||
def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]):
|
||||
concerns = collections.defaultdict(list)
|
||||
|
@ -1431,8 +1477,7 @@ async def main(args: argparse.Namespace):
|
|||
|
||||
ctx = Context(args.host, args.port, args.server_password, args.password, args.location_check_points,
|
||||
args.hint_cost, not args.disable_item_cheat, args.forfeit_mode, args.remaining_mode,
|
||||
args.auto_shutdown, args.compatibility)
|
||||
ctx.log_network = args.log_network
|
||||
args.auto_shutdown, args.compatibility, args.log_network)
|
||||
data_filename = args.multidata
|
||||
|
||||
try:
|
||||
|
|
48
NetUtils.py
48
NetUtils.py
|
@ -1,6 +1,4 @@
|
|||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
import enum
|
||||
from json import JSONEncoder, JSONDecoder
|
||||
|
@ -94,52 +92,6 @@ def _object_hook(o: typing.Any) -> typing.Any:
|
|||
decode = JSONDecoder(object_hook=_object_hook).decode
|
||||
|
||||
|
||||
class Node:
|
||||
endpoints: typing.List
|
||||
dumper = staticmethod(encode)
|
||||
loader = staticmethod(decode)
|
||||
|
||||
def __init__(self):
|
||||
self.endpoints = []
|
||||
super(Node, self).__init__()
|
||||
self.log_network = 0
|
||||
|
||||
def broadcast_all(self, msgs):
|
||||
msgs = self.dumper(msgs)
|
||||
for endpoint in self.endpoints:
|
||||
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
|
||||
|
||||
async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
|
||||
if not endpoint.socket or not endpoint.socket.open:
|
||||
return False
|
||||
msg = self.dumper(msgs)
|
||||
try:
|
||||
await endpoint.socket.send(msg)
|
||||
except websockets.ConnectionClosed:
|
||||
logging.exception(f"Exception during send_msgs, could not send {msg}")
|
||||
await self.disconnect(endpoint)
|
||||
else:
|
||||
if self.log_network:
|
||||
logging.info(f"Outgoing message: {msg}")
|
||||
return True
|
||||
|
||||
async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
|
||||
if not endpoint.socket or not endpoint.socket.open:
|
||||
return False
|
||||
try:
|
||||
await endpoint.socket.send(msg)
|
||||
except websockets.ConnectionClosed:
|
||||
logging.exception("Exception during send_encoded_msgs")
|
||||
await self.disconnect(endpoint)
|
||||
else:
|
||||
if self.log_network:
|
||||
logging.info(f"Outgoing message: {msg}")
|
||||
return True
|
||||
|
||||
async def disconnect(self, endpoint):
|
||||
if endpoint in self.endpoints:
|
||||
self.endpoints.remove(endpoint)
|
||||
|
||||
|
||||
class Endpoint:
|
||||
socket: websockets.WebSocketServerProtocol
|
||||
|
|
Loading…
Reference in New Issue