Core: make state.prog_items a `Dict[int, Counter[str]]` (#2407)

This commit is contained in:
Aaron Wagener 2023-11-02 00:41:20 -05:00 committed by GitHub
parent 19dc0720ba
commit 5669579374
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 46 additions and 51 deletions

View File

@ -605,7 +605,7 @@ PathValue = Tuple[str, Optional["PathValue"]]
class CollectionState(): class CollectionState():
prog_items: typing.Counter[Tuple[str, int]] prog_items: Dict[int, Counter[str]]
multiworld: MultiWorld multiworld: MultiWorld
reachable_regions: Dict[int, Set[Region]] reachable_regions: Dict[int, Set[Region]]
blocked_connections: Dict[int, Set[Entrance]] blocked_connections: Dict[int, Set[Entrance]]
@ -617,7 +617,7 @@ class CollectionState():
additional_copy_functions: List[Callable[[CollectionState, CollectionState], CollectionState]] = [] additional_copy_functions: List[Callable[[CollectionState, CollectionState], CollectionState]] = []
def __init__(self, parent: MultiWorld): def __init__(self, parent: MultiWorld):
self.prog_items = Counter() self.prog_items = {player: Counter() for player in parent.player_ids}
self.multiworld = parent self.multiworld = parent
self.reachable_regions = {player: set() for player in parent.get_all_ids()} self.reachable_regions = {player: set() for player in parent.get_all_ids()}
self.blocked_connections = {player: set() for player in parent.get_all_ids()} self.blocked_connections = {player: set() for player in parent.get_all_ids()}
@ -665,7 +665,7 @@ class CollectionState():
def copy(self) -> CollectionState: def copy(self) -> CollectionState:
ret = CollectionState(self.multiworld) ret = CollectionState(self.multiworld)
ret.prog_items = self.prog_items.copy() ret.prog_items = copy.deepcopy(self.prog_items)
ret.reachable_regions = {player: copy.copy(self.reachable_regions[player]) for player in ret.reachable_regions = {player: copy.copy(self.reachable_regions[player]) for player in
self.reachable_regions} self.reachable_regions}
ret.blocked_connections = {player: copy.copy(self.blocked_connections[player]) for player in ret.blocked_connections = {player: copy.copy(self.blocked_connections[player]) for player in
@ -709,23 +709,23 @@ class CollectionState():
self.collect(event.item, True, event) self.collect(event.item, True, event)
def has(self, item: str, player: int, count: int = 1) -> bool: def has(self, item: str, player: int, count: int = 1) -> bool:
return self.prog_items[item, player] >= count return self.prog_items[player][item] >= count
def has_all(self, items: Set[str], player: int) -> bool: def has_all(self, items: Set[str], player: int) -> bool:
"""Returns True if each item name of items is in state at least once.""" """Returns True if each item name of items is in state at least once."""
return all(self.prog_items[item, player] for item in items) return all(self.prog_items[player][item] for item in items)
def has_any(self, items: Set[str], player: int) -> bool: def has_any(self, items: Set[str], player: int) -> bool:
"""Returns True if at least one item name of items is in state at least once.""" """Returns True if at least one item name of items is in state at least once."""
return any(self.prog_items[item, player] for item in items) return any(self.prog_items[player][item] for item in items)
def count(self, item: str, player: int) -> int: def count(self, item: str, player: int) -> int:
return self.prog_items[item, player] return self.prog_items[player][item]
def has_group(self, item_name_group: str, player: int, count: int = 1) -> bool: def has_group(self, item_name_group: str, player: int, count: int = 1) -> bool:
found: int = 0 found: int = 0
for item_name in self.multiworld.worlds[player].item_name_groups[item_name_group]: for item_name in self.multiworld.worlds[player].item_name_groups[item_name_group]:
found += self.prog_items[item_name, player] found += self.prog_items[player][item_name]
if found >= count: if found >= count:
return True return True
return False return False
@ -733,11 +733,11 @@ class CollectionState():
def count_group(self, item_name_group: str, player: int) -> int: def count_group(self, item_name_group: str, player: int) -> int:
found: int = 0 found: int = 0
for item_name in self.multiworld.worlds[player].item_name_groups[item_name_group]: for item_name in self.multiworld.worlds[player].item_name_groups[item_name_group]:
found += self.prog_items[item_name, player] found += self.prog_items[player][item_name]
return found return found
def item_count(self, item: str, player: int) -> int: def item_count(self, item: str, player: int) -> int:
return self.prog_items[item, player] return self.prog_items[player][item]
def collect(self, item: Item, event: bool = False, location: Optional[Location] = None) -> bool: def collect(self, item: Item, event: bool = False, location: Optional[Location] = None) -> bool:
if location: if location:
@ -746,7 +746,7 @@ class CollectionState():
changed = self.multiworld.worlds[item.player].collect(self, item) changed = self.multiworld.worlds[item.player].collect(self, item)
if not changed and event: if not changed and event:
self.prog_items[item.name, item.player] += 1 self.prog_items[item.player][item.name] += 1
changed = True changed = True
self.stale[item.player] = True self.stale[item.player] = True

View File

@ -455,8 +455,8 @@ class TestFillRestrictive(unittest.TestCase):
location.place_locked_item(item) location.place_locked_item(item)
multi_world.state.sweep_for_events() multi_world.state.sweep_for_events()
multi_world.state.sweep_for_events() multi_world.state.sweep_for_events()
self.assertTrue(multi_world.state.prog_items[item.name, item.player], "Sweep did not collect - Test flawed") self.assertTrue(multi_world.state.prog_items[item.player][item.name], "Sweep did not collect - Test flawed")
self.assertEqual(multi_world.state.prog_items[item.name, item.player], 1, "Sweep collected multiple times") self.assertEqual(multi_world.state.prog_items[item.player][item.name], 1, "Sweep collected multiple times")
def test_correct_item_instance_removed_from_pool(self): def test_correct_item_instance_removed_from_pool(self):
"""Test that a placed item gets removed from the submitted pool""" """Test that a placed item gets removed from the submitted pool"""

View File

@ -414,16 +414,16 @@ class World(metaclass=AutoWorldRegister):
def collect(self, state: "CollectionState", item: "Item") -> bool: def collect(self, state: "CollectionState", item: "Item") -> bool:
name = self.collect_item(state, item) name = self.collect_item(state, item)
if name: if name:
state.prog_items[name, self.player] += 1 state.prog_items[self.player][name] += 1
return True return True
return False return False
def remove(self, state: "CollectionState", item: "Item") -> bool: def remove(self, state: "CollectionState", item: "Item") -> bool:
name = self.collect_item(state, item, True) name = self.collect_item(state, item, True)
if name: if name:
state.prog_items[name, self.player] -= 1 state.prog_items[self.player][name] -= 1
if state.prog_items[name, self.player] < 1: if state.prog_items[self.player][name] < 1:
del (state.prog_items[name, self.player]) del (state.prog_items[self.player][name])
return True return True
return False return False

View File

@ -31,7 +31,7 @@ def fake_pearl_state(state, player):
if state.has('Moon Pearl', player): if state.has('Moon Pearl', player):
return state return state
fake_state = state.copy() fake_state = state.copy()
fake_state.prog_items['Moon Pearl', player] += 1 fake_state.prog_items[player]['Moon Pearl'] += 1
return fake_state return fake_state

View File

@ -830,4 +830,4 @@ class ALttPLogic(LogicMixin):
return True return True
if self.multiworld.smallkey_shuffle[player] == smallkey_shuffle.option_universal: if self.multiworld.smallkey_shuffle[player] == smallkey_shuffle.option_universal:
return can_buy_unlimited(self, 'Small Key (Universal)', player) return can_buy_unlimited(self, 'Small Key (Universal)', player)
return self.prog_items[item, player] >= count return self.prog_items[player][item] >= count

View File

@ -5,12 +5,7 @@ from ..generic.Rules import set_rule
class ArchipIDLELogic(LogicMixin): class ArchipIDLELogic(LogicMixin):
def _archipidle_location_is_accessible(self, player_id, items_required): def _archipidle_location_is_accessible(self, player_id, items_required):
items_received = 0 return sum(self.prog_items[player_id].values()) >= items_required
for item in self.prog_items:
if item[1] == player_id:
items_received += 1
return items_received >= items_required
def set_rules(world: MultiWorld, player: int): def set_rules(world: MultiWorld, player: int):

View File

@ -12,11 +12,11 @@ def create_event(player, event: str) -> DLCQuestItem:
def has_enough_coin(player: int, coin: int): def has_enough_coin(player: int, coin: int):
return lambda state: state.prog_items[" coins", player] >= coin return lambda state: state.prog_items[player][" coins"] >= coin
def has_enough_coin_freemium(player: int, coin: int): def has_enough_coin_freemium(player: int, coin: int):
return lambda state: state.prog_items[" coins freemium", player] >= coin return lambda state: state.prog_items[player][" coins freemium"] >= coin
def set_rules(world, player, World_Options: Options.DLCQuestOptions): def set_rules(world, player, World_Options: Options.DLCQuestOptions):

View File

@ -92,7 +92,7 @@ class DLCqworld(World):
if change: if change:
suffix = item.coin_suffix suffix = item.coin_suffix
if suffix: if suffix:
state.prog_items[suffix, self.player] += item.coins state.prog_items[self.player][suffix] += item.coins
return change return change
def remove(self, state: CollectionState, item: DLCQuestItem) -> bool: def remove(self, state: CollectionState, item: DLCQuestItem) -> bool:
@ -100,5 +100,5 @@ class DLCqworld(World):
if change: if change:
suffix = item.coin_suffix suffix = item.coin_suffix
if suffix: if suffix:
state.prog_items[suffix, self.player] -= item.coins state.prog_items[self.player][suffix] -= item.coins
return change return change

View File

@ -517,12 +517,12 @@ class HKWorld(World):
change = super(HKWorld, self).collect(state, item) change = super(HKWorld, self).collect(state, item)
if change: if change:
for effect_name, effect_value in item_effects.get(item.name, {}).items(): for effect_name, effect_value in item_effects.get(item.name, {}).items():
state.prog_items[effect_name, item.player] += effect_value state.prog_items[item.player][effect_name] += effect_value
if item.name in {"Left_Mothwing_Cloak", "Right_Mothwing_Cloak"}: if item.name in {"Left_Mothwing_Cloak", "Right_Mothwing_Cloak"}:
if state.prog_items.get(('RIGHTDASH', item.player), 0) and \ if state.prog_items[item.player].get('RIGHTDASH', 0) and \
state.prog_items.get(('LEFTDASH', item.player), 0): state.prog_items[item.player].get('LEFTDASH', 0):
(state.prog_items["RIGHTDASH", item.player], state.prog_items["LEFTDASH", item.player]) = \ (state.prog_items[item.player]["RIGHTDASH"], state.prog_items[item.player]["LEFTDASH"]) = \
([max(state.prog_items["RIGHTDASH", item.player], state.prog_items["LEFTDASH", item.player])] * 2) ([max(state.prog_items[item.player]["RIGHTDASH"], state.prog_items[item.player]["LEFTDASH"])] * 2)
return change return change
def remove(self, state, item: HKItem) -> bool: def remove(self, state, item: HKItem) -> bool:
@ -530,9 +530,9 @@ class HKWorld(World):
if change: if change:
for effect_name, effect_value in item_effects.get(item.name, {}).items(): for effect_name, effect_value in item_effects.get(item.name, {}).items():
if state.prog_items[effect_name, item.player] == effect_value: if state.prog_items[item.player][effect_name] == effect_value:
del state.prog_items[effect_name, item.player] del state.prog_items[item.player][effect_name]
state.prog_items[effect_name, item.player] -= effect_value state.prog_items[item.player][effect_name] -= effect_value
return change return change

View File

@ -124,13 +124,13 @@ class GameStateAdapater:
# Don't allow any money usage if you can't get back wasted rupees # Don't allow any money usage if you can't get back wasted rupees
if item == "RUPEES": if item == "RUPEES":
if can_farm_rupees(self.state, self.player): if can_farm_rupees(self.state, self.player):
return self.state.prog_items["RUPEES", self.player] return self.state.prog_items[self.player]["RUPEES"]
return 0 return 0
elif item.endswith("_USED"): elif item.endswith("_USED"):
return 0 return 0
else: else:
item = ladxr_item_to_la_item_name[item] item = ladxr_item_to_la_item_name[item]
return self.state.prog_items.get((item, self.player), default) return self.state.prog_items[self.player].get(item, default)
class LinksAwakeningEntrance(Entrance): class LinksAwakeningEntrance(Entrance):

View File

@ -513,7 +513,7 @@ class LinksAwakeningWorld(World):
change = super().collect(state, item) change = super().collect(state, item)
if change: if change:
rupees = self.rupees.get(item.name, 0) rupees = self.rupees.get(item.name, 0)
state.prog_items["RUPEES", item.player] += rupees state.prog_items[item.player]["RUPEES"] += rupees
return change return change
@ -521,6 +521,6 @@ class LinksAwakeningWorld(World):
change = super().remove(state, item) change = super().remove(state, item)
if change: if change:
rupees = self.rupees.get(item.name, 0) rupees = self.rupees.get(item.name, 0)
state.prog_items["RUPEES", item.player] -= rupees state.prog_items[item.player]["RUPEES"] -= rupees
return change return change

View File

@ -188,6 +188,6 @@ class MessengerWorld(World):
shard_count = int(item.name.strip("Time Shard ()")) shard_count = int(item.name.strip("Time Shard ()"))
if remove: if remove:
shard_count = -shard_count shard_count = -shard_count
state.prog_items["Shards", self.player] += shard_count state.prog_items[self.player]["Shards"] += shard_count
return super().collect_item(state, item, remove) return super().collect_item(state, item, remove)

View File

@ -1260,16 +1260,16 @@ class OOTWorld(World):
def collect(self, state: CollectionState, item: OOTItem) -> bool: def collect(self, state: CollectionState, item: OOTItem) -> bool:
if item.advancement and item.special and item.special.get('alias', False): if item.advancement and item.special and item.special.get('alias', False):
alt_item_name, count = item.special.get('alias') alt_item_name, count = item.special.get('alias')
state.prog_items[alt_item_name, self.player] += count state.prog_items[self.player][alt_item_name] += count
return True return True
return super().collect(state, item) return super().collect(state, item)
def remove(self, state: CollectionState, item: OOTItem) -> bool: def remove(self, state: CollectionState, item: OOTItem) -> bool:
if item.advancement and item.special and item.special.get('alias', False): if item.advancement and item.special and item.special.get('alias', False):
alt_item_name, count = item.special.get('alias') alt_item_name, count = item.special.get('alias')
state.prog_items[alt_item_name, self.player] -= count state.prog_items[self.player][alt_item_name] -= count
if state.prog_items[alt_item_name, self.player] < 1: if state.prog_items[self.player][alt_item_name] < 1:
del (state.prog_items[alt_item_name, self.player]) del (state.prog_items[self.player][alt_item_name])
return True return True
return super().remove(state, item) return super().remove(state, item)

View File

@ -470,7 +470,7 @@ class SMZ3World(World):
def collect(self, state: CollectionState, item: Item) -> bool: def collect(self, state: CollectionState, item: Item) -> bool:
state.smz3state[self.player].Add([TotalSMZ3Item.Item(TotalSMZ3Item.ItemType[item.name], self.smz3World if hasattr(self, "smz3World") else None)]) state.smz3state[self.player].Add([TotalSMZ3Item.Item(TotalSMZ3Item.ItemType[item.name], self.smz3World if hasattr(self, "smz3World") else None)])
if item.advancement: if item.advancement:
state.prog_items[item.name, item.player] += 1 state.prog_items[item.player][item.name] += 1
return True # indicate that a logical state change has occured return True # indicate that a logical state change has occured
return False return False
@ -478,9 +478,9 @@ class SMZ3World(World):
name = self.collect_item(state, item, True) name = self.collect_item(state, item, True)
if name: if name:
state.smz3state[item.player].Remove([TotalSMZ3Item.Item(TotalSMZ3Item.ItemType[item.name], self.smz3World if hasattr(self, "smz3World") else None)]) state.smz3state[item.player].Remove([TotalSMZ3Item.Item(TotalSMZ3Item.ItemType[item.name], self.smz3World if hasattr(self, "smz3World") else None)])
state.prog_items[name, item.player] -= 1 state.prog_items[item.player][item.name] -= 1
if state.prog_items[name, item.player] < 1: if state.prog_items[item.player][item.name] < 1:
del (state.prog_items[name, item.player]) del (state.prog_items[item.player][item.name])
return True return True
return False return False

View File

@ -24,7 +24,7 @@ class TestProgressiveToolsLogic(SVTestBase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.multiworld.state.prog_items = Counter() self.multiworld.state.prog_items = {1: Counter()}
def test_sturgeon(self): def test_sturgeon(self):
self.assertFalse(self.world.logic.has("Sturgeon")(self.multiworld.state)) self.assertFalse(self.world.logic.has("Sturgeon")(self.multiworld.state))