MultiServer: categorize methods
This commit is contained in:
parent
acbca78e2d
commit
f4f043ac87
1
Main.py
1
Main.py
|
@ -377,7 +377,6 @@ def main(args, seed=None):
|
||||||
f.write(bytes([1])) # version of format
|
f.write(bytes([1])) # version of format
|
||||||
f.write(multidata)
|
f.write(multidata)
|
||||||
|
|
||||||
|
|
||||||
multidata_task = pool.submit(write_multidata)
|
multidata_task = pool.submit(write_multidata)
|
||||||
if not check_accessibility_task.result():
|
if not check_accessibility_task.result():
|
||||||
if not world.can_beat_game():
|
if not world.can_beat_game():
|
||||||
|
|
183
MultiServer.py
183
MultiServer.py
|
@ -31,10 +31,11 @@ from worlds import network_data_package, lookup_any_item_id_to_name, lookup_any_
|
||||||
import Utils
|
import Utils
|
||||||
from Utils import get_item_name_from_id, get_location_name_from_id, \
|
from Utils import get_item_name_from_id, get_location_name_from_id, \
|
||||||
version_tuple, restricted_loads, Version
|
version_tuple, restricted_loads, Version
|
||||||
from NetUtils import Node, Endpoint, ClientStatus, NetworkItem, decode, NetworkPlayer
|
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer
|
||||||
|
|
||||||
colorama.init()
|
colorama.init()
|
||||||
|
|
||||||
|
|
||||||
class Client(Endpoint):
|
class Client(Endpoint):
|
||||||
version = Version(0, 0, 0)
|
version = Version(0, 0, 0)
|
||||||
tags: typing.List[str] = []
|
tags: typing.List[str] = []
|
||||||
|
@ -50,9 +51,14 @@ class Client(Endpoint):
|
||||||
self.messageprocessor = client_message_processor(ctx, self)
|
self.messageprocessor = client_message_processor(ctx, self)
|
||||||
self.ctx = weakref.ref(ctx)
|
self.ctx = weakref.ref(ctx)
|
||||||
|
|
||||||
|
|
||||||
team_slot = typing.Tuple[int, int]
|
team_slot = typing.Tuple[int, int]
|
||||||
|
|
||||||
class Context(Node):
|
|
||||||
|
class Context:
|
||||||
|
dumper = staticmethod(encode)
|
||||||
|
loader = staticmethod(decode)
|
||||||
|
|
||||||
simple_options = {"hint_cost": int,
|
simple_options = {"hint_cost": int,
|
||||||
"location_check_points": int,
|
"location_check_points": int,
|
||||||
"server_password": str,
|
"server_password": str,
|
||||||
|
@ -64,8 +70,10 @@ class Context(Node):
|
||||||
|
|
||||||
def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int,
|
def __init__(self, host: str, port: int, server_password: str, password: str, location_check_points: int,
|
||||||
hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", remaining_mode: str = "disabled",
|
hint_cost: int, item_cheat: bool, forfeit_mode: str = "disabled", remaining_mode: str = "disabled",
|
||||||
auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2):
|
auto_shutdown: typing.SupportsFloat = 0, compatibility: int = 2, log_network: bool = False):
|
||||||
super(Context, self).__init__()
|
super(Context, self).__init__()
|
||||||
|
self.log_network = log_network
|
||||||
|
self.endpoints = []
|
||||||
self.compatibility: int = compatibility
|
self.compatibility: int = compatibility
|
||||||
self.shutdown_task = None
|
self.shutdown_task = None
|
||||||
self.data_filename = None
|
self.data_filename = None
|
||||||
|
@ -113,10 +121,70 @@ class Context(Node):
|
||||||
self.seed_name = ""
|
self.seed_name = ""
|
||||||
self.random = random.Random()
|
self.random = random.Random()
|
||||||
|
|
||||||
def get_hint_cost(self, slot):
|
# General networking
|
||||||
if self.hint_cost:
|
|
||||||
return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot])))
|
async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
|
||||||
return 0
|
if not endpoint.socket or not endpoint.socket.open:
|
||||||
|
return False
|
||||||
|
msg = self.dumper(msgs)
|
||||||
|
try:
|
||||||
|
await endpoint.socket.send(msg)
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
logging.exception(f"Exception during send_msgs, could not send {msg}")
|
||||||
|
await self.disconnect(endpoint)
|
||||||
|
else:
|
||||||
|
if self.log_network:
|
||||||
|
logging.info(f"Outgoing message: {msg}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
|
||||||
|
if not endpoint.socket or not endpoint.socket.open:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
await endpoint.socket.send(msg)
|
||||||
|
except websockets.ConnectionClosed:
|
||||||
|
logging.exception("Exception during send_encoded_msgs")
|
||||||
|
await self.disconnect(endpoint)
|
||||||
|
else:
|
||||||
|
if self.log_network:
|
||||||
|
logging.info(f"Outgoing message: {msg}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def broadcast_all(self, msgs):
|
||||||
|
msgs = self.dumper(msgs)
|
||||||
|
for endpoint in self.endpoints:
|
||||||
|
if endpoint.auth:
|
||||||
|
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
|
||||||
|
|
||||||
|
def broadcast_team(self, team, msgs):
|
||||||
|
msgs = self.dumper(msgs)
|
||||||
|
for client in self.endpoints:
|
||||||
|
if client.auth and client.team == team:
|
||||||
|
asyncio.create_task(self.send_encoded_msgs(client, msgs))
|
||||||
|
|
||||||
|
async def disconnect(self, endpoint):
|
||||||
|
if endpoint in self.endpoints:
|
||||||
|
self.endpoints.remove(endpoint)
|
||||||
|
await on_client_disconnected(self, endpoint)
|
||||||
|
|
||||||
|
# text
|
||||||
|
|
||||||
|
def notify_all(self, text):
|
||||||
|
logging.info("Notice (all): %s" % text)
|
||||||
|
self.broadcast_all([{"cmd": "Print", "text": text}])
|
||||||
|
|
||||||
|
def notify_client(self, client: Client, text: str):
|
||||||
|
if not client.auth:
|
||||||
|
return
|
||||||
|
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
|
||||||
|
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
|
||||||
|
|
||||||
|
def notify_client_multiple(self, client: Client, texts: typing.List[str]):
|
||||||
|
if not client.auth:
|
||||||
|
return
|
||||||
|
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
|
||||||
|
|
||||||
|
# loading
|
||||||
|
|
||||||
def load(self, multidatapath: str, use_embedded_server_options: bool = False):
|
def load(self, multidatapath: str, use_embedded_server_options: bool = False):
|
||||||
if multidatapath.lower().endswith(".zip"):
|
if multidatapath.lower().endswith(".zip"):
|
||||||
|
@ -177,27 +245,7 @@ class Context(Node):
|
||||||
server_options = decoded_obj.get("server_options", {})
|
server_options = decoded_obj.get("server_options", {})
|
||||||
self._set_options(server_options)
|
self._set_options(server_options)
|
||||||
|
|
||||||
def get_players_package(self):
|
# saving
|
||||||
return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
|
|
||||||
|
|
||||||
def _set_options(self, server_options: dict):
|
|
||||||
for key, value in server_options.items():
|
|
||||||
data_type = self.simple_options.get(key, None)
|
|
||||||
if data_type is not None:
|
|
||||||
if value not in {False, True, None}: # some can be boolean OR text, such as password
|
|
||||||
try:
|
|
||||||
value = data_type(value)
|
|
||||||
except Exception as e:
|
|
||||||
try:
|
|
||||||
raise Exception(f"Could not set server option {key}, skipping.") from e
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(e)
|
|
||||||
logging.debug(f"Setting server option {key} to {value} from supplied multidata")
|
|
||||||
setattr(self, key, value)
|
|
||||||
elif key == "disable_item_cheat":
|
|
||||||
self.item_cheat = not bool(value)
|
|
||||||
else:
|
|
||||||
logging.debug(f"Unrecognized server option {key}")
|
|
||||||
|
|
||||||
def save(self, now=False) -> bool:
|
def save(self, now=False) -> bool:
|
||||||
if self.saving:
|
if self.saving:
|
||||||
|
@ -228,7 +276,7 @@ class Context(Node):
|
||||||
import os
|
import os
|
||||||
name, ext = os.path.splitext(self.data_filename)
|
name, ext = os.path.splitext(self.data_filename)
|
||||||
self.save_filename = name + '.apsave' if ext.lower() in ('.archipelago','.zip') \
|
self.save_filename = name + '.apsave' if ext.lower() in ('.archipelago','.zip') \
|
||||||
else self.data_filename + '_' + 'apsave'
|
else self.data_filename + '_' + 'apsave'
|
||||||
try:
|
try:
|
||||||
with open(self.save_filename, 'rb') as f:
|
with open(self.save_filename, 'rb') as f:
|
||||||
save_data = restricted_loads(zlib.decompress(f.read()))
|
save_data = restricted_loads(zlib.decompress(f.read()))
|
||||||
|
@ -256,13 +304,6 @@ class Context(Node):
|
||||||
import atexit
|
import atexit
|
||||||
atexit.register(self._save, True) # make sure we save on exit too
|
atexit.register(self._save, True) # make sure we save on exit too
|
||||||
|
|
||||||
def recheck_hints(self):
|
|
||||||
for team, slot in self.hints:
|
|
||||||
self.hints[team, slot] = {
|
|
||||||
hint.re_check(self, team) for hint in
|
|
||||||
self.hints[team, slot]
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_save(self) -> dict:
|
def get_save(self) -> dict:
|
||||||
self.recheck_hints()
|
self.recheck_hints()
|
||||||
d = {
|
d = {
|
||||||
|
@ -303,43 +344,48 @@ class Context(Node):
|
||||||
logging.info(f'Loaded save file with {sum([len(p) for p in self.received_items.values()])} received items '
|
logging.info(f'Loaded save file with {sum([len(p) for p in self.received_items.values()])} received items '
|
||||||
f'for {len(self.received_items)} players')
|
f'for {len(self.received_items)} players')
|
||||||
|
|
||||||
|
# rest
|
||||||
|
|
||||||
|
def get_hint_cost(self, slot):
|
||||||
|
if self.hint_cost:
|
||||||
|
return max(0, int(self.hint_cost * 0.01 * len(self.locations[slot])))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def recheck_hints(self):
|
||||||
|
for team, slot in self.hints:
|
||||||
|
self.hints[team, slot] = {
|
||||||
|
hint.re_check(self, team) for hint in
|
||||||
|
self.hints[team, slot]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_players_package(self):
|
||||||
|
return [NetworkPlayer(t, p, self.get_aliased_name(t, p), n) for (t, p), n in self.player_names.items()]
|
||||||
|
|
||||||
|
def _set_options(self, server_options: dict):
|
||||||
|
for key, value in server_options.items():
|
||||||
|
data_type = self.simple_options.get(key, None)
|
||||||
|
if data_type is not None:
|
||||||
|
if value not in {False, True, None}: # some can be boolean OR text, such as password
|
||||||
|
try:
|
||||||
|
value = data_type(value)
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
raise Exception(f"Could not set server option {key}, skipping.") from e
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
logging.debug(f"Setting server option {key} to {value} from supplied multidata")
|
||||||
|
setattr(self, key, value)
|
||||||
|
elif key == "disable_item_cheat":
|
||||||
|
self.item_cheat = not bool(value)
|
||||||
|
else:
|
||||||
|
logging.debug(f"Unrecognized server option {key}")
|
||||||
|
|
||||||
def get_aliased_name(self, team: int, slot: int):
|
def get_aliased_name(self, team: int, slot: int):
|
||||||
if (team, slot) in self.name_aliases:
|
if (team, slot) in self.name_aliases:
|
||||||
return f"{self.name_aliases[team, slot]} ({self.player_names[team, slot]})"
|
return f"{self.name_aliases[team, slot]} ({self.player_names[team, slot]})"
|
||||||
else:
|
else:
|
||||||
return self.player_names[team, slot]
|
return self.player_names[team, slot]
|
||||||
|
|
||||||
def notify_all(self, text):
|
|
||||||
logging.info("Notice (all): %s" % text)
|
|
||||||
self.broadcast_all([{"cmd": "Print", "text": text}])
|
|
||||||
|
|
||||||
def notify_client(self, client: Client, text: str):
|
|
||||||
if not client.auth:
|
|
||||||
return
|
|
||||||
logging.info("Notice (Player %s in team %d): %s" % (client.name, client.team + 1, text))
|
|
||||||
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text}]))
|
|
||||||
|
|
||||||
def notify_client_multiple(self, client: Client, texts: typing.List[str]):
|
|
||||||
if not client.auth:
|
|
||||||
return
|
|
||||||
asyncio.create_task(self.send_msgs(client, [{"cmd": "Print", "text": text} for text in texts]))
|
|
||||||
|
|
||||||
def broadcast_team(self, team, msgs):
|
|
||||||
msgs = self.dumper(msgs)
|
|
||||||
for client in self.endpoints:
|
|
||||||
if client.auth and client.team == team:
|
|
||||||
asyncio.create_task(self.send_encoded_msgs(client, msgs))
|
|
||||||
|
|
||||||
def broadcast_all(self, msgs):
|
|
||||||
msgs = self.dumper(msgs)
|
|
||||||
for endpoint in self.endpoints:
|
|
||||||
if endpoint.auth:
|
|
||||||
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
|
|
||||||
|
|
||||||
async def disconnect(self, endpoint):
|
|
||||||
await super(Context, self).disconnect(endpoint)
|
|
||||||
await on_client_disconnected(self, endpoint)
|
|
||||||
|
|
||||||
|
|
||||||
def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]):
|
def notify_hints(ctx: Context, team: int, hints: typing.List[NetUtils.Hint]):
|
||||||
concerns = collections.defaultdict(list)
|
concerns = collections.defaultdict(list)
|
||||||
|
@ -1431,8 +1477,7 @@ async def main(args: argparse.Namespace):
|
||||||
|
|
||||||
ctx = Context(args.host, args.port, args.server_password, args.password, args.location_check_points,
|
ctx = Context(args.host, args.port, args.server_password, args.password, args.location_check_points,
|
||||||
args.hint_cost, not args.disable_item_cheat, args.forfeit_mode, args.remaining_mode,
|
args.hint_cost, not args.disable_item_cheat, args.forfeit_mode, args.remaining_mode,
|
||||||
args.auto_shutdown, args.compatibility)
|
args.auto_shutdown, args.compatibility, args.log_network)
|
||||||
ctx.log_network = args.log_network
|
|
||||||
data_filename = args.multidata
|
data_filename = args.multidata
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
48
NetUtils.py
48
NetUtils.py
|
@ -1,6 +1,4 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import typing
|
import typing
|
||||||
import enum
|
import enum
|
||||||
from json import JSONEncoder, JSONDecoder
|
from json import JSONEncoder, JSONDecoder
|
||||||
|
@ -94,52 +92,6 @@ def _object_hook(o: typing.Any) -> typing.Any:
|
||||||
decode = JSONDecoder(object_hook=_object_hook).decode
|
decode = JSONDecoder(object_hook=_object_hook).decode
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
|
||||||
endpoints: typing.List
|
|
||||||
dumper = staticmethod(encode)
|
|
||||||
loader = staticmethod(decode)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.endpoints = []
|
|
||||||
super(Node, self).__init__()
|
|
||||||
self.log_network = 0
|
|
||||||
|
|
||||||
def broadcast_all(self, msgs):
|
|
||||||
msgs = self.dumper(msgs)
|
|
||||||
for endpoint in self.endpoints:
|
|
||||||
asyncio.create_task(self.send_encoded_msgs(endpoint, msgs))
|
|
||||||
|
|
||||||
async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bool:
|
|
||||||
if not endpoint.socket or not endpoint.socket.open:
|
|
||||||
return False
|
|
||||||
msg = self.dumper(msgs)
|
|
||||||
try:
|
|
||||||
await endpoint.socket.send(msg)
|
|
||||||
except websockets.ConnectionClosed:
|
|
||||||
logging.exception(f"Exception during send_msgs, could not send {msg}")
|
|
||||||
await self.disconnect(endpoint)
|
|
||||||
else:
|
|
||||||
if self.log_network:
|
|
||||||
logging.info(f"Outgoing message: {msg}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
|
|
||||||
if not endpoint.socket or not endpoint.socket.open:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
await endpoint.socket.send(msg)
|
|
||||||
except websockets.ConnectionClosed:
|
|
||||||
logging.exception("Exception during send_encoded_msgs")
|
|
||||||
await self.disconnect(endpoint)
|
|
||||||
else:
|
|
||||||
if self.log_network:
|
|
||||||
logging.info(f"Outgoing message: {msg}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def disconnect(self, endpoint):
|
|
||||||
if endpoint in self.endpoints:
|
|
||||||
self.endpoints.remove(endpoint)
|
|
||||||
|
|
||||||
|
|
||||||
class Endpoint:
|
class Endpoint:
|
||||||
socket: websockets.WebSocketServerProtocol
|
socket: websockets.WebSocketServerProtocol
|
||||||
|
|
Loading…
Reference in New Issue