add some typing info to CollectionState (#468)
This commit is contained in:
parent
9ecd320c8c
commit
578451fcfa
|
@ -7,6 +7,7 @@ import json
|
|||
import functools
|
||||
from collections import OrderedDict, Counter, deque
|
||||
from typing import List, Dict, Optional, Set, Iterable, Union, Any, Tuple, TypedDict, Callable
|
||||
import typing # this can go away when Python 3.8 support is dropped
|
||||
import secrets
|
||||
import random
|
||||
|
||||
|
@ -563,9 +564,20 @@ class MultiWorld():
|
|||
return False
|
||||
|
||||
|
||||
PathValue = Tuple[str, Optional["PathValue"]]
|
||||
|
||||
|
||||
class CollectionState():
|
||||
additional_init_functions: List[Callable] = []
|
||||
additional_copy_functions: List[Callable] = []
|
||||
prog_items: typing.Counter[Tuple[str, int]]
|
||||
world: MultiWorld
|
||||
reachable_regions: Dict[int, Set[Region]]
|
||||
blocked_connections: Dict[int, Set[Entrance]]
|
||||
events: Set[Location]
|
||||
path: Dict[Union[Region, Entrance], PathValue]
|
||||
locations_checked: Set[Location]
|
||||
stale: Dict[int, bool]
|
||||
additional_init_functions: List[Callable[[CollectionState, MultiWorld], None]] = []
|
||||
additional_copy_functions: List[Callable[[CollectionState, CollectionState], CollectionState]] = []
|
||||
|
||||
def __init__(self, parent: MultiWorld):
|
||||
self.prog_items = Counter()
|
||||
|
@ -603,6 +615,7 @@ class CollectionState():
|
|||
if new_region in rrp:
|
||||
bc.remove(connection)
|
||||
elif connection.can_reach(self):
|
||||
assert new_region, "tried to search through an Entrance with no Region"
|
||||
rrp.add(new_region)
|
||||
bc.remove(connection)
|
||||
bc.update(new_region.exits)
|
||||
|
@ -633,7 +646,8 @@ class CollectionState():
|
|||
spot: Union[Location, Entrance, Region, str],
|
||||
resolution_hint: Optional[str] = None,
|
||||
player: Optional[int] = None) -> bool:
|
||||
if not hasattr(spot, "can_reach"):
|
||||
if isinstance(spot, str):
|
||||
assert isinstance(player, int), "can_reach: player is required if spot is str"
|
||||
# try to resolve a name
|
||||
if resolution_hint == 'Location':
|
||||
spot = self.world.get_location(spot, player)
|
||||
|
@ -644,7 +658,7 @@ class CollectionState():
|
|||
spot = self.world.get_region(spot, player)
|
||||
return spot.can_reach(self)
|
||||
|
||||
def sweep_for_events(self, key_only: bool = False, locations: Set[Location] = None):
|
||||
def sweep_for_events(self, key_only: bool = False, locations: Optional[Iterable[Location]] = None) -> None:
|
||||
if locations is None:
|
||||
locations = self.world.get_filled_locations()
|
||||
new_locations = True
|
||||
|
@ -656,6 +670,7 @@ class CollectionState():
|
|||
new_locations = reachable_events - self.events
|
||||
for event in new_locations:
|
||||
self.events.add(event)
|
||||
assert isinstance(event.item, Item), "tried to collect Event with no Item"
|
||||
self.collect(event.item, True, event)
|
||||
|
||||
def has(self, item: str, player: int, count: int = 1) -> bool:
|
||||
|
@ -670,7 +685,7 @@ class CollectionState():
|
|||
def count(self, item: str, player: int) -> int:
|
||||
return self.prog_items[item, player]
|
||||
|
||||
def has_group(self, item_name_group: str, player: int, count: int = 1):
|
||||
def has_group(self, item_name_group: str, player: int, count: int = 1) -> bool:
|
||||
found: int = 0
|
||||
for item_name in self.world.worlds[player].item_name_groups[item_name_group]:
|
||||
found += self.prog_items[item_name, player]
|
||||
|
@ -678,7 +693,7 @@ class CollectionState():
|
|||
return True
|
||||
return False
|
||||
|
||||
def count_group(self, item_name_group: str, player: int):
|
||||
def count_group(self, item_name_group: str, player: int) -> int:
|
||||
found: int = 0
|
||||
for item_name in self.world.worlds[player].item_name_groups[item_name_group]:
|
||||
found += self.prog_items[item_name, player]
|
||||
|
|
Loading…
Reference in New Issue