From b533ffb9e8e8d93795a2962ec610fbf474a59626 Mon Sep 17 00:00:00 2001 From: Fabian Dill Date: Mon, 17 Oct 2022 03:22:02 +0200 Subject: [PATCH] Locality: rewrite for linear memory consumption, from quadratic (#1091) --- Main.py | 8 ++--- test/general/TestFill.py | 3 +- worlds/generic/Rules.py | 74 ++++++++++++++++++++++++++++++++-------- 3 files changed, 63 insertions(+), 22 deletions(-) diff --git a/Main.py b/Main.py index b26dcf79..63f5b8a8 100644 --- a/Main.py +++ b/Main.py @@ -16,7 +16,7 @@ from worlds.alttp.Regions import is_main_entrance from Fill import distribute_items_restrictive, flood_items, balance_multiworld_progression, distribute_planned from worlds.alttp.Shops import SHOP_ID_START, total_shop_slots, FillDisabledShopSlots from Utils import output_path, get_options, __version__, version_tuple -from worlds.generic.Rules import locality_rules, exclusion_rules, group_locality_rules +from worlds.generic.Rules import locality_rules, exclusion_rules from worlds import AutoWorld ordered_areas = ( @@ -107,7 +107,7 @@ def main(args, seed=None, baked_server_options: Optional[Dict[str, object]] = No if world.goal[player] in ["localtriforcehunt", "localganontriforcehunt"]: world.local_items[player].value.add('Triforce Piece') - # Not possible to place pendants/crystals out side of boss prizes yet. + # Not possible to place pendants/crystals outside boss prizes yet. world.non_local_items[player].value -= item_name_groups['Pendants'] world.non_local_items[player].value -= item_name_groups['Crystals'] @@ -122,9 +122,7 @@ def main(args, seed=None, baked_server_options: Optional[Dict[str, object]] = No logger.info('Calculating Access Rules.') if world.players > 1: - for player in world.player_ids: - locality_rules(world, player) - group_locality_rules(world) + locality_rules(world) else: world.non_local_items[1].value = set() world.local_items[1].value = set() diff --git a/test/general/TestFill.py b/test/general/TestFill.py index 8ce5b3b2..1893f6bd 100644 --- a/test/general/TestFill.py +++ b/test/general/TestFill.py @@ -575,8 +575,7 @@ class TestDistributeItemsRestrictive(unittest.TestCase): multi_world.local_items[player1.id].value = set(names(player1.basic_items)) multi_world.local_items[player2.id].value = set(names(player2.basic_items)) - locality_rules(multi_world, player1.id) - locality_rules(multi_world, player2.id) + locality_rules(multi_world) distribute_items_restrictive(multi_world) diff --git a/worlds/generic/Rules.py b/worlds/generic/Rules.py index 6f70e1b5..e69a34c5 100644 --- a/worlds/generic/Rules.py +++ b/worlds/generic/Rules.py @@ -1,3 +1,4 @@ +import collections import typing from BaseClasses import LocationProgressType, MultiWorld @@ -12,29 +13,72 @@ else: ItemRule = typing.Callable[[object], bool] -def group_locality_rules(world): +def locality_needed(world: MultiWorld) -> bool: + for player in world.player_ids: + if world.local_items[player].value: + return True + if world.non_local_items[player].value: + return True + + # Group for group_id, group in world.groups.items(): if set(world.player_ids) == set(group["players"]): continue if group["local_items"]: - for location in world.get_locations(): - if location.player not in group["players"]: - forbid_items_for_player(location, group["local_items"], group_id) + return True if group["non_local_items"]: - for location in world.get_locations(): - if location.player in group["players"]: - forbid_items_for_player(location, group["non_local_items"], group_id) + return True -def locality_rules(world, player: int): - if world.local_items[player].value: +def locality_rules(world: MultiWorld): + if locality_needed(world): + + forbid_data: typing.Dict[int, typing.Dict[int, typing.Set[str]]] = \ + collections.defaultdict(lambda: collections.defaultdict(set)) + + def forbid(sender: int, receiver: int, items: typing.Set[str]): + forbid_data[sender][receiver].update(items) + + for receiving_player in world.player_ids: + local_items: typing.Set[str] = world.local_items[receiving_player].value + if local_items: + for sending_player in world.player_ids: + if receiving_player != sending_player: + forbid(sending_player, receiving_player, local_items) + non_local_items: typing.Set[str] = world.non_local_items[receiving_player].value + if non_local_items: + forbid(receiving_player, receiving_player, non_local_items) + + # Group + for receiving_group_id, receiving_group in world.groups.items(): + if set(world.player_ids) == set(receiving_group["players"]): + continue + if receiving_group["local_items"]: + for sending_player in world.player_ids: + if sending_player not in receiving_group["players"]: + forbid(sending_player, receiving_group_id, receiving_group["local_items"]) + if receiving_group["non_local_items"]: + for sending_player in world.player_ids: + if sending_player in receiving_group["players"]: + forbid(sending_player, receiving_group_id, receiving_group["non_local_items"]) + + # create fewer lambda's to save memory and cache misses + func_cache = {} for location in world.get_locations(): - if location.player != player: - forbid_items_for_player(location, world.local_items[player].value, player) - if world.non_local_items[player].value: - for location in world.get_locations(): - if location.player == player: - forbid_items_for_player(location, world.non_local_items[player].value, player) + if (location.player, location.item_rule) in func_cache: + location.item_rule = func_cache[location.player, location.item_rule] + # empty rule that just returns True, overwrite + elif location.item_rule is location.__class__.item_rule: + func_cache[location.player, location.item_rule] = location.item_rule = \ + lambda i, sending_blockers = forbid_data[location.player], \ + old_rule = location.item_rule: \ + i.name not in sending_blockers[i.player] + # special rule, needs to also be fulfilled. + else: + func_cache[location.player, location.item_rule] = location.item_rule = \ + lambda i, sending_blockers = forbid_data[location.player], \ + old_rule = location.item_rule: \ + i.name not in sending_blockers[i.player] and old_rule(i) def exclusion_rules(world: MultiWorld, player: int, exclude_locations: typing.Set[str]) -> None: