Core: some typing and docs in various parts of the interface (#1060)

* some typing and docs in various parts of the interface

* fix whitespace in docstring

Co-authored-by: black-sliver <59490463+black-sliver@users.noreply.github.com>

* suggested changes from discussion

* remove redundant import

* adjust type for json messages

* for options module detection:
 module.lower().endswith("options")

Co-authored-by: black-sliver <59490463+black-sliver@users.noreply.github.com>
This commit is contained in:
Doug Hoskisson 2022-09-28 14:54:10 -07:00 committed by GitHub
parent 8bc8b412a3
commit c96b6d7b95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 96 additions and 84 deletions

View File

@ -40,6 +40,7 @@ class MultiWorld():
plando_connections: List
worlds: Dict[int, auto_world]
groups: Dict[int, Group]
regions: List[Region]
itempool: List[Item]
is_race: bool = False
precollected_items: Dict[int, List[Item]]
@ -50,6 +51,7 @@ class MultiWorld():
non_local_items: Dict[int, Options.NonLocalItems]
progression_balancing: Dict[int, Options.ProgressionBalancing]
completion_condition: Dict[int, Callable[[CollectionState], bool]]
exclude_locations: Dict[int, Options.ExcludeLocations]
class AttributeProxy():
def __init__(self, rule):
@ -993,7 +995,7 @@ class Entrance:
return False
def connect(self, region: Region, addresses=None, target=None):
def connect(self, region: Region, addresses: Any = None, target: Any = None) -> None:
self.connected_region = region
self.target = target
self.addresses = addresses
@ -1081,7 +1083,7 @@ class Location:
show_in_spoiler: bool = True
progress_type: LocationProgressType = LocationProgressType.DEFAULT
always_allow = staticmethod(lambda item, state: False)
access_rule = staticmethod(lambda state: True)
access_rule: Callable[[CollectionState], bool] = staticmethod(lambda state: True)
item_rule = staticmethod(lambda item: True)
item: Optional[Item] = None

View File

@ -132,12 +132,12 @@ class CommonContext:
# defaults
starting_reconnect_delay: int = 5
current_reconnect_delay: int = starting_reconnect_delay
command_processor: type(CommandProcessor) = ClientCommandProcessor
command_processor: typing.Type[CommandProcessor] = ClientCommandProcessor
ui = None
ui_task: typing.Optional[asyncio.Task] = None
input_task: typing.Optional[asyncio.Task] = None
keep_alive_task: typing.Optional[asyncio.Task] = None
server_task: typing.Optional[asyncio.Task] = None
ui_task: typing.Optional["asyncio.Task[None]"] = None
input_task: typing.Optional["asyncio.Task[None]"] = None
keep_alive_task: typing.Optional["asyncio.Task[None]"] = None
server_task: typing.Optional["asyncio.Task[None]"] = None
server: typing.Optional[Endpoint] = None
server_version: Version = Version(0, 0, 0)
current_energy_link_value: int = 0 # to display in UI, gets set by server
@ -146,7 +146,7 @@ class CommonContext:
# remaining type info
slot_info: typing.Dict[int, NetworkSlot]
server_address: str
server_address: typing.Optional[str]
password: typing.Optional[str]
hint_cost: typing.Optional[int]
player_names: typing.Dict[int, str]
@ -154,6 +154,7 @@ class CommonContext:
# locations
locations_checked: typing.Set[int] # local state
locations_scouted: typing.Set[int]
items_received: typing.List[NetworkItem]
missing_locations: typing.Set[int] # server state
checked_locations: typing.Set[int] # server state
server_locations: typing.Set[int] # all locations the server knows of, missing_location | checked_locations
@ -163,7 +164,7 @@ class CommonContext:
# current message box through kvui
_messagebox = None
def __init__(self, server_address, password):
def __init__(self, server_address: typing.Optional[str], password: typing.Optional[str]) -> None:
# server state
self.server_address = server_address
self.username = None
@ -243,7 +244,8 @@ class CommonContext:
if self.server_task is not None:
await self.server_task
async def send_msgs(self, msgs):
async def send_msgs(self, msgs: typing.List[typing.Any]) -> None:
""" `msgs` JSON serializable """
if not self.server or not self.server.socket.open or self.server.socket.closed:
return
await self.server.socket.send(encode(msgs))
@ -271,7 +273,7 @@ class CommonContext:
logger.info('Enter slot name:')
self.auth = await self.console_input()
async def send_connect(self, **kwargs):
async def send_connect(self, **kwargs: typing.Any) -> None:
payload = {
'cmd': 'Connect',
'password': self.password, 'name': self.auth, 'version': Utils.version_tuple,
@ -282,7 +284,7 @@ class CommonContext:
payload.update(kwargs)
await self.send_msgs([payload])
async def console_input(self):
async def console_input(self) -> str:
self.input_requests += 1
return await self.input_queue.get()
@ -390,7 +392,7 @@ class CommonContext:
# DeathLink hooks
def on_deathlink(self, data: dict):
def on_deathlink(self, data: typing.Dict[str, typing.Any]) -> None:
"""Gets dispatched when a new DeathLink is triggered by another linked player."""
self.last_death_link = max(data["time"], self.last_death_link)
text = data.get("cause", "")
@ -477,7 +479,7 @@ async def keep_alive(ctx: CommonContext, seconds_between_checks=100):
seconds_elapsed = 0
async def server_loop(ctx: CommonContext, address=None):
async def server_loop(ctx: CommonContext, address: typing.Optional[str] = None) -> None:
if ctx.server and ctx.server.socket:
logger.error('Already connected')
return
@ -722,7 +724,7 @@ async def console_loop(ctx: CommonContext):
logger.exception(e)
def get_base_parser(description=None):
def get_base_parser(description: typing.Optional[str] = None):
import argparse
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--connect', default=None, help='Address of the multiworld host.')

View File

@ -100,7 +100,7 @@ _encode = JSONEncoder(
).encode
def encode(obj):
def encode(obj: typing.Any) -> str:
return _encode(_scan_for_TypedTuples(obj))

View File

@ -165,6 +165,7 @@ class FreeText(Option):
class NumericOption(Option[int], numbers.Integral):
default = 0
# note: some of the `typing.Any`` here is a result of unresolved issue in python standards
# `int` is not a `numbers.Integral` according to the official typestubs
# (even though isinstance(5, numbers.Integral) == True)
@ -628,7 +629,7 @@ class VerifyKeys:
class OptionDict(Option[typing.Dict[str, typing.Any]], VerifyKeys):
default = {}
default: typing.Dict[str, typing.Any] = {}
supports_weighting = False
def __init__(self, value: typing.Dict[str, typing.Any]):
@ -659,7 +660,7 @@ class ItemDict(OptionDict):
class OptionList(Option[typing.List[typing.Any]], VerifyKeys):
default = []
default: typing.List[typing.Any] = []
supports_weighting = False
def __init__(self, value: typing.List[typing.Any]):

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import shutil
import json
import bsdiff4
import bsdiff4 # type: ignore
import yaml
import os
import lzma
@ -10,7 +10,7 @@ import threading
import concurrent.futures
import zipfile
import sys
from typing import Tuple, Optional, Dict, Any, Union, BinaryIO
from typing import ClassVar, List, Tuple, Optional, Dict, Any, Union, BinaryIO
import ModuleUpdate
ModuleUpdate.update()
@ -21,10 +21,10 @@ current_patch_version = 5
class AutoPatchRegister(type):
patch_types: Dict[str, APDeltaPatch] = {}
file_endings: Dict[str, APDeltaPatch] = {}
patch_types: ClassVar[Dict[str, AutoPatchRegister]] = {}
file_endings: ClassVar[Dict[str, AutoPatchRegister]] = {}
def __new__(cls, name: str, bases, dct: Dict[str, Any]):
def __new__(cls, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchRegister:
# construct class
new_class = super().__new__(cls, name, bases, dct)
if "game" in dct:
@ -35,10 +35,11 @@ class AutoPatchRegister(type):
return new_class
@staticmethod
def get_handler(file: str) -> Optional[type(APDeltaPatch)]:
def get_handler(file: str) -> Optional[AutoPatchRegister]:
for file_ending, handler in AutoPatchRegister.file_endings.items():
if file.endswith(file_ending):
return handler
return None
class APContainer:
@ -61,34 +62,36 @@ class APContainer:
self.player_name = player_name
self.server = server
def write(self, file: Optional[Union[str, BinaryIO]] = None):
if not self.path and not file:
def write(self, file: Optional[Union[str, BinaryIO]] = None) -> None:
zip_file = file if file else self.path
if not zip_file:
raise FileNotFoundError(f"Cannot write {self.__class__.__name__} due to no path provided.")
with zipfile.ZipFile(file if file else self.path, "w", self.compression_method, True, self.compression_level) \
with zipfile.ZipFile(zip_file, "w", self.compression_method, True, self.compression_level) \
as zf:
if file:
self.path = zf.filename
self.write_contents(zf)
def write_contents(self, opened_zipfile: zipfile.ZipFile):
def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
manifest = self.get_manifest()
try:
manifest = json.dumps(manifest)
manifest_str = json.dumps(manifest)
except Exception as e:
raise Exception(f"Manifest {manifest} did not convert to json.") from e
else:
opened_zipfile.writestr("archipelago.json", manifest)
opened_zipfile.writestr("archipelago.json", manifest_str)
def read(self, file: Optional[Union[str, BinaryIO]] = None):
def read(self, file: Optional[Union[str, BinaryIO]] = None) -> None:
"""Read data into patch object. file can be file-like, such as an outer zip file's stream."""
if not self.path and not file:
zip_file = file if file else self.path
if not zip_file:
raise FileNotFoundError(f"Cannot read {self.__class__.__name__} due to no path provided.")
with zipfile.ZipFile(file if file else self.path, "r") as zf:
with zipfile.ZipFile(zip_file, "r") as zf:
if file:
self.path = zf.filename
self.read_contents(zf)
def read_contents(self, opened_zipfile: zipfile.ZipFile):
def read_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
with opened_zipfile.open("archipelago.json", "r") as f:
manifest = json.load(f)
if manifest["compatible_version"] > self.version:
@ -98,7 +101,7 @@ class APContainer:
self.server = manifest["server"]
self.player_name = manifest["player_name"]
def get_manifest(self) -> dict:
def get_manifest(self) -> Dict[str, Any]:
return {
"server": self.server, # allow immediate connection to server in multiworld. Empty string otherwise
"player": self.player,
@ -114,17 +117,17 @@ class APDeltaPatch(APContainer, metaclass=AutoPatchRegister):
"""An APContainer that additionally has delta.bsdiff4
containing a delta patch to get the desired file, often a rom."""
hash = Optional[str] # base checksum of source file
hash: Optional[str] # base checksum of source file
patch_file_ending: str = ""
delta: Optional[bytes] = None
result_file_ending: str = ".sfc"
source_data: bytes
def __init__(self, *args, patched_path: str = "", **kwargs):
def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None:
self.patched_path = patched_path
super(APDeltaPatch, self).__init__(*args, **kwargs)
def get_manifest(self) -> dict:
def get_manifest(self) -> Dict[str, Any]:
manifest = super(APDeltaPatch, self).get_manifest()
manifest["base_checksum"] = self.hash
manifest["result_file_ending"] = self.result_file_ending
@ -205,15 +208,19 @@ def generate_yaml(patch: bytes, metadata: Optional[dict] = None, game: str = GAM
return patch.encode(encoding="utf-8-sig")
def generate_patch(rom: bytes, metadata: Optional[dict] = None, game: str = GAME_ALTTP) -> bytes:
def generate_patch(rom: bytes, metadata: Optional[Dict[str, Any]] = None, game: str = GAME_ALTTP) -> bytes:
if metadata is None:
metadata = {}
patch = bsdiff4.diff(get_base_rom_data(game), rom)
return generate_yaml(patch, metadata, game)
def create_patch_file(rom_file_to_patch: str, server: str = "", destination: str = None,
player: int = 0, player_name: str = "", game: str = GAME_ALTTP) -> str:
def create_patch_file(rom_file_to_patch: str,
server: str = "",
destination: Optional[str] = None,
player: int = 0,
player_name: str = "",
game: str = GAME_ALTTP) -> str:
meta = {"server": server, # allow immediate connection to server in multiworld. Empty string otherwise
"player_id": player,
"player_name": player_name}
@ -229,19 +236,19 @@ def create_patch_file(rom_file_to_patch: str, server: str = "", destination: str
return target
def create_rom_bytes(patch_file: str, ignore_version: bool = False) -> Tuple[dict, str, bytearray]:
def create_rom_bytes(patch_file: str, ignore_version: bool = False) -> Tuple[Dict[str, Any], str, bytearray]:
data = Utils.parse_yaml(lzma.decompress(load_bytes(patch_file)).decode("utf-8-sig"))
game_name = data["game"]
if not ignore_version and data["compatible_version"] > current_patch_version:
raise RuntimeError("Patch file is incompatible with this patcher, likely an update is required.")
patched_data = bsdiff4.patch(get_base_rom_data(game_name), data["patch"])
patched_data: bytearray = bsdiff4.patch(get_base_rom_data(game_name), data["patch"])
rom_hash = patched_data[int(0x7FC0):int(0x7FD5)]
data["meta"]["hash"] = "".join(chr(x) for x in rom_hash)
target = os.path.splitext(patch_file)[0] + ".sfc"
return data["meta"], target, patched_data
def get_base_rom_data(game: str):
def get_base_rom_data(game: str) -> bytes:
if game == GAME_ALTTP:
from worlds.alttp.Rom import get_base_rom_bytes
elif game == "alttp": # old version for A Link to the Past
@ -260,7 +267,7 @@ def get_base_rom_data(game: str):
return get_base_rom_bytes()
def create_rom_file(patch_file: str) -> Tuple[dict, str]:
def create_rom_file(patch_file: str) -> Tuple[Dict[str, Any], str]:
auto_handler = AutoPatchRegister.get_handler(patch_file)
if auto_handler:
handler: APDeltaPatch = auto_handler(patch_file)
@ -293,7 +300,7 @@ def write_lzma(data: bytes, path: str):
f.write(data)
def read_rom(stream, strip_header=True) -> bytearray:
def read_rom(stream: BinaryIO, strip_header: bool = True) -> bytearray:
"""Reads rom into bytearray and optionally strips off any smc header"""
buffer = bytearray(stream.read())
if strip_header and len(buffer) % 0x400 == 0x200:
@ -321,7 +328,7 @@ if __name__ == "__main__":
elif rom.endswith(".apbp"):
print(f"Applying patch {rom}")
data, target = create_rom_file(rom)
#romfile, adjusted = Utils.get_adjuster_settings(target)
# romfile, adjusted = Utils.get_adjuster_settings(target)
adjuster_settings = Utils.get_adjuster_settings(GAME_ALTTP)
adjusted = False
if adjuster_settings:
@ -385,21 +392,9 @@ if __name__ == "__main__":
if 'server' in data:
Utils.persistent_store("servers", data['hash'], data['server'])
print(f"Host is {data['server']}")
elif rom.endswith(".apm3"):
print(f"Applying patch {rom}")
data, target = create_rom_file(rom)
print(f"Created rom {target}.")
if 'server' in data:
Utils.persistent_store("servers", data['hash'], data['server'])
print(f"Host is {data['server']}")
elif rom.endswith(".apsmz"):
print(f"Applying patch {rom}")
data, target = create_rom_file(rom)
print(f"Created rom {target}.")
if 'server' in data:
Utils.persistent_store("servers", data['hash'], data['server'])
print(f"Host is {data['server']}")
elif rom.endswith(".apdkc3"):
elif rom.endswith(".apm3") \
or rom.endswith(".apsmz") \
or rom.endswith(".apdkc3"):
print(f"Applying patch {rom}")
data, target = create_rom_file(rom)
print(f"Created rom {target}.")
@ -410,8 +405,7 @@ if __name__ == "__main__":
elif rom.endswith(".zip"):
print(f"Updating host in patch files contained in {rom}")
def _handle_zip_file_entry(zfinfo: zipfile.ZipInfo, server: str):
def _handle_zip_file_entry(zfinfo: zipfile.ZipInfo, server: str) -> str:
data = zfr.read(zfinfo)
if zfinfo.filename.endswith(".apbp") or \
zfinfo.filename.endswith(".apm3") or \
@ -421,8 +415,7 @@ if __name__ == "__main__":
zfw.writestr(zfinfo, data)
return zfinfo.filename
futures = []
futures: List[concurrent.futures.Future[str]] = []
with zipfile.ZipFile(rom, "r") as zfr:
updated_zip = os.path.splitext(rom)[0] + "_updated.zip"
with zipfile.ZipFile(updated_zip, "w", compression=zipfile.ZIP_DEFLATED,

View File

@ -217,8 +217,11 @@ def get_public_ipv6() -> str:
return ip
OptionsType = typing.Dict[str, typing.Dict[str, typing.Any]]
@cache_argsless
def get_default_options() -> dict:
def get_default_options() -> OptionsType:
# Refer to host.yaml for comments as to what all these options mean.
options = {
"general_options": {
@ -290,7 +293,7 @@ def get_default_options() -> dict:
return options
def update_options(src: dict, dest: dict, filename: str, keys: list) -> dict:
def update_options(src: dict, dest: dict, filename: str, keys: list) -> OptionsType:
for key, value in src.items():
new_keys = keys.copy()
new_keys.append(key)
@ -310,9 +313,9 @@ def update_options(src: dict, dest: dict, filename: str, keys: list) -> dict:
@cache_argsless
def get_options() -> dict:
def get_options() -> OptionsType:
filenames = ("options.yaml", "host.yaml")
locations = []
locations: typing.List[str] = []
if os.path.join(os.getcwd()) != local_path():
locations += filenames # use files from cwd only if it's not the local_path
locations += [user_path(filename) for filename in filenames]
@ -353,7 +356,7 @@ def persistent_load() -> typing.Dict[str, dict]:
return storage
def get_adjuster_settings(game_name: str):
def get_adjuster_settings(game_name: str) -> typing.Dict[str, typing.Any]:
adjuster_settings = persistent_load().get("adjuster", {}).get(game_name, {})
return adjuster_settings
@ -392,7 +395,8 @@ class RestrictedUnpickler(pickle.Unpickler):
# Options and Plando are unpickled by WebHost -> Generate
if module == "worlds.generic" and name in {"PlandoItem", "PlandoConnection"}:
return getattr(self.generic_properties_module, name)
if module.endswith("Options"):
# pep 8 specifies that modules should have "all-lowercase names" (options, not Options)
if module.lower().endswith("options"):
if module == "Options":
mod = self.options_module
else:

View File

@ -64,7 +64,10 @@ def create():
for game_name, world in AutoWorldRegister.world_types.items():
all_options = {**Options.per_game_common_options, **world.option_definitions}
all_options: typing.Dict[str, Options.AssembleOptions] = {
**Options.per_game_common_options,
**world.option_definitions
}
with open(local_path("WebHostLib", "templates", "options.yaml")) as f:
file_data = f.read()
res = Template(file_data).render(

View File

@ -22,7 +22,7 @@ def upload_zip_to_db(zfile: zipfile.ZipFile, owner=None, meta={"race": False}, s
if not owner:
owner = session["_id"]
infolist = zfile.infolist()
slots = set()
slots: typing.Set[Slot] = set()
spoiler = ""
multidata = None
for file in infolist:

View File

@ -16,6 +16,10 @@ Then run any of the starting point scripts, like Generate.py, and the included M
required modules and after pressing enter proceed to install everything automatically.
After this, you should be able to run the programs.
* With yaml(s) in the `Players` folder, `Generate.py` will generate the multiworld archive.
* `MultiServer.py`, with the filename of the generated archive as a command line parameter, will host the multiworld locally.
* `--log_network` is a command line parameter useful for debugging.
## Windows

View File

@ -20,7 +20,7 @@ class TestBase(unittest.TestCase):
for location in world.get_locations():
if location.name not in excluded:
with self.subTest("Location should be reached", location=location):
self.assertTrue(location.can_reach(state))
self.assertTrue(location.can_reach(state), f"{location.name} unreachable")
with self.subTest("Completion Condition"):
self.assertTrue(world.can_beat_game(state))

View File

@ -3,9 +3,9 @@ from __future__ import annotations
import logging
import sys
import pathlib
from typing import Dict, FrozenSet, Set, Tuple, List, Optional, TextIO, Any, Callable, Union, TYPE_CHECKING
from typing import Dict, FrozenSet, Set, Tuple, List, Optional, TextIO, Any, Callable, Type, Union, TYPE_CHECKING
from Options import Option
from Options import AssembleOptions
from BaseClasses import CollectionState
if TYPE_CHECKING:
@ -13,7 +13,7 @@ if TYPE_CHECKING:
class AutoWorldRegister(type):
world_types: Dict[str, type(World)] = {}
world_types: Dict[str, Type[World]] = {}
def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoWorldRegister:
if "web" in dct:
@ -120,7 +120,7 @@ class World(metaclass=AutoWorldRegister):
"""A World object encompasses a game's Items, Locations, Rules and additional data or functionality required.
A Game should have its own subclass of World in which it defines the required data structures."""
option_definitions: Dict[str, Option[Any]] = {} # link your Options mapping
option_definitions: Dict[str, AssembleOptions] = {} # link your Options mapping
game: str # name the game
topology_present: bool = False # indicate if world type has any meaningful layout/pathing
@ -229,7 +229,8 @@ class World(metaclass=AutoWorldRegister):
pass
def post_fill(self) -> None:
"""Optional Method that is called after regular fill. Can be used to do adjustments before output generation."""
"""Optional Method that is called after regular fill. Can be used to do adjustments before output generation.
This happens before progression balancing, so the items may not be in their final locations yet."""
def generate_output(self, output_directory: str) -> None:
"""This method gets called from a threadpool, do not use world.random here.
@ -237,7 +238,9 @@ class World(metaclass=AutoWorldRegister):
pass
def fill_slot_data(self) -> Dict[str, Any]: # json of WebHostLib.models.Slot
"""Fill in the slot_data field in the Connected network package."""
"""Fill in the `slot_data` field in the `Connected` network package.
This is a way the generator can give custom data to the client.
The client will receive this as JSON in the `Connected` response."""
return {}
def extend_hint_information(self, hint_data: Dict[int, Dict[int, str]]):

View File

@ -1,6 +1,6 @@
import typing
from BaseClasses import LocationProgressType
from BaseClasses import LocationProgressType, MultiWorld
if typing.TYPE_CHECKING:
import BaseClasses
@ -37,14 +37,14 @@ def locality_rules(world, player: int):
forbid_items_for_player(location, world.non_local_items[player].value, player)
def exclusion_rules(world, player: int, exclude_locations: typing.Set[str]):
def exclusion_rules(world: MultiWorld, player: int, exclude_locations: typing.Set[str]) -> None:
for loc_name in exclude_locations:
try:
location = world.get_location(loc_name, player)
except KeyError as e: # failed to find the given location. Check if it's a legitimate location
if loc_name not in world.worlds[player].location_name_to_id:
raise Exception(f"Unable to exclude location {loc_name} in player {player}'s world.") from e
else:
else:
add_item_rule(location, lambda i: not (i.advancement or i.useful))
location.progress_type = LocationProgressType.EXCLUDED