diff --git a/BaseClasses.py b/BaseClasses.py index 61f3f8f6..29264f34 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -11,8 +11,10 @@ from argparse import Namespace from collections import Counter, deque from collections.abc import Collection, MutableSequence from enum import IntEnum, IntFlag -from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, NamedTuple, Optional, Set, Tuple, \ - TypedDict, Union, Type, ClassVar +from typing import (AbstractSet, Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, NamedTuple, + Optional, Protocol, Set, Tuple, Union, Type) + +from typing_extensions import NotRequired, TypedDict import NetUtils import Options @@ -22,16 +24,16 @@ if typing.TYPE_CHECKING: from worlds import AutoWorld -class Group(TypedDict, total=False): +class Group(TypedDict): name: str game: str world: "AutoWorld.World" - players: Set[int] - item_pool: Set[str] - replacement_items: Dict[int, Optional[str]] - local_items: Set[str] - non_local_items: Set[str] - link_replacement: bool + players: AbstractSet[int] + item_pool: NotRequired[Set[str]] + replacement_items: NotRequired[Dict[int, Optional[str]]] + local_items: NotRequired[Set[str]] + non_local_items: NotRequired[Set[str]] + link_replacement: NotRequired[bool] class ThreadBarrierProxy: @@ -48,6 +50,11 @@ class ThreadBarrierProxy: "Please use multiworld.per_slot_randoms[player] or randomize ahead of output.") +class HasNameAndPlayer(Protocol): + name: str + player: int + + class MultiWorld(): debug_types = False player_name: Dict[int, str] @@ -156,7 +163,7 @@ class MultiWorld(): self.start_inventory_from_pool: Dict[int, Options.StartInventoryPool] = {} for player in range(1, players + 1): - def set_player_attr(attr, val): + def set_player_attr(attr: str, val) -> None: self.__dict__.setdefault(attr, {})[player] = val set_player_attr('plando_items', []) set_player_attr('plando_texts', {}) @@ -165,13 +172,13 @@ class MultiWorld(): set_player_attr('completion_condition', lambda state: True) self.worlds = {} self.per_slot_randoms = Utils.DeprecateDict("Using per_slot_randoms is now deprecated. Please use the " - "world's random object instead (usually self.random)") + "world's random object instead (usually self.random)") self.plando_options = PlandoOptions.none def get_all_ids(self) -> Tuple[int, ...]: return self.player_ids + tuple(self.groups) - def add_group(self, name: str, game: str, players: Set[int] = frozenset()) -> Tuple[int, Group]: + def add_group(self, name: str, game: str, players: AbstractSet[int] = frozenset()) -> Tuple[int, Group]: """Create a group with name and return the assigned player ID and group. If a group of this name already exists, the set of players is extended instead of creating a new one.""" from worlds import AutoWorld @@ -195,7 +202,7 @@ class MultiWorld(): return new_id, new_group - def get_player_groups(self, player) -> Set[int]: + def get_player_groups(self, player: int) -> Set[int]: return {group_id for group_id, group in self.groups.items() if player in group["players"]} def set_seed(self, seed: Optional[int] = None, secure: bool = False, name: Optional[str] = None): @@ -258,7 +265,7 @@ class MultiWorld(): "link_replacement": replacement_prio.index(item_link["link_replacement"]), } - for name, item_link in item_links.items(): + for _name, item_link in item_links.items(): current_item_name_groups = AutoWorld.AutoWorldRegister.world_types[item_link["game"]].item_name_groups pool = set() local_items = set() @@ -388,7 +395,7 @@ class MultiWorld(): return tuple(world for player, world in self.worlds.items() if player not in self.groups and self.game[player] == game_name) - def get_name_string_for_object(self, obj) -> str: + def get_name_string_for_object(self, obj: HasNameAndPlayer) -> str: return obj.name if self.players == 1 else f'{obj.name} ({self.get_player_name(obj.player)})' def get_player_name(self, player: int) -> str: @@ -439,7 +446,7 @@ class MultiWorld(): def get_items(self) -> List[Item]: return [loc.item for loc in self.get_filled_locations()] + self.itempool - def find_item_locations(self, item, player: int, resolve_group_locations: bool = False) -> List[Location]: + def find_item_locations(self, item: str, player: int, resolve_group_locations: bool = False) -> List[Location]: if resolve_group_locations: player_groups = self.get_player_groups(player) return [location for location in self.get_locations() if @@ -448,7 +455,7 @@ class MultiWorld(): return [location for location in self.get_locations() if location.item and location.item.name == item and location.item.player == player] - def find_item(self, item, player: int) -> Location: + def find_item(self, item: str, player: int) -> Location: return next(location for location in self.get_locations() if location.item and location.item.name == item and location.item.player == player) @@ -806,7 +813,7 @@ class CollectionState(): if found >= count: return True return False - + def has_from_list_unique(self, items: Iterable[str], player: int, count: int) -> bool: """Returns True if the state contains at least `count` items matching any of the item names from a list. Ignores duplicates of the same item.""" @@ -821,7 +828,7 @@ class CollectionState(): def count_from_list(self, items: Iterable[str], player: int) -> int: """Returns the cumulative count of items from a list present in state.""" return sum(self.prog_items[player][item_name] for item_name in items) - + def count_from_list_unique(self, items: Iterable[str], player: int) -> int: """Returns the cumulative count of items from a list present in state. Ignores duplicates of the same item.""" return sum(self.prog_items[player][item_name] > 0 for item_name in items) @@ -900,7 +907,7 @@ class Entrance: addresses = None target = None - def __init__(self, player: int, name: str = '', parent: Region = None): + def __init__(self, player: int, name: str = "", parent: Optional[Region] = None) -> None: self.name = name self.parent_region = parent self.player = player @@ -920,9 +927,6 @@ class Entrance: region.entrances.append(self) def __repr__(self): - return self.__str__() - - def __str__(self): multiworld = self.parent_region.multiworld if self.parent_region else None return multiworld.get_name_string_for_object(self) if multiworld else f'{self.name} (Player {self.player})' @@ -1048,7 +1052,7 @@ class Region: self.locations.append(location_type(self.player, location, address, self)) def connect(self, connecting_region: Region, name: Optional[str] = None, - rule: Optional[Callable[[CollectionState], bool]] = None) -> entrance_type: + rule: Optional[Callable[[CollectionState], bool]] = None) -> Entrance: """ Connects this Region to another Region, placing the provided rule on the connection. @@ -1088,9 +1092,6 @@ class Region: rules[connecting_region] if rules and connecting_region in rules else None) def __repr__(self): - return self.__str__() - - def __str__(self): return self.multiworld.get_name_string_for_object(self) if self.multiworld else f'{self.name} (Player {self.player})' @@ -1109,9 +1110,9 @@ class Location: locked: bool = False show_in_spoiler: bool = True progress_type: LocationProgressType = LocationProgressType.DEFAULT - always_allow = staticmethod(lambda state, item: False) + always_allow: Callable[[CollectionState, Item], bool] = staticmethod(lambda state, item: False) access_rule: Callable[[CollectionState], bool] = staticmethod(lambda state: True) - item_rule = staticmethod(lambda item: True) + item_rule: Callable[[Item], bool] = staticmethod(lambda item: True) item: Optional[Item] = None def __init__(self, player: int, name: str = '', address: Optional[int] = None, parent: Optional[Region] = None): @@ -1120,11 +1121,15 @@ class Location: self.address = address self.parent_region = parent - def can_fill(self, state: CollectionState, item: Item, check_access=True) -> bool: - return ((self.always_allow(state, item) and item.name not in state.multiworld.worlds[item.player].options.non_local_items) - or ((self.progress_type != LocationProgressType.EXCLUDED or not (item.advancement or item.useful)) - and self.item_rule(item) - and (not check_access or self.can_reach(state)))) + def can_fill(self, state: CollectionState, item: Item, check_access: bool = True) -> bool: + return (( + self.always_allow(state, item) + and item.name not in state.multiworld.worlds[item.player].options.non_local_items + ) or ( + (self.progress_type != LocationProgressType.EXCLUDED or not (item.advancement or item.useful)) + and self.item_rule(item) + and (not check_access or self.can_reach(state)) + )) def can_reach(self, state: CollectionState) -> bool: # Region.can_reach is just a cache lookup, so placing it first for faster abort on average @@ -1139,9 +1144,6 @@ class Location: self.locked = True def __repr__(self): - return self.__str__() - - def __str__(self): multiworld = self.parent_region.multiworld if self.parent_region and self.parent_region.multiworld else None return multiworld.get_name_string_for_object(self) if multiworld else f'{self.name} (Player {self.player})' @@ -1163,7 +1165,7 @@ class Location: @property def native_item(self) -> bool: """Returns True if the item in this location matches game.""" - return self.item and self.item.game == self.game + return self.item is not None and self.item.game == self.game @property def hint_text(self) -> str: @@ -1246,9 +1248,6 @@ class Item: return hash((self.name, self.player)) def __repr__(self) -> str: - return self.__str__() - - def __str__(self) -> str: if self.location and self.location.parent_region and self.location.parent_region.multiworld: return self.location.parent_region.multiworld.get_name_string_for_object(self) return f"{self.name} (Player {self.player})" @@ -1326,9 +1325,9 @@ class Spoiler: # in the second phase, we cull each sphere such that the game is still beatable, # reducing each range of influence to the bare minimum required inside it - restore_later = {} + restore_later: Dict[Location, Item] = {} for num, sphere in reversed(tuple(enumerate(collection_spheres))): - to_delete = set() + to_delete: Set[Location] = set() for location in sphere: # we remove the item at location and check if game is still beatable logging.debug('Checking if %s (Player %d) is required to beat the game.', location.item.name, @@ -1346,7 +1345,7 @@ class Spoiler: sphere -= to_delete # second phase, sphere 0 - removed_precollected = [] + removed_precollected: List[Item] = [] for item in (i for i in chain.from_iterable(multiworld.precollected_items.values()) if i.advancement): logging.debug('Checking if %s (Player %d) is required to beat the game.', item.name, item.player) multiworld.precollected_items[item.player].remove(item) @@ -1499,9 +1498,9 @@ class Spoiler: if self.paths: outfile.write('\n\nPaths:\n\n') - path_listings = [] + path_listings: List[str] = [] for location, path in sorted(self.paths.items()): - path_lines = [] + path_lines: List[str] = [] for region, exit in path: if exit is not None: path_lines.append("{} -> {}".format(region, exit))