Core: some typing and cleaning in `BaseClasses.py` (#3391)

* Core: some typing and cleaning in `BaseClasses.py`

* more backwards `__repr__`

* double-quote string

* remove some end-of-line whitespace
This commit is contained in:
Doug Hoskisson 2024-08-23 17:05:30 -07:00 committed by GitHub
parent 64b654d42e
commit 43cb9611fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 45 additions and 46 deletions

View File

@ -11,8 +11,10 @@ from argparse import Namespace
from collections import Counter, deque
from collections.abc import Collection, MutableSequence
from enum import IntEnum, IntFlag
from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, NamedTuple, Optional, Set, Tuple, \
TypedDict, Union, Type, ClassVar
from typing import (AbstractSet, Any, Callable, ClassVar, Dict, Iterable, Iterator, List, Mapping, NamedTuple,
Optional, Protocol, Set, Tuple, Union, Type)
from typing_extensions import NotRequired, TypedDict
import NetUtils
import Options
@ -22,16 +24,16 @@ if typing.TYPE_CHECKING:
from worlds import AutoWorld
class Group(TypedDict, total=False):
class Group(TypedDict):
name: str
game: str
world: "AutoWorld.World"
players: Set[int]
item_pool: Set[str]
replacement_items: Dict[int, Optional[str]]
local_items: Set[str]
non_local_items: Set[str]
link_replacement: bool
players: AbstractSet[int]
item_pool: NotRequired[Set[str]]
replacement_items: NotRequired[Dict[int, Optional[str]]]
local_items: NotRequired[Set[str]]
non_local_items: NotRequired[Set[str]]
link_replacement: NotRequired[bool]
class ThreadBarrierProxy:
@ -48,6 +50,11 @@ class ThreadBarrierProxy:
"Please use multiworld.per_slot_randoms[player] or randomize ahead of output.")
class HasNameAndPlayer(Protocol):
name: str
player: int
class MultiWorld():
debug_types = False
player_name: Dict[int, str]
@ -156,7 +163,7 @@ class MultiWorld():
self.start_inventory_from_pool: Dict[int, Options.StartInventoryPool] = {}
for player in range(1, players + 1):
def set_player_attr(attr, val):
def set_player_attr(attr: str, val) -> None:
self.__dict__.setdefault(attr, {})[player] = val
set_player_attr('plando_items', [])
set_player_attr('plando_texts', {})
@ -165,13 +172,13 @@ class MultiWorld():
set_player_attr('completion_condition', lambda state: True)
self.worlds = {}
self.per_slot_randoms = Utils.DeprecateDict("Using per_slot_randoms is now deprecated. Please use the "
"world's random object instead (usually self.random)")
"world's random object instead (usually self.random)")
self.plando_options = PlandoOptions.none
def get_all_ids(self) -> Tuple[int, ...]:
return self.player_ids + tuple(self.groups)
def add_group(self, name: str, game: str, players: Set[int] = frozenset()) -> Tuple[int, Group]:
def add_group(self, name: str, game: str, players: AbstractSet[int] = frozenset()) -> Tuple[int, Group]:
"""Create a group with name and return the assigned player ID and group.
If a group of this name already exists, the set of players is extended instead of creating a new one."""
from worlds import AutoWorld
@ -195,7 +202,7 @@ class MultiWorld():
return new_id, new_group
def get_player_groups(self, player) -> Set[int]:
def get_player_groups(self, player: int) -> Set[int]:
return {group_id for group_id, group in self.groups.items() if player in group["players"]}
def set_seed(self, seed: Optional[int] = None, secure: bool = False, name: Optional[str] = None):
@ -258,7 +265,7 @@ class MultiWorld():
"link_replacement": replacement_prio.index(item_link["link_replacement"]),
}
for name, item_link in item_links.items():
for _name, item_link in item_links.items():
current_item_name_groups = AutoWorld.AutoWorldRegister.world_types[item_link["game"]].item_name_groups
pool = set()
local_items = set()
@ -388,7 +395,7 @@ class MultiWorld():
return tuple(world for player, world in self.worlds.items() if
player not in self.groups and self.game[player] == game_name)
def get_name_string_for_object(self, obj) -> str:
def get_name_string_for_object(self, obj: HasNameAndPlayer) -> str:
return obj.name if self.players == 1 else f'{obj.name} ({self.get_player_name(obj.player)})'
def get_player_name(self, player: int) -> str:
@ -439,7 +446,7 @@ class MultiWorld():
def get_items(self) -> List[Item]:
return [loc.item for loc in self.get_filled_locations()] + self.itempool
def find_item_locations(self, item, player: int, resolve_group_locations: bool = False) -> List[Location]:
def find_item_locations(self, item: str, player: int, resolve_group_locations: bool = False) -> List[Location]:
if resolve_group_locations:
player_groups = self.get_player_groups(player)
return [location for location in self.get_locations() if
@ -448,7 +455,7 @@ class MultiWorld():
return [location for location in self.get_locations() if
location.item and location.item.name == item and location.item.player == player]
def find_item(self, item, player: int) -> Location:
def find_item(self, item: str, player: int) -> Location:
return next(location for location in self.get_locations() if
location.item and location.item.name == item and location.item.player == player)
@ -806,7 +813,7 @@ class CollectionState():
if found >= count:
return True
return False
def has_from_list_unique(self, items: Iterable[str], player: int, count: int) -> bool:
"""Returns True if the state contains at least `count` items matching any of the item names from a list.
Ignores duplicates of the same item."""
@ -821,7 +828,7 @@ class CollectionState():
def count_from_list(self, items: Iterable[str], player: int) -> int:
"""Returns the cumulative count of items from a list present in state."""
return sum(self.prog_items[player][item_name] for item_name in items)
def count_from_list_unique(self, items: Iterable[str], player: int) -> int:
"""Returns the cumulative count of items from a list present in state. Ignores duplicates of the same item."""
return sum(self.prog_items[player][item_name] > 0 for item_name in items)
@ -900,7 +907,7 @@ class Entrance:
addresses = None
target = None
def __init__(self, player: int, name: str = '', parent: Region = None):
def __init__(self, player: int, name: str = "", parent: Optional[Region] = None) -> None:
self.name = name
self.parent_region = parent
self.player = player
@ -920,9 +927,6 @@ class Entrance:
region.entrances.append(self)
def __repr__(self):
return self.__str__()
def __str__(self):
multiworld = self.parent_region.multiworld if self.parent_region else None
return multiworld.get_name_string_for_object(self) if multiworld else f'{self.name} (Player {self.player})'
@ -1048,7 +1052,7 @@ class Region:
self.locations.append(location_type(self.player, location, address, self))
def connect(self, connecting_region: Region, name: Optional[str] = None,
rule: Optional[Callable[[CollectionState], bool]] = None) -> entrance_type:
rule: Optional[Callable[[CollectionState], bool]] = None) -> Entrance:
"""
Connects this Region to another Region, placing the provided rule on the connection.
@ -1088,9 +1092,6 @@ class Region:
rules[connecting_region] if rules and connecting_region in rules else None)
def __repr__(self):
return self.__str__()
def __str__(self):
return self.multiworld.get_name_string_for_object(self) if self.multiworld else f'{self.name} (Player {self.player})'
@ -1109,9 +1110,9 @@ class Location:
locked: bool = False
show_in_spoiler: bool = True
progress_type: LocationProgressType = LocationProgressType.DEFAULT
always_allow = staticmethod(lambda state, item: False)
always_allow: Callable[[CollectionState, Item], bool] = staticmethod(lambda state, item: False)
access_rule: Callable[[CollectionState], bool] = staticmethod(lambda state: True)
item_rule = staticmethod(lambda item: True)
item_rule: Callable[[Item], bool] = staticmethod(lambda item: True)
item: Optional[Item] = None
def __init__(self, player: int, name: str = '', address: Optional[int] = None, parent: Optional[Region] = None):
@ -1120,11 +1121,15 @@ class Location:
self.address = address
self.parent_region = parent
def can_fill(self, state: CollectionState, item: Item, check_access=True) -> bool:
return ((self.always_allow(state, item) and item.name not in state.multiworld.worlds[item.player].options.non_local_items)
or ((self.progress_type != LocationProgressType.EXCLUDED or not (item.advancement or item.useful))
and self.item_rule(item)
and (not check_access or self.can_reach(state))))
def can_fill(self, state: CollectionState, item: Item, check_access: bool = True) -> bool:
return ((
self.always_allow(state, item)
and item.name not in state.multiworld.worlds[item.player].options.non_local_items
) or (
(self.progress_type != LocationProgressType.EXCLUDED or not (item.advancement or item.useful))
and self.item_rule(item)
and (not check_access or self.can_reach(state))
))
def can_reach(self, state: CollectionState) -> bool:
# Region.can_reach is just a cache lookup, so placing it first for faster abort on average
@ -1139,9 +1144,6 @@ class Location:
self.locked = True
def __repr__(self):
return self.__str__()
def __str__(self):
multiworld = self.parent_region.multiworld if self.parent_region and self.parent_region.multiworld else None
return multiworld.get_name_string_for_object(self) if multiworld else f'{self.name} (Player {self.player})'
@ -1163,7 +1165,7 @@ class Location:
@property
def native_item(self) -> bool:
"""Returns True if the item in this location matches game."""
return self.item and self.item.game == self.game
return self.item is not None and self.item.game == self.game
@property
def hint_text(self) -> str:
@ -1246,9 +1248,6 @@ class Item:
return hash((self.name, self.player))
def __repr__(self) -> str:
return self.__str__()
def __str__(self) -> str:
if self.location and self.location.parent_region and self.location.parent_region.multiworld:
return self.location.parent_region.multiworld.get_name_string_for_object(self)
return f"{self.name} (Player {self.player})"
@ -1326,9 +1325,9 @@ class Spoiler:
# in the second phase, we cull each sphere such that the game is still beatable,
# reducing each range of influence to the bare minimum required inside it
restore_later = {}
restore_later: Dict[Location, Item] = {}
for num, sphere in reversed(tuple(enumerate(collection_spheres))):
to_delete = set()
to_delete: Set[Location] = set()
for location in sphere:
# we remove the item at location and check if game is still beatable
logging.debug('Checking if %s (Player %d) is required to beat the game.', location.item.name,
@ -1346,7 +1345,7 @@ class Spoiler:
sphere -= to_delete
# second phase, sphere 0
removed_precollected = []
removed_precollected: List[Item] = []
for item in (i for i in chain.from_iterable(multiworld.precollected_items.values()) if i.advancement):
logging.debug('Checking if %s (Player %d) is required to beat the game.', item.name, item.player)
multiworld.precollected_items[item.player].remove(item)
@ -1499,9 +1498,9 @@ class Spoiler:
if self.paths:
outfile.write('\n\nPaths:\n\n')
path_listings = []
path_listings: List[str] = []
for location, path in sorted(self.paths.items()):
path_lines = []
path_lines: List[str] = []
for region, exit in path:
if exit is not None:
path_lines.append("{} -> {}".format(region, exit))