From e66a2a7c30a2e9d42d2d751be8d159acb04fb6d1 Mon Sep 17 00:00:00 2001 From: Fabian Dill Date: Sun, 10 Oct 2021 16:50:01 +0200 Subject: [PATCH] Core: change precollected_items to dict-style Core: make sure there are enough threads available during generate_output to prevent deadlocks if event waiting is used --- BaseClasses.py | 10 ++++++---- Main.py | 19 +++++++++---------- test/inverted_owg/TestInvertedOWG.py | 2 +- test/owg/TestVanillaOWG.py | 2 +- worlds/alttp/Rom.py | 4 +--- worlds/oot/__init__.py | 16 +++++++++------- 6 files changed, 27 insertions(+), 26 deletions(-) diff --git a/BaseClasses.py b/BaseClasses.py index 423ffbd5..e818f629 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -27,6 +27,7 @@ class MultiWorld(): plando_connections: List worlds: Dict[int, Any] is_race: bool = False + precollected_items: Dict[int, List[Item]] class AttributeProxy(): def __init__(self, rule): @@ -46,7 +47,7 @@ class MultiWorld(): self.itempool = [] self.seed = None self.seed_name: str = "Unavailable" - self.precollected_items = [] + self.precollected_items = {player: [] for player in self.player_ids} self.state = CollectionState(self) self._cached_entrances = None self._cached_locations = None @@ -266,7 +267,7 @@ class MultiWorld(): def push_precollected(self, item: Item): item.world = self - self.precollected_items.append(item) + self.precollected_items[item.player].append(item) self.state.collect(item, True) def push_item(self, location: Location, item: Item, collect: bool = True): @@ -473,8 +474,9 @@ class CollectionState(object): self.path = {} self.locations_checked = set() self.stale = {player: True for player in range(1, parent.players + 1)} - for item in parent.precollected_items: - self.collect(item, True) + for items in parent.precollected_items.values(): + for item in items: + self.collect(item, True) def update_reachable_regions(self, player: int): from worlds.alttp.EntranceShuffle import indirect_connections diff --git a/Main.py b/Main.py index 638fecb2..a4a718a8 100644 --- a/Main.py +++ b/Main.py @@ -1,4 +1,4 @@ -from itertools import zip_longest +from itertools import zip_longest, chain import logging import os import time @@ -159,16 +159,15 @@ def main(args, seed=None): output = tempfile.TemporaryDirectory() with output as temp_dir: - with concurrent.futures.ThreadPoolExecutor() as pool: + with concurrent.futures.ThreadPoolExecutor(world.players+2) as pool: check_accessibility_task = pool.submit(world.fulfills_accessibility) - output_file_futures = [] - + output_file_futures = [pool.submit(AutoWorld.call_stage, world, "generate_output", temp_dir)] for player in world.player_ids: # skip starting a thread for methods that say "pass". if AutoWorld.World.generate_output.__code__ is not world.worlds[player].generate_output.__code__: output_file_futures.append(pool.submit(AutoWorld.call_single, world, "generate_output", player, temp_dir)) - output_file_futures.append(pool.submit(AutoWorld.call_stage, world, "generate_output", temp_dir)) + def get_entrance_to_region(region: Region): for entrance in region.entrances: @@ -246,9 +245,8 @@ def main(args, seed=None): for slot in world.player_ids: client_versions[slot] = world.worlds[slot].get_required_client_version() games[slot] = world.game[slot] - precollected_items = {player: [] for player in range(1, world.players + 1)} - for item in world.precollected_items: - precollected_items[item.player].append(item.code) + precollected_items = {player: [item.code for item in world_precollected] + for player, world_precollected in world.precollected_items.items()} precollected_hints = {player: set() for player in range(1, world.players + 1)} # for now special case Factorio tech_tree_information sending_visible_players = set() @@ -397,7 +395,7 @@ def create_playthrough(world): # second phase, sphere 0 removed_precollected = [] - for item in (i for i in world.precollected_items if i.advancement): + for item in (i for i in chain(world.precollected_items.values()) if i.advancement): logging.debug('Checking if %s (Player %d) is required to beat the game.', item.name, item.player) world.precollected_items.remove(item) world.state.remove(item) @@ -463,7 +461,8 @@ def create_playthrough(world): get_path(state, world.get_region('Inverted Big Bomb Shop', player)) # we can finally output our playthrough - world.spoiler.playthrough = {"0": sorted([str(item) for item in world.precollected_items if item.advancement])} + world.spoiler.playthrough = {"0": sorted([str(item) for item in chain(world.precollected_items.values()) + if item.advancement])} for i, sphere in enumerate(collection_spheres): world.spoiler.playthrough[str(i + 1)] = {str(location): str(location.item) for location in sorted(sphere)} diff --git a/test/inverted_owg/TestInvertedOWG.py b/test/inverted_owg/TestInvertedOWG.py index 486c3cb9..7192fcb0 100644 --- a/test/inverted_owg/TestInvertedOWG.py +++ b/test/inverted_owg/TestInvertedOWG.py @@ -35,7 +35,7 @@ class TestInvertedOWG(TestBase): self.world.itempool.extend(ItemFactory(['Green Pendant', 'Red Pendant', 'Blue Pendant', 'Beat Agahnim 1', 'Beat Agahnim 2', 'Crystal 1', 'Crystal 2', 'Crystal 3', 'Crystal 4', 'Crystal 5', 'Crystal 6', 'Crystal 7'], 1)) self.world.get_location('Agahnim 1', 1).item = None self.world.get_location('Agahnim 2', 1).item = None - self.world.precollected_items.clear() + self.world.precollected_items[1].clear() self.world.itempool.append(ItemFactory('Pegasus Boots', 1)) mark_light_world_regions(self.world, 1) self.world.worlds[1].set_rules() diff --git a/test/owg/TestVanillaOWG.py b/test/owg/TestVanillaOWG.py index a4e584f8..68b10732 100644 --- a/test/owg/TestVanillaOWG.py +++ b/test/owg/TestVanillaOWG.py @@ -34,7 +34,7 @@ class TestVanillaOWG(TestBase): self.world.itempool.extend(ItemFactory(['Green Pendant', 'Red Pendant', 'Blue Pendant', 'Beat Agahnim 1', 'Beat Agahnim 2', 'Crystal 1', 'Crystal 2', 'Crystal 3', 'Crystal 4', 'Crystal 5', 'Crystal 6', 'Crystal 7'], 1)) self.world.get_location('Agahnim 1', 1).item = None self.world.get_location('Agahnim 2', 1).item = None - self.world.precollected_items.clear() + self.world.precollected_items[1].clear() self.world.itempool.append(ItemFactory('Pegasus Boots', 1)) mark_dark_world_regions(self.world, 1) self.world.worlds[1].set_rules() \ No newline at end of file diff --git a/worlds/alttp/Rom.py b/worlds/alttp/Rom.py index c3367267..568ef893 100644 --- a/worlds/alttp/Rom.py +++ b/worlds/alttp/Rom.py @@ -1315,9 +1315,7 @@ def patch_rom(world, rom, player, enemized): equip[0x37B] = 1 equip[0x36E] = 0x80 - for item in world.precollected_items: - if item.player != player: - continue + for item in world.precollected_items[player]: if item.name in {'Bow', 'Silver Bow', 'Silver Arrows', 'Progressive Bow', 'Progressive Bow (Alt)', 'Titans Mitts', 'Power Glove', 'Progressive Glove', diff --git a/worlds/oot/__init__.py b/worlds/oot/__init__.py index 5e7affb0..e4dc8d84 100644 --- a/worlds/oot/__init__.py +++ b/worlds/oot/__init__.py @@ -36,7 +36,6 @@ location_id_offset = 67000 # OoT's generate_output doesn't benefit from more than 2 threads, instead it uses a lot of memory. i_o_limiter = threading.Semaphore(2) -hint_data_available = threading.Event() class OOTWorld(World): @@ -88,6 +87,10 @@ class OOTWorld(World): return super().__new__(cls) + def __init__(self, world, player): + self.hint_data_available = threading.Event() + super(OOTWorld, self).__init__(world, player) + def generate_early(self): # Player name MUST be at most 16 bytes ascii-encoded, otherwise won't write to ROM correctly if len(bytes(self.world.get_player_name(self.player), 'ascii')) > 16: @@ -261,7 +264,7 @@ class OOTWorld(World): # Both two-handed swords can be required in glitch logic, so only consider them nonprogression in glitchless self.nonadvancement_items.add('Biggoron Sword') self.nonadvancement_items.add('Giants Knife') - + def load_regions_from_json(self, file_path): region_json = read_json(file_path) @@ -456,9 +459,7 @@ class OOTWorld(World): junk_pool = get_junk_pool(self) removed_items = [] # Determine starting items - for item in self.world.precollected_items: - if item.player != self.player: - continue + for item in self.world.precollected_items[self.player]: if item.name in self.remove_from_start_inventory: self.remove_from_start_inventory.remove(item.name) removed_items.append(item.name) @@ -703,7 +704,7 @@ class OOTWorld(World): def generate_output(self, output_directory: str): if self.hints != 'none': - hint_data_available.wait() + self.hint_data_available.wait() with i_o_limiter: # Make ice traps appear as other random items @@ -782,7 +783,8 @@ class OOTWorld(World): except Exception as e: raise e finally: - hint_data_available.set() + for autoworld in world.get_game_worlds("Ocarina of Time"): + autoworld.hint_data_available.set() def modify_multidata(self, multidata: dict): for item_name in self.remove_from_start_inventory: