add some typing info to CollectionState (#468)
This commit is contained in:
		
							parent
							
								
									9ecd320c8c
								
							
						
					
					
						commit
						578451fcfa
					
				| 
						 | 
					@ -7,6 +7,7 @@ import json
 | 
				
			||||||
import functools
 | 
					import functools
 | 
				
			||||||
from collections import OrderedDict, Counter, deque
 | 
					from collections import OrderedDict, Counter, deque
 | 
				
			||||||
from typing import List, Dict, Optional, Set, Iterable, Union, Any, Tuple, TypedDict, Callable
 | 
					from typing import List, Dict, Optional, Set, Iterable, Union, Any, Tuple, TypedDict, Callable
 | 
				
			||||||
 | 
					import typing  # this can go away when Python 3.8 support is dropped
 | 
				
			||||||
import secrets
 | 
					import secrets
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -563,9 +564,20 @@ class MultiWorld():
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					PathValue = Tuple[str, Optional["PathValue"]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CollectionState():
 | 
					class CollectionState():
 | 
				
			||||||
    additional_init_functions: List[Callable] = []
 | 
					    prog_items: typing.Counter[Tuple[str, int]]
 | 
				
			||||||
    additional_copy_functions: List[Callable] = []
 | 
					    world: MultiWorld
 | 
				
			||||||
 | 
					    reachable_regions: Dict[int, Set[Region]]
 | 
				
			||||||
 | 
					    blocked_connections: Dict[int, Set[Entrance]]
 | 
				
			||||||
 | 
					    events: Set[Location]
 | 
				
			||||||
 | 
					    path: Dict[Union[Region, Entrance], PathValue]
 | 
				
			||||||
 | 
					    locations_checked: Set[Location]
 | 
				
			||||||
 | 
					    stale: Dict[int, bool]
 | 
				
			||||||
 | 
					    additional_init_functions: List[Callable[[CollectionState, MultiWorld], None]] = []
 | 
				
			||||||
 | 
					    additional_copy_functions: List[Callable[[CollectionState, CollectionState], CollectionState]] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, parent: MultiWorld):
 | 
					    def __init__(self, parent: MultiWorld):
 | 
				
			||||||
        self.prog_items = Counter()
 | 
					        self.prog_items = Counter()
 | 
				
			||||||
| 
						 | 
					@ -603,6 +615,7 @@ class CollectionState():
 | 
				
			||||||
            if new_region in rrp:
 | 
					            if new_region in rrp:
 | 
				
			||||||
                bc.remove(connection)
 | 
					                bc.remove(connection)
 | 
				
			||||||
            elif connection.can_reach(self):
 | 
					            elif connection.can_reach(self):
 | 
				
			||||||
 | 
					                assert new_region, "tried to search through an Entrance with no Region"
 | 
				
			||||||
                rrp.add(new_region)
 | 
					                rrp.add(new_region)
 | 
				
			||||||
                bc.remove(connection)
 | 
					                bc.remove(connection)
 | 
				
			||||||
                bc.update(new_region.exits)
 | 
					                bc.update(new_region.exits)
 | 
				
			||||||
| 
						 | 
					@ -633,7 +646,8 @@ class CollectionState():
 | 
				
			||||||
                  spot: Union[Location, Entrance, Region, str],
 | 
					                  spot: Union[Location, Entrance, Region, str],
 | 
				
			||||||
                  resolution_hint: Optional[str] = None,
 | 
					                  resolution_hint: Optional[str] = None,
 | 
				
			||||||
                  player: Optional[int] = None) -> bool:
 | 
					                  player: Optional[int] = None) -> bool:
 | 
				
			||||||
        if not hasattr(spot, "can_reach"):
 | 
					        if isinstance(spot, str):
 | 
				
			||||||
 | 
					            assert isinstance(player, int), "can_reach: player is required if spot is str"
 | 
				
			||||||
            # try to resolve a name
 | 
					            # try to resolve a name
 | 
				
			||||||
            if resolution_hint == 'Location':
 | 
					            if resolution_hint == 'Location':
 | 
				
			||||||
                spot = self.world.get_location(spot, player)
 | 
					                spot = self.world.get_location(spot, player)
 | 
				
			||||||
| 
						 | 
					@ -644,7 +658,7 @@ class CollectionState():
 | 
				
			||||||
                spot = self.world.get_region(spot, player)
 | 
					                spot = self.world.get_region(spot, player)
 | 
				
			||||||
        return spot.can_reach(self)
 | 
					        return spot.can_reach(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sweep_for_events(self, key_only: bool = False, locations: Set[Location] = None):
 | 
					    def sweep_for_events(self, key_only: bool = False, locations: Optional[Iterable[Location]] = None) -> None:
 | 
				
			||||||
        if locations is None:
 | 
					        if locations is None:
 | 
				
			||||||
            locations = self.world.get_filled_locations()
 | 
					            locations = self.world.get_filled_locations()
 | 
				
			||||||
        new_locations = True
 | 
					        new_locations = True
 | 
				
			||||||
| 
						 | 
					@ -656,6 +670,7 @@ class CollectionState():
 | 
				
			||||||
            new_locations = reachable_events - self.events
 | 
					            new_locations = reachable_events - self.events
 | 
				
			||||||
            for event in new_locations:
 | 
					            for event in new_locations:
 | 
				
			||||||
                self.events.add(event)
 | 
					                self.events.add(event)
 | 
				
			||||||
 | 
					                assert isinstance(event.item, Item), "tried to collect Event with no Item"
 | 
				
			||||||
                self.collect(event.item, True, event)
 | 
					                self.collect(event.item, True, event)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def has(self, item: str, player: int, count: int = 1) -> bool:
 | 
					    def has(self, item: str, player: int, count: int = 1) -> bool:
 | 
				
			||||||
| 
						 | 
					@ -670,7 +685,7 @@ class CollectionState():
 | 
				
			||||||
    def count(self, item: str, player: int) -> int:
 | 
					    def count(self, item: str, player: int) -> int:
 | 
				
			||||||
        return self.prog_items[item, player]
 | 
					        return self.prog_items[item, player]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def has_group(self, item_name_group: str, player: int, count: int = 1):
 | 
					    def has_group(self, item_name_group: str, player: int, count: int = 1) -> bool:
 | 
				
			||||||
        found: int = 0
 | 
					        found: int = 0
 | 
				
			||||||
        for item_name in self.world.worlds[player].item_name_groups[item_name_group]:
 | 
					        for item_name in self.world.worlds[player].item_name_groups[item_name_group]:
 | 
				
			||||||
            found += self.prog_items[item_name, player]
 | 
					            found += self.prog_items[item_name, player]
 | 
				
			||||||
| 
						 | 
					@ -678,7 +693,7 @@ class CollectionState():
 | 
				
			||||||
                return True
 | 
					                return True
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def count_group(self, item_name_group: str, player: int):
 | 
					    def count_group(self, item_name_group: str, player: int) -> int:
 | 
				
			||||||
        found: int = 0
 | 
					        found: int = 0
 | 
				
			||||||
        for item_name in self.world.worlds[player].item_name_groups[item_name_group]:
 | 
					        for item_name in self.world.worlds[player].item_name_groups[item_name_group]:
 | 
				
			||||||
            found += self.prog_items[item_name, player]
 | 
					            found += self.prog_items[item_name, player]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue