Core: make state.prog_items a `Dict[int, Counter[str]]` (#2407)

This commit is contained in:
Aaron Wagener 2023-11-02 00:41:20 -05:00 committed by GitHub
parent 19dc0720ba
commit 5669579374
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 46 additions and 51 deletions

View File

@ -605,7 +605,7 @@ PathValue = Tuple[str, Optional["PathValue"]]
class CollectionState():
prog_items: typing.Counter[Tuple[str, int]]
prog_items: Dict[int, Counter[str]]
multiworld: MultiWorld
reachable_regions: Dict[int, Set[Region]]
blocked_connections: Dict[int, Set[Entrance]]
@ -617,7 +617,7 @@ class CollectionState():
additional_copy_functions: List[Callable[[CollectionState, CollectionState], CollectionState]] = []
def __init__(self, parent: MultiWorld):
self.prog_items = Counter()
self.prog_items = {player: Counter() for player in parent.player_ids}
self.multiworld = parent
self.reachable_regions = {player: set() for player in parent.get_all_ids()}
self.blocked_connections = {player: set() for player in parent.get_all_ids()}
@ -665,7 +665,7 @@ class CollectionState():
def copy(self) -> CollectionState:
ret = CollectionState(self.multiworld)
ret.prog_items = self.prog_items.copy()
ret.prog_items = copy.deepcopy(self.prog_items)
ret.reachable_regions = {player: copy.copy(self.reachable_regions[player]) for player in
self.reachable_regions}
ret.blocked_connections = {player: copy.copy(self.blocked_connections[player]) for player in
@ -709,23 +709,23 @@ class CollectionState():
self.collect(event.item, True, event)
def has(self, item: str, player: int, count: int = 1) -> bool:
return self.prog_items[item, player] >= count
return self.prog_items[player][item] >= count
def has_all(self, items: Set[str], player: int) -> bool:
"""Returns True if each item name of items is in state at least once."""
return all(self.prog_items[item, player] for item in items)
return all(self.prog_items[player][item] for item in items)
def has_any(self, items: Set[str], player: int) -> bool:
"""Returns True if at least one item name of items is in state at least once."""
return any(self.prog_items[item, player] for item in items)
return any(self.prog_items[player][item] for item in items)
def count(self, item: str, player: int) -> int:
return self.prog_items[item, player]
return self.prog_items[player][item]
def has_group(self, item_name_group: str, player: int, count: int = 1) -> bool:
found: int = 0
for item_name in self.multiworld.worlds[player].item_name_groups[item_name_group]:
found += self.prog_items[item_name, player]
found += self.prog_items[player][item_name]
if found >= count:
return True
return False
@ -733,11 +733,11 @@ class CollectionState():
def count_group(self, item_name_group: str, player: int) -> int:
found: int = 0
for item_name in self.multiworld.worlds[player].item_name_groups[item_name_group]:
found += self.prog_items[item_name, player]
found += self.prog_items[player][item_name]
return found
def item_count(self, item: str, player: int) -> int:
return self.prog_items[item, player]
return self.prog_items[player][item]
def collect(self, item: Item, event: bool = False, location: Optional[Location] = None) -> bool:
if location:
@ -746,7 +746,7 @@ class CollectionState():
changed = self.multiworld.worlds[item.player].collect(self, item)
if not changed and event:
self.prog_items[item.name, item.player] += 1
self.prog_items[item.player][item.name] += 1
changed = True
self.stale[item.player] = True

View File

@ -455,8 +455,8 @@ class TestFillRestrictive(unittest.TestCase):
location.place_locked_item(item)
multi_world.state.sweep_for_events()
multi_world.state.sweep_for_events()
self.assertTrue(multi_world.state.prog_items[item.name, item.player], "Sweep did not collect - Test flawed")
self.assertEqual(multi_world.state.prog_items[item.name, item.player], 1, "Sweep collected multiple times")
self.assertTrue(multi_world.state.prog_items[item.player][item.name], "Sweep did not collect - Test flawed")
self.assertEqual(multi_world.state.prog_items[item.player][item.name], 1, "Sweep collected multiple times")
def test_correct_item_instance_removed_from_pool(self):
"""Test that a placed item gets removed from the submitted pool"""

View File

@ -414,16 +414,16 @@ class World(metaclass=AutoWorldRegister):
def collect(self, state: "CollectionState", item: "Item") -> bool:
name = self.collect_item(state, item)
if name:
state.prog_items[name, self.player] += 1
state.prog_items[self.player][name] += 1
return True
return False
def remove(self, state: "CollectionState", item: "Item") -> bool:
name = self.collect_item(state, item, True)
if name:
state.prog_items[name, self.player] -= 1
if state.prog_items[name, self.player] < 1:
del (state.prog_items[name, self.player])
state.prog_items[self.player][name] -= 1
if state.prog_items[self.player][name] < 1:
del (state.prog_items[self.player][name])
return True
return False

View File

@ -31,7 +31,7 @@ def fake_pearl_state(state, player):
if state.has('Moon Pearl', player):
return state
fake_state = state.copy()
fake_state.prog_items['Moon Pearl', player] += 1
fake_state.prog_items[player]['Moon Pearl'] += 1
return fake_state

View File

@ -830,4 +830,4 @@ class ALttPLogic(LogicMixin):
return True
if self.multiworld.smallkey_shuffle[player] == smallkey_shuffle.option_universal:
return can_buy_unlimited(self, 'Small Key (Universal)', player)
return self.prog_items[item, player] >= count
return self.prog_items[player][item] >= count

View File

@ -5,12 +5,7 @@ from ..generic.Rules import set_rule
class ArchipIDLELogic(LogicMixin):
def _archipidle_location_is_accessible(self, player_id, items_required):
items_received = 0
for item in self.prog_items:
if item[1] == player_id:
items_received += 1
return items_received >= items_required
return sum(self.prog_items[player_id].values()) >= items_required
def set_rules(world: MultiWorld, player: int):

View File

@ -12,11 +12,11 @@ def create_event(player, event: str) -> DLCQuestItem:
def has_enough_coin(player: int, coin: int):
return lambda state: state.prog_items[" coins", player] >= coin
return lambda state: state.prog_items[player][" coins"] >= coin
def has_enough_coin_freemium(player: int, coin: int):
return lambda state: state.prog_items[" coins freemium", player] >= coin
return lambda state: state.prog_items[player][" coins freemium"] >= coin
def set_rules(world, player, World_Options: Options.DLCQuestOptions):

View File

@ -92,7 +92,7 @@ class DLCqworld(World):
if change:
suffix = item.coin_suffix
if suffix:
state.prog_items[suffix, self.player] += item.coins
state.prog_items[self.player][suffix] += item.coins
return change
def remove(self, state: CollectionState, item: DLCQuestItem) -> bool:
@ -100,5 +100,5 @@ class DLCqworld(World):
if change:
suffix = item.coin_suffix
if suffix:
state.prog_items[suffix, self.player] -= item.coins
state.prog_items[self.player][suffix] -= item.coins
return change

View File

@ -517,12 +517,12 @@ class HKWorld(World):
change = super(HKWorld, self).collect(state, item)
if change:
for effect_name, effect_value in item_effects.get(item.name, {}).items():
state.prog_items[effect_name, item.player] += effect_value
state.prog_items[item.player][effect_name] += effect_value
if item.name in {"Left_Mothwing_Cloak", "Right_Mothwing_Cloak"}:
if state.prog_items.get(('RIGHTDASH', item.player), 0) and \
state.prog_items.get(('LEFTDASH', item.player), 0):
(state.prog_items["RIGHTDASH", item.player], state.prog_items["LEFTDASH", item.player]) = \
([max(state.prog_items["RIGHTDASH", item.player], state.prog_items["LEFTDASH", item.player])] * 2)
if state.prog_items[item.player].get('RIGHTDASH', 0) and \
state.prog_items[item.player].get('LEFTDASH', 0):
(state.prog_items[item.player]["RIGHTDASH"], state.prog_items[item.player]["LEFTDASH"]) = \
([max(state.prog_items[item.player]["RIGHTDASH"], state.prog_items[item.player]["LEFTDASH"])] * 2)
return change
def remove(self, state, item: HKItem) -> bool:
@ -530,9 +530,9 @@ class HKWorld(World):
if change:
for effect_name, effect_value in item_effects.get(item.name, {}).items():
if state.prog_items[effect_name, item.player] == effect_value:
del state.prog_items[effect_name, item.player]
state.prog_items[effect_name, item.player] -= effect_value
if state.prog_items[item.player][effect_name] == effect_value:
del state.prog_items[item.player][effect_name]
state.prog_items[item.player][effect_name] -= effect_value
return change

View File

@ -124,13 +124,13 @@ class GameStateAdapater:
# Don't allow any money usage if you can't get back wasted rupees
if item == "RUPEES":
if can_farm_rupees(self.state, self.player):
return self.state.prog_items["RUPEES", self.player]
return self.state.prog_items[self.player]["RUPEES"]
return 0
elif item.endswith("_USED"):
return 0
else:
item = ladxr_item_to_la_item_name[item]
return self.state.prog_items.get((item, self.player), default)
return self.state.prog_items[self.player].get(item, default)
class LinksAwakeningEntrance(Entrance):

View File

@ -513,7 +513,7 @@ class LinksAwakeningWorld(World):
change = super().collect(state, item)
if change:
rupees = self.rupees.get(item.name, 0)
state.prog_items["RUPEES", item.player] += rupees
state.prog_items[item.player]["RUPEES"] += rupees
return change
@ -521,6 +521,6 @@ class LinksAwakeningWorld(World):
change = super().remove(state, item)
if change:
rupees = self.rupees.get(item.name, 0)
state.prog_items["RUPEES", item.player] -= rupees
state.prog_items[item.player]["RUPEES"] -= rupees
return change

View File

@ -188,6 +188,6 @@ class MessengerWorld(World):
shard_count = int(item.name.strip("Time Shard ()"))
if remove:
shard_count = -shard_count
state.prog_items["Shards", self.player] += shard_count
state.prog_items[self.player]["Shards"] += shard_count
return super().collect_item(state, item, remove)

View File

@ -1260,16 +1260,16 @@ class OOTWorld(World):
def collect(self, state: CollectionState, item: OOTItem) -> bool:
if item.advancement and item.special and item.special.get('alias', False):
alt_item_name, count = item.special.get('alias')
state.prog_items[alt_item_name, self.player] += count
state.prog_items[self.player][alt_item_name] += count
return True
return super().collect(state, item)
def remove(self, state: CollectionState, item: OOTItem) -> bool:
if item.advancement and item.special and item.special.get('alias', False):
alt_item_name, count = item.special.get('alias')
state.prog_items[alt_item_name, self.player] -= count
if state.prog_items[alt_item_name, self.player] < 1:
del (state.prog_items[alt_item_name, self.player])
state.prog_items[self.player][alt_item_name] -= count
if state.prog_items[self.player][alt_item_name] < 1:
del (state.prog_items[self.player][alt_item_name])
return True
return super().remove(state, item)

View File

@ -470,7 +470,7 @@ class SMZ3World(World):
def collect(self, state: CollectionState, item: Item) -> bool:
state.smz3state[self.player].Add([TotalSMZ3Item.Item(TotalSMZ3Item.ItemType[item.name], self.smz3World if hasattr(self, "smz3World") else None)])
if item.advancement:
state.prog_items[item.name, item.player] += 1
state.prog_items[item.player][item.name] += 1
return True # indicate that a logical state change has occured
return False
@ -478,9 +478,9 @@ class SMZ3World(World):
name = self.collect_item(state, item, True)
if name:
state.smz3state[item.player].Remove([TotalSMZ3Item.Item(TotalSMZ3Item.ItemType[item.name], self.smz3World if hasattr(self, "smz3World") else None)])
state.prog_items[name, item.player] -= 1
if state.prog_items[name, item.player] < 1:
del (state.prog_items[name, item.player])
state.prog_items[item.player][item.name] -= 1
if state.prog_items[item.player][item.name] < 1:
del (state.prog_items[item.player][item.name])
return True
return False

View File

@ -24,7 +24,7 @@ class TestProgressiveToolsLogic(SVTestBase):
def setUp(self):
super().setUp()
self.multiworld.state.prog_items = Counter()
self.multiworld.state.prog_items = {1: Counter()}
def test_sturgeon(self):
self.assertFalse(self.world.logic.has("Sturgeon")(self.multiworld.state))