Locality: rewrite for linear memory consumption, from quadratic (#1091)

This commit is contained in:
Fabian Dill 2022-10-17 03:22:02 +02:00 committed by GitHub
parent bb46ee7fc1
commit b533ffb9e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 22 deletions

View File

@ -16,7 +16,7 @@ from worlds.alttp.Regions import is_main_entrance
from Fill import distribute_items_restrictive, flood_items, balance_multiworld_progression, distribute_planned
from worlds.alttp.Shops import SHOP_ID_START, total_shop_slots, FillDisabledShopSlots
from Utils import output_path, get_options, __version__, version_tuple
from worlds.generic.Rules import locality_rules, exclusion_rules, group_locality_rules
from worlds.generic.Rules import locality_rules, exclusion_rules
from worlds import AutoWorld
ordered_areas = (
@ -107,7 +107,7 @@ def main(args, seed=None, baked_server_options: Optional[Dict[str, object]] = No
if world.goal[player] in ["localtriforcehunt", "localganontriforcehunt"]:
world.local_items[player].value.add('Triforce Piece')
# Not possible to place pendants/crystals out side of boss prizes yet.
# Not possible to place pendants/crystals outside boss prizes yet.
world.non_local_items[player].value -= item_name_groups['Pendants']
world.non_local_items[player].value -= item_name_groups['Crystals']
@ -122,9 +122,7 @@ def main(args, seed=None, baked_server_options: Optional[Dict[str, object]] = No
logger.info('Calculating Access Rules.')
if world.players > 1:
for player in world.player_ids:
locality_rules(world, player)
group_locality_rules(world)
locality_rules(world)
else:
world.non_local_items[1].value = set()
world.local_items[1].value = set()

View File

@ -575,8 +575,7 @@ class TestDistributeItemsRestrictive(unittest.TestCase):
multi_world.local_items[player1.id].value = set(names(player1.basic_items))
multi_world.local_items[player2.id].value = set(names(player2.basic_items))
locality_rules(multi_world, player1.id)
locality_rules(multi_world, player2.id)
locality_rules(multi_world)
distribute_items_restrictive(multi_world)

View File

@ -1,3 +1,4 @@
import collections
import typing
from BaseClasses import LocationProgressType, MultiWorld
@ -12,29 +13,72 @@ else:
ItemRule = typing.Callable[[object], bool]
def group_locality_rules(world):
def locality_needed(world: MultiWorld) -> bool:
for player in world.player_ids:
if world.local_items[player].value:
return True
if world.non_local_items[player].value:
return True
# Group
for group_id, group in world.groups.items():
if set(world.player_ids) == set(group["players"]):
continue
if group["local_items"]:
for location in world.get_locations():
if location.player not in group["players"]:
forbid_items_for_player(location, group["local_items"], group_id)
return True
if group["non_local_items"]:
for location in world.get_locations():
if location.player in group["players"]:
forbid_items_for_player(location, group["non_local_items"], group_id)
return True
def locality_rules(world, player: int):
if world.local_items[player].value:
def locality_rules(world: MultiWorld):
if locality_needed(world):
forbid_data: typing.Dict[int, typing.Dict[int, typing.Set[str]]] = \
collections.defaultdict(lambda: collections.defaultdict(set))
def forbid(sender: int, receiver: int, items: typing.Set[str]):
forbid_data[sender][receiver].update(items)
for receiving_player in world.player_ids:
local_items: typing.Set[str] = world.local_items[receiving_player].value
if local_items:
for sending_player in world.player_ids:
if receiving_player != sending_player:
forbid(sending_player, receiving_player, local_items)
non_local_items: typing.Set[str] = world.non_local_items[receiving_player].value
if non_local_items:
forbid(receiving_player, receiving_player, non_local_items)
# Group
for receiving_group_id, receiving_group in world.groups.items():
if set(world.player_ids) == set(receiving_group["players"]):
continue
if receiving_group["local_items"]:
for sending_player in world.player_ids:
if sending_player not in receiving_group["players"]:
forbid(sending_player, receiving_group_id, receiving_group["local_items"])
if receiving_group["non_local_items"]:
for sending_player in world.player_ids:
if sending_player in receiving_group["players"]:
forbid(sending_player, receiving_group_id, receiving_group["non_local_items"])
# create fewer lambda's to save memory and cache misses
func_cache = {}
for location in world.get_locations():
if location.player != player:
forbid_items_for_player(location, world.local_items[player].value, player)
if world.non_local_items[player].value:
for location in world.get_locations():
if location.player == player:
forbid_items_for_player(location, world.non_local_items[player].value, player)
if (location.player, location.item_rule) in func_cache:
location.item_rule = func_cache[location.player, location.item_rule]
# empty rule that just returns True, overwrite
elif location.item_rule is location.__class__.item_rule:
func_cache[location.player, location.item_rule] = location.item_rule = \
lambda i, sending_blockers = forbid_data[location.player], \
old_rule = location.item_rule: \
i.name not in sending_blockers[i.player]
# special rule, needs to also be fulfilled.
else:
func_cache[location.player, location.item_rule] = location.item_rule = \
lambda i, sending_blockers = forbid_data[location.player], \
old_rule = location.item_rule: \
i.name not in sending_blockers[i.player] and old_rule(i)
def exclusion_rules(world: MultiWorld, player: int, exclude_locations: typing.Set[str]) -> None: