diff --git a/Items.py b/Items.py index 58466a1c..67118c22 100644 --- a/Items.py +++ b/Items.py @@ -62,8 +62,8 @@ item_table = {'Bow': (True, False, None, 0x0B, 'You have\nchosen the\narcher cla 'Progressive Glove': (True, False, None, 0x61, 'a way to lift\nheavier things', 'and the lift upgrade', 'body-building kid', 'some glove for sale', 'fungus for gloves', 'body-building boy lifts again', 'a glove'), 'Silver Arrows': (True, False, None, 0x58, 'Do you fancy\nsilver tipped\narrows?', 'and the ganonsbane', 'ganon-killing kid', 'ganon doom for sale', 'fungus for pork', 'archer boy shines again', 'the silver arrows'), 'Green Pendant': (True, False, 'Crystal', (0x04, 0x38, 0x62, 0x00, 0x69, 0x01), None, None, None, None, None, None, None), - 'Red Pendant': (True, False, 'Crystal', (0x02, 0x34, 0x60, 0x00, 0x69, 0x02), None, None, None, None, None, None, None), - 'Blue Pendant': (True, False, 'Crystal', (0x01, 0x32, 0x60, 0x00, 0x69, 0x03), None, None, None, None, None, None, None), + 'Blue Pendant': (True, False, 'Crystal', (0x02, 0x34, 0x60, 0x00, 0x69, 0x02), None, None, None, None, None, None, None), + 'Red Pendant': (True, False, 'Crystal', (0x01, 0x32, 0x60, 0x00, 0x69, 0x03), None, None, None, None, None, None, None), 'Triforce': (True, False, None, 0x6A, '\n YOU WIN!', 'and the triforce', 'victorious kid', 'victory for sale', 'fungus for the win', 'greedy boy wins game again', 'the Triforce'), 'Power Star': (True, False, None, 0x6B, 'a small victory', 'and the power star', 'star-struck kid', 'star for sale', 'see stars with shroom', 'mario powers up again', 'a Power Star'), 'Triforce Piece': (True, False, None, 0x6C, 'a small victory', 'and the thirdforce', 'triangular kid', 'triangle for sale', 'fungus for triangle', 'wise boy has triangle again', 'a Triforce Piece'), diff --git a/MultiClient.py b/MultiClient.py index 1cf3cc3b..22a86c35 100644 --- a/MultiClient.py +++ b/MultiClient.py @@ -2,10 +2,10 @@ import argparse import asyncio import json import logging -import typing import urllib.parse import atexit +from Utils import get_item_name_from_id, get_location_name_from_address, ReceivedItem exit_func = atexit.register(input, "Press enter to close.") @@ -18,16 +18,10 @@ import websockets import prompt_toolkit from prompt_toolkit.patch_stdout import patch_stdout -import Items import Regions import Utils -class ReceivedItem(typing.NamedTuple): - item: int - location: int - player: int - class Context: def __init__(self, snes_address, server_address, password, found_items): self.snes_address = snes_address @@ -609,7 +603,7 @@ async def server_loop(ctx : Context, address = None): logging.info('Connecting to multiworld server at %s' % address) try: - ctx.socket = await websockets.connect(address, port=port, ping_timeout=None, ping_interval=None) + ctx.socket = await websockets.connect(address, port=port, ping_timeout=60, ping_interval=30) logging.info('Connected') ctx.server_address = address @@ -771,26 +765,89 @@ async def server_auth(ctx: Context, password_requested): ctx.awaiting_rom = False ctx.auth = ctx.rom.copy() await send_msgs(ctx.socket, [['Connect', { - 'password': ctx.password, 'rom': ctx.auth, 'version': [1, 2, 0], 'tags': get_tags(ctx) + 'password': ctx.password, 'rom': ctx.auth, 'version': [1, 3, 0], 'tags': get_tags(ctx) }]]) async def console_input(ctx : Context): ctx.input_requests += 1 return await ctx.input_queue.get() + async def disconnect(ctx: Context): if ctx.socket is not None and not ctx.socket.closed: await ctx.socket.close() if ctx.server_task is not None: await ctx.server_task + async def connect(ctx: Context, address=None): await disconnect(ctx) ctx.server_task = asyncio.create_task(server_loop(ctx, address)) -async def console_loop(ctx : Context): +from MultiServer import CommandProcessor + + +class ClientCommandProcessor(CommandProcessor): + def __init__(self, ctx: Context): + self.ctx = ctx + + def _cmd_exit(self): + """Close connections and client""" + self.ctx.exit_event.set() + + def _cmd_snes(self, snes_address: str = ""): + """Connect to a snes. Optionally include network address of a snes to connect to, otherwise show available devices""" + self.ctx.snes_reconnect_address = None + asyncio.create_task(snes_connect(self.ctx, snes_address if snes_address else self.ctx.snes_address)) + + def _cmd_snes_close(self): + """Close connection to a currently connected snes""" + self.ctx.snes_reconnect_address = None + if self.ctx.snes_socket is not None and not self.ctx.snes_socket.closed: + asyncio.create_task(self.ctx.snes_socket.close()) + + def _cmd_connect(self, address: str = ""): + """Connect to a MultiWorld Server""" + self.ctx.server_address = None + asyncio.create_task(connect(self.ctx, address if address else None)) + + def _cmd_disconnect(self): + """Disconnect from a MultiWorld Server""" + self.ctx.server_address = None + asyncio.create_task(disconnect(self.ctx)) + + def _cmd_received(self): + """List all received items""" + logging.info('Received items:') + for index, item in enumerate(self.ctx.items_received, 1): + logging.info('%s from %s (%s) (%d/%d in list)' % ( + color(get_item_name_from_id(item.item), 'red', 'bold'), + color(self.ctx.player_names[item.player], 'yellow'), + get_location_name_from_address(item.location), index, len(self.ctx.items_received))) + + def _cmd_missing(self): + """List all missing location checks""" + for location in [k for k, v in Regions.location_table.items() if type(v[0]) is int]: + if location not in self.ctx.locations_checked: + logging.info('Missing: ' + location) + + def _cmd_show_items(self, toggle: str = ""): + """Toggle showing of items received across the team""" + if toggle: + self.ctx.found_items = toggle.lower() in {"1", "true", "on"} + else: + self.ctx.found_items = not self.ctx.found_items + logging.info(f"Set showing team items to {self.ctx.found_items}") + asyncio.create_task(send_msgs(self.ctx.socket, [['UpdateTags', get_tags(self.ctx)]])) + + def default(self, raw: str): + asyncio.create_task(send_msgs(self.ctx.socket, [['Say', raw]])) + + +async def console_loop(ctx: Context): session = prompt_toolkit.PromptSession() + commandprocessor = ClientCommandProcessor(ctx) while not ctx.exit_event.is_set(): try: with patch_stdout(): @@ -804,67 +861,11 @@ async def console_loop(ctx : Context): command = input_text.split() if not command: continue - - if command[0][:1] != '/': - asyncio.create_task(send_msgs(ctx.socket, [['Say', input_text]])) - continue - - precommand = command[0][1:] - - if precommand == 'exit': - ctx.exit_event.set() - - elif precommand == 'snes': - ctx.snes_reconnect_address = None - asyncio.create_task(snes_connect(ctx, command[1] if len(command) > 1 else ctx.snes_address)) - - elif precommand in {'snes_close', 'snes_quit'}: - ctx.snes_reconnect_address = None - if ctx.snes_socket is not None and not ctx.snes_socket.closed: - await ctx.snes_socket.close() - - elif precommand in {'connect', 'reconnect'}: - ctx.server_address = None - asyncio.create_task(connect(ctx, command[1] if len(command) > 1 else None)) - - elif precommand == 'disconnect': - ctx.server_address = None - asyncio.create_task(disconnect(ctx)) - - - elif precommand == 'received': - logging.info('Received items:') - for index, item in enumerate(ctx.items_received, 1): - logging.info('%s from %s (%s) (%d/%d in list)' % ( - color(get_item_name_from_id(item.item), 'red', 'bold'), - color(ctx.player_names[item.player], 'yellow'), - get_location_name_from_address(item.location), index, len(ctx.items_received))) - - elif precommand == 'missing': - for location in [k for k, v in Regions.location_table.items() if type(v[0]) is int]: - if location not in ctx.locations_checked: - logging.info('Missing: ' + location) - - elif precommand == "show_items": - if len(command) > 1: - ctx.found_items = command[1].lower() in {"1", "true", "on"} - else: - ctx.found_items = not ctx.found_items - logging.info(f"Set showing team items to {ctx.found_items}") - asyncio.create_task(send_msgs(ctx.socket, [['UpdateTags', get_tags(ctx)]])) - - elif precommand == "license": - with open("LICENSE") as f: - logging.info(f.read()) + commandprocessor(input_text) except Exception as e: logging.exception(e) await snes_flush_writes(ctx) -def get_item_name_from_id(code): - return Items.lookup_id_to_name.get(code, f'Unknown item (ID:{code})') - -def get_location_name_from_address(address): - return Regions.lookup_id_to_name.get(address, f'Unknown location (ID:{address})') async def track_locations(ctx : Context, roomid, roomdata): new_locations = [] diff --git a/MultiServer.py b/MultiServer.py index 1fdd90ec..9656fe4a 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import argparse import asyncio import functools @@ -20,7 +22,7 @@ from fuzzywuzzy import process as fuzzy_process import Items import Regions import Utils -from MultiClient import ReceivedItem, get_item_name_from_id, get_location_name_from_address +from Utils import get_item_name_from_id, get_location_name_from_address, ReceivedItem console_names = frozenset(set(Items.item_table) | set(Regions.location_table)) @@ -29,7 +31,7 @@ class Client: version: typing.List[int] = [0, 0, 0] tags: typing.List[str] = [] - def __init__(self, socket: websockets.server.WebSocketServerProtocol): + def __init__(self, socket: websockets.server.WebSocketServerProtocol, ctx: Context): self.socket = socket self.auth = False self.name = None @@ -38,6 +40,7 @@ class Client: self.send_index = 0 self.tags = [] self.version = [0, 0, 0] + self.messageprocessor = ClientMessageProcessor(ctx, self) @property def wants_item_notification(self): @@ -68,6 +71,7 @@ class Context: self.hints_sent = collections.defaultdict(set) self.item_cheat = item_cheat self.running = True + self.commandprocessor = ServerCommandProcessor(self) def get_save(self) -> dict: return { @@ -141,7 +145,7 @@ def notify_hints(ctx: Context, team: int, hints: typing.List[Utils.Hint]): asyncio.create_task(send_msgs(client.socket, payload)) async def server(websocket, path, ctx: Context): - client = Client(websocket) + client = Client(websocket, ctx) ctx.clients.append(client) try: @@ -169,7 +173,7 @@ async def on_client_connected(ctx: Context, client: Client): # tags are for additional features in the communication. # Name them by feature or fork, as you feel is appropriate. 'tags': ['Berserker'], - 'version': [1, 2, 0] + 'version': [1, 3, 0] }]]) async def on_client_disconnected(ctx: Context, client: Client): @@ -330,6 +334,182 @@ def get_intended_text(input_text: str, possible_answers: typing.Iterable[str]= c return picks[0][0], False, f"Too many close matches, did you mean {picks[0][0]}?" +class CommandMeta(type): + def __new__(cls, name, bases, attrs): + commands = attrs["commands"] = {} + for base in bases: + commands.update(base.commands) + commands.update({name[5:].lower(): method for name, method in attrs.items() if + name.startswith("_cmd_")}) + return super(CommandMeta, cls).__new__(cls, name, bases, attrs) + + +class CommandProcessor(metaclass=CommandMeta): + commands: typing.Dict[str, typing.Callable] + marker = "/" + + def output(self, text: str): + print(text) + + def __call__(self, raw: str): + if not raw: + return + try: + command = raw.split() + basecommand = command[0] + if basecommand[0] == self.marker: + method = self.commands.get(basecommand[1:].lower(), None) + if not method: + self._error_unknown_command(basecommand[1:]) + else: + method(self, *command[1:]) + else: + self.default(raw) + except Exception as e: + self._error_parsing_command(e) + + def get_help_text(self) -> str: + s = "" + for command, method in self.commands.items(): + spec = inspect.signature(method).parameters + argtext = "" + for argname, parameter in spec.items(): + if argname == "self": + continue + + if isinstance(parameter.default, str): + if not parameter.default: + argname = f"[{argname}]" + else: + argname += "=" + parameter.default + argtext += argname + argtext += " " + s += f"{self.marker}{command} {argtext}\n {method.__doc__}\n" + return s + + def _cmd_help(self): + """Returns the help listing""" + self.output(self.get_help_text()) + + def _cmd_license(self): + """Returns the licensing information""" + with open("LICENSE") as f: + self.output(f.read()) + + def default(self, raw: str): + self.output("Echo: " + raw) + + def _error_unknown_command(self, raw: str): + self.output(f"Could not find command {raw}. Known commands: {', '.join(self.commands)}") + + def _error_parsing_command(self, exception: Exception): + self.output(str(exception)) + + +class ClientMessageProcessor(CommandProcessor): + marker = "!" + ctx: Context + + def __init__(self, ctx: Context, client: Client): + self.ctx = ctx + self.client = client + + def output(self, text): + notify_client(self.client, text) + + def default(self, raw: str): + pass # default is client sending just text + + def _cmd_players(self): + """Get information about connected and missing players""" + notify_all(self.ctx, get_connected_players_string(self.ctx)) + + def _cmd_forfeit(self): + """Surrender and send your remaining items out to their recipients""" + forfeit_player(self.ctx, self.client.team, self.client.slot) + + def _cmd_countdown(self, seconds: str = "10"): + """Start a countdown in seconds""" + try: + timer = int(seconds) + except ValueError: + timer = 10 + asyncio.create_task(countdown(self.ctx, timer)) + + def _cmd_getitem(self, *item_name: str): + """Cheat in an item""" + item_name = " ".join(item_name) + if self.ctx.item_cheat: + item_name, usable, response = get_intended_text(item_name, Items.item_table.keys()) + if usable: + new_item = ReceivedItem(Items.item_table[item_name][3], -1, self.client.slot) + get_received_items(self.ctx, self.client.team, self.client.slot).append(new_item) + notify_all(self.ctx, 'Cheat console: sending "' + item_name + '" to ' + self.client.name) + send_new_items(self.ctx) + else: + self.output(response) + else: + self.output("Cheating is disabled.") + + def _cmd_hint(self, *item_or_location: str): + """Use !hint {item_name/location_name}, for example !hint Lamp or !hint Link's House. """ + points_available = self.ctx.location_check_points * len( + self.ctx.location_checks[self.client.team, self.client.slot]) - \ + self.ctx.hint_cost * self.ctx.hints_used[self.client.team, self.client.slot] + item_or_location = " ".join(item_or_location) + if not item_or_location: + self.output(f"A hint costs {self.ctx.hint_cost} points. " + f"You have {points_available} points.") + for item_name in self.ctx.hints_sent[self.client.team, self.client.slot]: + if item_name in Items.item_table: # item name + hints = collect_hints(self.ctx, self.client.team, self.client.slot, item_name) + else: # location name + hints = collect_hints_location(self.ctx, self.client.team, self.client.slot, item_name) + notify_hints(self.ctx, self.client.team, hints) + else: + item_name, usable, response = get_intended_text(item_or_location) + if usable: + if item_name in Items.hint_blacklist: + self.output(f"Sorry, \"{item_name}\" is marked as non-hintable.") + hints = [] + elif item_name in Items.item_table: # item name + hints = collect_hints(self.ctx, self.client.team, self.client.slot, item_name) + else: # location name + hints = collect_hints_location(self.ctx, self.client.team, self.client.slot, item_name) + + if hints: + if item_name in self.ctx.hints_sent[self.client.team, self.client.slot]: + notify_hints(self.ctx, self.client.team, hints) + self.output("Hint was previously used, no points deducted.") + else: + found = 0 + for hint in hints: + found += 1 - hint.found + if not found: + notify_hints(self.ctx, self.client.team, hints) + self.output("No new items found, no points deducted.") + else: + if self.ctx.hint_cost: + can_pay = points_available // (self.ctx.hint_cost * found) >= 1 + else: + can_pay = True + + if can_pay: + self.ctx.hints_used[self.client.team, self.client.slot] += found + self.ctx.hints_sent[self.client.team, self.client.slot].add(item_name) + notify_hints(self.ctx, self.client.team, hints) + save(self.ctx) + else: + notify_client(self.client, f"You can't afford the hint. " + f"You have {points_available} points and need at least " + f"{self.ctx.hint_cost}, " + f"more if multiple items are still to be found.") + else: + self.output("Nothing found. Item/Location may not exist.") + else: + self.output(response) + + async def process_client_cmd(ctx: Context, client: Client, cmd, args): if type(cmd) is not str: await send_msgs(client.socket, [['InvalidCmd']]) @@ -372,133 +552,56 @@ async def process_client_cmd(ctx: Context, client: Client, cmd, args): await send_msgs(client.socket, reply) await on_client_joined(ctx, client) - if not client.auth: - return + if client.auth: + if cmd == 'Sync': + items = get_received_items(ctx, client.team, client.slot) + if items: + client.send_index = len(items) + await send_msgs(client.socket, [['ReceivedItems', (0, tuplize_received_items(items))]]) - if cmd == 'Sync': - items = get_received_items(ctx, client.team, client.slot) - if items: - client.send_index = len(items) - await send_msgs(client.socket, [['ReceivedItems', (0, tuplize_received_items(items))]]) + elif cmd == 'LocationChecks': + if type(args) is not list: + await send_msgs(client.socket, [['InvalidArguments', 'LocationChecks']]) + return + register_location_checks(ctx, client.team, client.slot, args) - if cmd == 'LocationChecks': - if type(args) is not list: - await send_msgs(client.socket, [['InvalidArguments', 'LocationChecks']]) - return - register_location_checks(ctx, client.team, client.slot, args) - - if cmd == 'LocationScouts': - if type(args) is not list: - await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']]) - return - locs = [] - for location in args: - if type(location) is not int or 0 >= location > len(Regions.location_table): + elif cmd == 'LocationScouts': + if type(args) is not list: await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']]) return - loc_name = list(Regions.location_table.keys())[location - 1] - target_item, target_player = ctx.locations[(Regions.location_table[loc_name][0], client.slot)] + locs = [] + for location in args: + if type(location) is not int or 0 >= location > len(Regions.location_table): + await send_msgs(client.socket, [['InvalidArguments', 'LocationScouts']]) + return + loc_name = list(Regions.location_table.keys())[location - 1] + target_item, target_player = ctx.locations[(Regions.location_table[loc_name][0], client.slot)] - replacements = {'SmallKey': 0xA2, 'BigKey': 0x9D, 'Compass': 0x8D, 'Map': 0x7D} - item_type = [i[2] for i in Items.item_table.values() if type(i[3]) is int and i[3] == target_item] - if item_type: - target_item = replacements.get(item_type[0], target_item) + replacements = {'SmallKey': 0xA2, 'BigKey': 0x9D, 'Compass': 0x8D, 'Map': 0x7D} + item_type = [i[2] for i in Items.item_table.values() if type(i[3]) is int and i[3] == target_item] + if item_type: + target_item = replacements.get(item_type[0], target_item) - locs.append([loc_name, location, target_item, target_player]) + locs.append([loc_name, location, target_item, target_player]) - logging.info(f"{client.name} in team {client.team+1} scouted {', '.join([l[0] for l in locs])}") - await send_msgs(client.socket, [['LocationInfo', [l[1:] for l in locs]]]) + # logging.info(f"{client.name} in team {client.team+1} scouted {', '.join([l[0] for l in locs])}") + await send_msgs(client.socket, [['LocationInfo', [l[1:] for l in locs]]]) - if cmd == 'UpdateTags': - if not args or type(args) is not list: - await send_msgs(client.socket, [['InvalidArguments', 'UpdateTags']]) - return - client.tags = args + elif cmd == 'UpdateTags': + if not args or type(args) is not list: + await send_msgs(client.socket, [['InvalidArguments', 'UpdateTags']]) + return + client.tags = args - if cmd == 'Say': - if type(args) is not str or not args.isprintable(): - await send_msgs(client.socket, [['InvalidArguments', 'Say']]) - return + if cmd == 'Say': + if type(args) is not str or not args.isprintable(): + await send_msgs(client.socket, [['InvalidArguments', 'Say']]) + return - notify_all(ctx, client.name + ': ' + args) + notify_all(ctx, client.name + ': ' + args) + print(args) + client.messageprocessor(args) - if args.startswith('!players'): - notify_all(ctx, get_connected_players_string(ctx)) - elif args.startswith('!forfeit'): - forfeit_player(ctx, client.team, client.slot) - elif args.startswith('!countdown'): - try: - timer = int(args.split()[1]) - except (IndexError, ValueError): - timer = 10 - asyncio.create_task(countdown(ctx, timer)) - elif args.startswith('!getitem') and ctx.item_cheat: - item_name = args[9:].lower() - item_name, usable, response = get_intended_text(item_name, Items.item_table.keys()) - if usable: - new_item = ReceivedItem(Items.item_table[item_name][3], -1, client.slot) - get_received_items(ctx, client.team, client.slot).append(new_item) - notify_all(ctx, 'Cheat console: sending "' + item_name + '" to ' + client.name) - send_new_items(ctx) - else: - notify_client(client, response) - elif args.startswith("!hint"): - points_available = ctx.location_check_points * len(ctx.location_checks[client.team, client.slot]) - \ - ctx.hint_cost * ctx.hints_used[client.team, client.slot] - item_name = args[6:] - - if not item_name: - notify_client(client, "Use !hint {item_name/location_name}, " - "for example !hint Lamp or !hint Link's House. " - f"A hint costs {ctx.hint_cost} points. " - f"You have {points_available} points.") - for item_name in ctx.hints_sent[client.team, client.slot]: - if item_name in Items.item_table: # item name - hints = collect_hints(ctx, client.team, client.slot, item_name) - else: # location name - hints = collect_hints_location(ctx, client.team, client.slot, item_name) - notify_hints(ctx, client.team, hints) - else: - item_name, usable, response = get_intended_text(item_name) - if usable: - if item_name in Items.hint_blacklist: - notify_client(client, f"Sorry, \"{item_name}\" is marked as non-hintable.") - hints = [] - elif item_name in Items.item_table: # item name - hints = collect_hints(ctx, client.team, client.slot, item_name) - else: # location name - hints = collect_hints_location(ctx, client.team, client.slot, item_name) - - if hints: - if item_name in ctx.hints_sent[client.team, client.slot]: - notify_hints(ctx, client.team, hints) - notify_client(client, "Hint was previously used, no points deducted.") - else: - found = 0 - for hint in hints: - found += 1 - hint.found - if not found: - notify_hints(ctx, client.team, hints) - notify_client(client, "No new items found, no points deducted.") - else: - if ctx.hint_cost: - can_pay = points_available // (ctx.hint_cost * found) >= 1 - else: - can_pay = True - - if can_pay: - ctx.hints_used[client.team, client.slot] += found - ctx.hints_sent[client.team, client.slot].add(item_name) - notify_hints(ctx, client.team, hints) - save(ctx) - else: - notify_client(client, f"You can't afford the hint. " - f"You have {points_available} points and need at least {ctx.hint_cost}, " - f"more if multiple items are still to be found.") - else: - notify_client(client, "Nothing found. Item/Location may not exist.") - else: - notify_client(client, response) def set_password(ctx: Context, password): @@ -506,48 +609,6 @@ def set_password(ctx: Context, password): logging.warning('Password set to ' + password if password else 'Password disabled') -class CommandProcessor(): - commands: typing.Dict[str, typing.Callable] - - def __init__(self): - self.commands = {name[5:].lower(): method for name, method in inspect.getmembers(self) if - name.startswith("_cmd_")} - - def output(self, text: str): - print(text) - - def __call__(self, raw: str): - if not raw: - return - command = raw.split() - basecommand = command[0] - if basecommand[0] == "/": - method = self.commands.get(basecommand[1:].lower(), None) - if not method: - self._error_unknown_command(basecommand[1:]) - else: - method(*command[1:]) - else: - self.default(raw) - - def get_help_text(self) -> str: - s = "" - for command, method in self.commands.items(): - spec = inspect.signature(method).parameters - s += f"/{command} {' '.join(spec)}\n {method.__doc__}\n" - return s - - def _cmd_help(self): - """Returns the help listing""" - self.output(self.get_help_text()) - - def default(self, raw: str): - self.output("Echo: " + raw) - - def _error_unknown_command(self, raw: str): - self.output(f"Could not find command {raw}. Known commands: {', '.join(self.commands)}") - - class ServerCommandProcessor(CommandProcessor): ctx: Context @@ -577,9 +638,10 @@ class ServerCommandProcessor(CommandProcessor): asyncio.create_task(self.ctx.server.ws_server._close()) self.ctx.running = False - def _cmd_password(self, new_password: str = ""): + def _cmd_password(self, *new_password: str): """Set the server password. Leave the password text empty to remove the password""" - set_password(self.ctx, new_password if new_password else None) + + set_password(self.ctx, " ".join(new_password) if new_password else None) def _cmd_forfeit(self, player_name: str): """Send out the remaining items from a player's game to their intended recipients""" @@ -634,12 +696,11 @@ class ServerCommandProcessor(CommandProcessor): async def console(ctx: Context): session = prompt_toolkit.PromptSession() - cmd_processor = ServerCommandProcessor(ctx) while ctx.running: with patch_stdout(): input_text = await session.prompt_async() try: - cmd_processor(input_text) + ctx.commandprocessor(input_text) except: import traceback traceback.print_exc() diff --git a/Patch.py b/Patch.py index def899e8..5b7e791f 100644 --- a/Patch.py +++ b/Patch.py @@ -3,6 +3,10 @@ import yaml import os import lzma import hashlib +import threading +import concurrent.futures +import zipfile +import sys from typing import Tuple, Optional import Utils @@ -11,7 +15,7 @@ from Rom import JAP10HASH, read_rom base_rom_bytes = None -def get_base_rom_bytes(file_name: str = None) -> bytes: +def get_base_rom_bytes(file_name: str = "") -> bytes: global base_rom_bytes if not base_rom_bytes: options = Utils.get_options() @@ -29,7 +33,7 @@ def get_base_rom_bytes(file_name: str = None) -> bytes: return base_rom_bytes -def generate_patch(rom: bytes, metadata=None) -> bytes: +def generate_patch(rom: bytes, metadata: Optional[dict] = None) -> bytes: if metadata is None: metadata = {} patch = bsdiff4.diff(get_base_rom_bytes(), rom) @@ -47,7 +51,7 @@ def create_patch_file(rom_file_to_patch: str, server: str = "") -> str: return target -def create_rom_file(patch_file) -> Tuple[dict, str]: +def create_rom_file(patch_file: str) -> Tuple[dict, str]: data = Utils.parse_yaml(lzma.decompress(load_bytes(patch_file)).decode("utf-8-sig")) patched_data = bsdiff4.patch(get_base_rom_bytes(), data["patch"]) target = os.path.splitext(patch_file)[0] + ".sfc" @@ -56,7 +60,14 @@ def create_rom_file(patch_file) -> Tuple[dict, str]: return data["meta"], target -def load_bytes(path: str): +def update_patch_data(patch_data: bytes, server: str = "") -> bytes: + data = Utils.parse_yaml(lzma.decompress(patch_data).decode("utf-8-sig")) + data["meta"]["server"] = server + bytes = generate_patch(data["patch"], data["meta"]) + return lzma.compress(bytes) + + +def load_bytes(path: str) -> bytes: with open(path, "rb") as f: return f.read() @@ -66,11 +77,51 @@ def write_lzma(data: bytes, path: str): f.write(data) if __name__ == "__main__": - ipv4 = Utils.get_public_ipv4() - import sys + host = Utils.get_public_ipv4() + options = Utils.get_options()['server_options'] + if options['host']: + host = options['host'] + address = f"{host}:{options['port']}" + ziplock = threading.Lock() + print(f"Host for patches to be created is {address}") + + Processed = False for rom in sys.argv: - if rom.endswith(".sfc"): - print(f"Creating patch for {rom}") - result = create_patch_file(rom, ipv4) - print(f"Created patch {result}") + try: + if rom.endswith(".sfc"): + print(f"Creating patch for {rom}") + result = create_patch_file(rom, address) + print(f"Created patch {result}") + elif rom.endswith(".bmbp"): + print(f"Applying patch {rom}") + data, target = create_rom_file(rom) + print(f"Created rom {target}.") + if 'server' in data: + print(f"Host is {data['server']}") + elif rom.endswith(".zip"): + print(f"Updating host in patch files contained in {rom}") + def _handle_zip_file_entry(zfinfo : zipfile.ZipInfo, server: str): + data = zfr.read(zfinfo) + if zfinfo.filename.endswith(".bmbp"): + data = update_patch_data(data, server) + with ziplock: + zfw.writestr(zfinfo, data) + return zfinfo.filename + + with concurrent.futures.ThreadPoolExecutor() as pool: + futures = [] + with zipfile.ZipFile(rom, "r") as zfr: + updated_zip = os.path.splitext(rom)[0] + "_updated.zip" + with zipfile.ZipFile(updated_zip, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zfw: + for zfname in zfr.namelist(): + futures.append(pool.submit(_handle_zip_file_entry, zfr.getinfo(zfname), address)) + for future in futures: + print(f"File {future.result()} added to {os.path.split(updated_zip)[1]}") + + except: + import traceback + traceback.print_exc() + + if Processed: + input("Press enter to close.") \ No newline at end of file diff --git a/Utils.py b/Utils.py index 95c4ea05..068d4078 100644 --- a/Utils.py +++ b/Utils.py @@ -184,3 +184,19 @@ def get_options() -> dict: else: raise FileNotFoundError(f"Could not find {locations[1]} to load options.") return get_options.options + + +def get_item_name_from_id(code): + import Items + return Items.lookup_id_to_name.get(code, f'Unknown item (ID:{code})') + + +def get_location_name_from_address(address): + import Regions + return Regions.lookup_id_to_name.get(address, f'Unknown location (ID:{address})') + + +class ReceivedItem(typing.NamedTuple): + item: int + location: int + player: int diff --git a/easy.yaml b/easy.yaml index f71954d1..e89443bf 100644 --- a/easy.yaml +++ b/easy.yaml @@ -160,7 +160,7 @@ rom: sprite: # Enter the name of your preferred sprite and weight it appropriately random: 0 randomonhit: 0 - link: 1 + link: 1 # to add other sprites: open the gui/Creator, go to adjust, select a sprite and write down the name the gui calls it disablemusic: off # If "on", all in-game music will be disabled extendedmsu: on # If "on", V31 extended MSU support will be available quickswap: # Enable switching items by pressing the L+R shoulder buttons