Switch to simpler caching system

This should speed up generating the seeds the currently take
the longest. Seems to have no impact on the average case.
This commit is contained in:
Kevin Cathcart 2019-07-08 22:48:16 -04:00
parent 54c53ea07e
commit d44d194de7
3 changed files with 29 additions and 56 deletions

View File

@ -172,7 +172,6 @@ class World(object):
'Small Key (Swamp Palace)', 'Big Key (Ice Palace)'] + ['Small Key (Ice Palace)'] * 2 + ['Big Key (Misery Mire)', 'Big Key (Turtle Rock)', 'Big Key (Ganons Tower)'] + ['Small Key (Misery Mire)'] * 3 + ['Small Key (Turtle Rock)'] * 4 + ['Small Key (Ganons Tower)'] * 4): 'Small Key (Swamp Palace)', 'Big Key (Ice Palace)'] + ['Small Key (Ice Palace)'] * 2 + ['Big Key (Misery Mire)', 'Big Key (Turtle Rock)', 'Big Key (Ganons Tower)'] + ['Small Key (Misery Mire)'] * 3 + ['Small Key (Turtle Rock)'] * 4 + ['Small Key (Ganons Tower)'] * 4):
soft_collect(item) soft_collect(item)
ret.sweep_for_events() ret.sweep_for_events()
ret.clear_cached_unreachable()
return ret return ret
def get_items(self): def get_items(self):
@ -233,7 +232,6 @@ class World(object):
def unlocks_new_location(self, item): def unlocks_new_location(self, item):
temp_state = self.state.copy() temp_state = self.state.copy()
temp_state.clear_cached_unreachable()
temp_state.collect(item, True) temp_state.collect(item, True)
for location in self.get_unfilled_locations(): for location in self.get_unfilled_locations():
@ -320,75 +318,50 @@ class CollectionState(object):
def __init__(self, parent): def __init__(self, parent):
self.prog_items = [] self.prog_items = []
self.world = parent self.world = parent
self.region_cache = {} self.reachable_regions = set()
self.location_cache = {}
self.entrance_cache = {}
self.recursion_count = 0
self.events = [] self.events = []
self.path = {} self.path = {}
self.locations_checked = set() self.locations_checked = set()
self.stale = True
def update_reachable_regions(self):
self.stale=False
def clear_cached_unreachable(self): new_regions = True
# we only need to invalidate results which were False, places we could reach before we can still reach after adding more items reachable_regions_count = len(self.reachable_regions)
self.region_cache = {k: v for k, v in self.region_cache.items() if v} while new_regions:
self.location_cache = {k: v for k, v in self.location_cache.items() if v} possible = [region for region in self.world.regions if region not in self.reachable_regions]
self.entrance_cache = {k: v for k, v in self.entrance_cache.items() if v} for candidate in possible:
if candidate.can_reach_private(self):
self.reachable_regions.add(candidate)
new_regions = len(self.reachable_regions) > reachable_regions_count
reachable_regions_count = len(self.reachable_regions)
def copy(self): def copy(self):
ret = CollectionState(self.world) ret = CollectionState(self.world)
ret.prog_items = copy.copy(self.prog_items) ret.prog_items = copy.copy(self.prog_items)
ret.region_cache = copy.copy(self.region_cache) ret.reachable_regions = copy.copy(self.reachable_regions)
ret.location_cache = copy.copy(self.location_cache)
ret.entrance_cache = copy.copy(self.entrance_cache)
ret.events = copy.copy(self.events) ret.events = copy.copy(self.events)
ret.path = copy.copy(self.path) ret.path = copy.copy(self.path)
ret.locations_checked = copy.copy(self.locations_checked) ret.locations_checked = copy.copy(self.locations_checked)
ret.stale = True
return ret return ret
def can_reach(self, spot, resolution_hint=None): def can_reach(self, spot, resolution_hint=None):
try: try:
spot_type = spot.spot_type spot_type = spot.spot_type
if spot_type == 'Location':
correct_cache = self.location_cache
elif spot_type == 'Region':
correct_cache = self.region_cache
elif spot_type == 'Entrance':
correct_cache = self.entrance_cache
else:
raise AttributeError
except AttributeError: except AttributeError:
# try to resolve a name # try to resolve a name
if resolution_hint == 'Location': if resolution_hint == 'Location':
spot = self.world.get_location(spot) spot = self.world.get_location(spot)
correct_cache = self.location_cache
elif resolution_hint == 'Entrance': elif resolution_hint == 'Entrance':
spot = self.world.get_entrance(spot) spot = self.world.get_entrance(spot)
correct_cache = self.entrance_cache
else: else:
# default to Region # default to Region
spot = self.world.get_region(spot) spot = self.world.get_region(spot)
correct_cache = self.region_cache
if spot.recursion_count > 0: return spot.can_reach(self)
return False
if spot not in correct_cache:
# for the purpose of evaluating results, recursion is resolved by always denying recursive access (as that ia what we are trying to figure out right now in the first place
spot.recursion_count += 1
self.recursion_count += 1
can_reach = spot.can_reach(self)
spot.recursion_count -= 1
self.recursion_count -= 1
# we only store qualified false results (i.e. ones not inside a hypothetical)
if not can_reach:
if self.recursion_count == 0:
correct_cache[spot] = can_reach
else:
correct_cache[spot] = can_reach
return can_reach
return correct_cache[spot]
def sweep_for_events(self, key_only=False): def sweep_for_events(self, key_only=False):
# this may need improvement # this may need improvement
@ -566,12 +539,12 @@ class CollectionState(object):
elif event or item.advancement: elif event or item.advancement:
self.prog_items.append(item.name) self.prog_items.append(item.name)
changed = True changed = True
self.stale = True
if changed: if changed:
self.clear_cached_unreachable()
if not event: if not event:
self.sweep_for_events() self.sweep_for_events()
self.clear_cached_unreachable()
def remove(self, item): def remove(self, item):
if item.advancement: if item.advancement:
@ -603,10 +576,8 @@ class CollectionState(object):
return return
# invalidate caches, nothing can be trusted anymore now # invalidate caches, nothing can be trusted anymore now
self.region_cache = {} self.reachable_regions = set()
self.location_cache = {} self.stale = True
self.entrance_cache = {}
self.recursion_count = 0
def __getattr__(self, item): def __getattr__(self, item):
if item.startswith('can_reach_'): if item.startswith('can_reach_'):
@ -647,6 +618,11 @@ class Region(object):
self.recursion_count = 0 self.recursion_count = 0
def can_reach(self, state): def can_reach(self, state):
if state.stale:
state.update_reachable_regions()
return self in state.reachable_regions
def can_reach_private(self, state):
for entrance in self.entrances: for entrance in self.entrances:
if state.can_reach(entrance): if state.can_reach(entrance):
if not self in state.path: if not self in state.path:
@ -683,7 +659,7 @@ class Entrance(object):
self.access_rule = lambda state: True self.access_rule = lambda state: True
def can_reach(self, state): def can_reach(self, state):
if self.access_rule(state) and state.can_reach(self.parent_region): if state.can_reach(self.parent_region) and self.access_rule(state):
if not self in state.path: if not self in state.path:
state.path[self] = (self.name, state.path.get(self.parent_region, (self.parent_region.name, None))) state.path[self] = (self.name, state.path.get(self.parent_region, (self.parent_region.name, None)))
return True return True
@ -768,7 +744,7 @@ class Location(object):
return self.always_allow(state, item) or (self.parent_region.can_fill(item) and self.item_rule(item) and (not check_access or self.can_reach(state))) return self.always_allow(state, item) or (self.parent_region.can_fill(item) and self.item_rule(item) and (not check_access or self.can_reach(state)))
def can_reach(self, state): def can_reach(self, state):
if self.access_rule(state) and state.can_reach(self.parent_region): if state.can_reach(self.parent_region) and self.access_rule(state):
return True return True
return False return False

View File

@ -71,7 +71,6 @@ def fill_dungeons(world):
world.push_item(bk_location, big_key, False) world.push_item(bk_location, big_key, False)
bk_location.event = True bk_location.event = True
dungeon_locations.remove(bk_location) dungeon_locations.remove(bk_location)
all_state.clear_cached_unreachable()
big_key = None big_key = None
# next place small keys # next place small keys
@ -97,7 +96,6 @@ def fill_dungeons(world):
world.push_item(sk_location, small_key, False) world.push_item(sk_location, small_key, False)
sk_location.event = True sk_location.event = True
dungeon_locations.remove(sk_location) dungeon_locations.remove(sk_location)
all_state.clear_cached_unreachable()
if small_keys: if small_keys:
# key placement not finished, loop again # key placement not finished, loop again
@ -109,7 +107,6 @@ def fill_dungeons(world):
di_location = dungeon_locations.pop() di_location = dungeon_locations.pop()
world.push_item(di_location, dungeon_item, False) world.push_item(di_location, dungeon_item, False)
world.state.clear_cached_unreachable()
def get_dungeon_item_pool(world): def get_dungeon_item_pool(world):
return [item for dungeon in world.dungeons for item in dungeon.all_items if item.key or world.place_dungeon_items] return [item for dungeon in world.dungeons for item in dungeon.all_items if item.key or world.place_dungeon_items]
@ -142,7 +139,6 @@ def fill_dungeons_restrictive(world, shuffled_locations):
fill_restrictive(world, all_state_base, shuffled_locations, dungeon_items) fill_restrictive(world, all_state_base, shuffled_locations, dungeon_items)
world.state.clear_cached_unreachable()
dungeon_music_addresses = {'Eastern Palace - Prize': [0x1559A], dungeon_music_addresses = {'Eastern Palace - Prize': [0x1559A],

View File

@ -200,6 +200,7 @@ def copy_world(world):
# copy progress items in state # copy progress items in state
ret.state.prog_items = list(world.state.prog_items) ret.state.prog_items = list(world.state.prog_items)
ret.state.stale = True
set_rules(ret) set_rules(ret)