add some typing info to CollectionState (#468)
This commit is contained in:
		
							parent
							
								
									9ecd320c8c
								
							
						
					
					
						commit
						578451fcfa
					
				| 
						 | 
				
			
			@ -7,6 +7,7 @@ import json
 | 
			
		|||
import functools
 | 
			
		||||
from collections import OrderedDict, Counter, deque
 | 
			
		||||
from typing import List, Dict, Optional, Set, Iterable, Union, Any, Tuple, TypedDict, Callable
 | 
			
		||||
import typing  # this can go away when Python 3.8 support is dropped
 | 
			
		||||
import secrets
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -563,9 +564,20 @@ class MultiWorld():
 | 
			
		|||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PathValue = Tuple[str, Optional["PathValue"]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CollectionState():
 | 
			
		||||
    additional_init_functions: List[Callable] = []
 | 
			
		||||
    additional_copy_functions: List[Callable] = []
 | 
			
		||||
    prog_items: typing.Counter[Tuple[str, int]]
 | 
			
		||||
    world: MultiWorld
 | 
			
		||||
    reachable_regions: Dict[int, Set[Region]]
 | 
			
		||||
    blocked_connections: Dict[int, Set[Entrance]]
 | 
			
		||||
    events: Set[Location]
 | 
			
		||||
    path: Dict[Union[Region, Entrance], PathValue]
 | 
			
		||||
    locations_checked: Set[Location]
 | 
			
		||||
    stale: Dict[int, bool]
 | 
			
		||||
    additional_init_functions: List[Callable[[CollectionState, MultiWorld], None]] = []
 | 
			
		||||
    additional_copy_functions: List[Callable[[CollectionState, CollectionState], CollectionState]] = []
 | 
			
		||||
 | 
			
		||||
    def __init__(self, parent: MultiWorld):
 | 
			
		||||
        self.prog_items = Counter()
 | 
			
		||||
| 
						 | 
				
			
			@ -603,6 +615,7 @@ class CollectionState():
 | 
			
		|||
            if new_region in rrp:
 | 
			
		||||
                bc.remove(connection)
 | 
			
		||||
            elif connection.can_reach(self):
 | 
			
		||||
                assert new_region, "tried to search through an Entrance with no Region"
 | 
			
		||||
                rrp.add(new_region)
 | 
			
		||||
                bc.remove(connection)
 | 
			
		||||
                bc.update(new_region.exits)
 | 
			
		||||
| 
						 | 
				
			
			@ -633,7 +646,8 @@ class CollectionState():
 | 
			
		|||
                  spot: Union[Location, Entrance, Region, str],
 | 
			
		||||
                  resolution_hint: Optional[str] = None,
 | 
			
		||||
                  player: Optional[int] = None) -> bool:
 | 
			
		||||
        if not hasattr(spot, "can_reach"):
 | 
			
		||||
        if isinstance(spot, str):
 | 
			
		||||
            assert isinstance(player, int), "can_reach: player is required if spot is str"
 | 
			
		||||
            # try to resolve a name
 | 
			
		||||
            if resolution_hint == 'Location':
 | 
			
		||||
                spot = self.world.get_location(spot, player)
 | 
			
		||||
| 
						 | 
				
			
			@ -644,7 +658,7 @@ class CollectionState():
 | 
			
		|||
                spot = self.world.get_region(spot, player)
 | 
			
		||||
        return spot.can_reach(self)
 | 
			
		||||
 | 
			
		||||
    def sweep_for_events(self, key_only: bool = False, locations: Set[Location] = None):
 | 
			
		||||
    def sweep_for_events(self, key_only: bool = False, locations: Optional[Iterable[Location]] = None) -> None:
 | 
			
		||||
        if locations is None:
 | 
			
		||||
            locations = self.world.get_filled_locations()
 | 
			
		||||
        new_locations = True
 | 
			
		||||
| 
						 | 
				
			
			@ -656,6 +670,7 @@ class CollectionState():
 | 
			
		|||
            new_locations = reachable_events - self.events
 | 
			
		||||
            for event in new_locations:
 | 
			
		||||
                self.events.add(event)
 | 
			
		||||
                assert isinstance(event.item, Item), "tried to collect Event with no Item"
 | 
			
		||||
                self.collect(event.item, True, event)
 | 
			
		||||
 | 
			
		||||
    def has(self, item: str, player: int, count: int = 1) -> bool:
 | 
			
		||||
| 
						 | 
				
			
			@ -670,7 +685,7 @@ class CollectionState():
 | 
			
		|||
    def count(self, item: str, player: int) -> int:
 | 
			
		||||
        return self.prog_items[item, player]
 | 
			
		||||
 | 
			
		||||
    def has_group(self, item_name_group: str, player: int, count: int = 1):
 | 
			
		||||
    def has_group(self, item_name_group: str, player: int, count: int = 1) -> bool:
 | 
			
		||||
        found: int = 0
 | 
			
		||||
        for item_name in self.world.worlds[player].item_name_groups[item_name_group]:
 | 
			
		||||
            found += self.prog_items[item_name, player]
 | 
			
		||||
| 
						 | 
				
			
			@ -678,7 +693,7 @@ class CollectionState():
 | 
			
		|||
                return True
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def count_group(self, item_name_group: str, player: int):
 | 
			
		||||
    def count_group(self, item_name_group: str, player: int) -> int:
 | 
			
		||||
        found: int = 0
 | 
			
		||||
        for item_name in self.world.worlds[player].item_name_groups[item_name_group]:
 | 
			
		||||
            found += self.prog_items[item_name, player]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue