import unittest from enum import IntEnum from BaseClasses import Region, EntranceType, MultiWorld, Entrance from entrance_rando import disconnect_entrance_for_randomization, randomize_entrances, EntranceRandomizationError, \ ERPlacementState, EntranceLookup, bake_target_group_lookup from Options import Accessibility from test.general import generate_test_multiworld, generate_locations, generate_items from worlds.generic.Rules import set_rule class ERTestGroups(IntEnum): LEFT = 1 RIGHT = 2 TOP = 3 BOTTOM = 4 directionally_matched_group_lookup = { ERTestGroups.LEFT: [ERTestGroups.RIGHT], ERTestGroups.RIGHT: [ERTestGroups.LEFT], ERTestGroups.TOP: [ERTestGroups.BOTTOM], ERTestGroups.BOTTOM: [ERTestGroups.TOP] } def generate_entrance_pair(region: Region, name_suffix: str, group: int): lx = region.create_exit(region.name + name_suffix) lx.randomization_group = group lx.randomization_type = EntranceType.TWO_WAY le = region.create_er_target(region.name + name_suffix) le.randomization_group = group le.randomization_type = EntranceType.TWO_WAY def generate_disconnected_region_grid(multiworld: MultiWorld, grid_side_length: int, region_size: int = 0, region_type: type[Region] = Region): """ Generates a grid-like region structure for ER testing, where menu is connected to the top-left region, and each region "in vanilla" has 2 2-way exits going either down or to the right, until reaching the goal region in the bottom right """ for row in range(grid_side_length): for col in range(grid_side_length): index = row * grid_side_length + col name = f"region{index}" region = region_type(name, 1, multiworld) multiworld.regions.append(region) generate_locations(region_size, 1, region=region, tag=f"_{name}") if row == 0 and col == 0: multiworld.get_region("Menu", 1).connect(region) if col != 0: generate_entrance_pair(region, "_left", ERTestGroups.LEFT) if col != grid_side_length - 1: generate_entrance_pair(region, "_right", ERTestGroups.RIGHT) if row != 0: generate_entrance_pair(region, "_top", ERTestGroups.TOP) if row != grid_side_length - 1: generate_entrance_pair(region, "_bottom", ERTestGroups.BOTTOM) class TestEntranceLookup(unittest.TestCase): def test_shuffled_targets(self): """tests that get_targets shuffles targets between groups when requested""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True) er_targets = [entrance for region in multiworld.get_regions(1) for entrance in region.entrances if not entrance.parent_region] for entrance in er_targets: lookup.add(entrance) retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM], False, False) prev = None group_order = [prev := group.randomization_group for group in retrieved_targets if prev != group.randomization_group] # technically possible that group order may not be shuffled, by some small chance, on some seeds. but generally # a shuffled list should alternate more frequently which is the desired behavior here self.assertGreater(len(group_order), 2) def test_ordered_targets(self): """tests that get_targets does not shuffle targets between groups when requested""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True) er_targets = [entrance for region in multiworld.get_regions(1) for entrance in region.entrances if not entrance.parent_region] for entrance in er_targets: lookup.add(entrance) retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM], False, True) prev = None group_order = [prev := group.randomization_group for group in retrieved_targets if prev != group.randomization_group] self.assertEqual([ERTestGroups.TOP, ERTestGroups.BOTTOM], group_order) class TestBakeTargetGroupLookup(unittest.TestCase): def test_lookup_generation(self): multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) world = multiworld.worlds[1] expected = { ERTestGroups.LEFT: [-ERTestGroups.LEFT], ERTestGroups.RIGHT: [-ERTestGroups.RIGHT], ERTestGroups.TOP: [-ERTestGroups.TOP], ERTestGroups.BOTTOM: [-ERTestGroups.BOTTOM] } actual = bake_target_group_lookup(world, lambda g: [-g]) self.assertEqual(expected, actual) class TestDisconnectForRandomization(unittest.TestCase): def test_disconnect_default_2way(self): multiworld = generate_test_multiworld() r1 = Region("r1", 1, multiworld) r2 = Region("r2", 1, multiworld) e = r1.create_exit("e") e.randomization_type = EntranceType.TWO_WAY e.randomization_group = 1 e.connect(r2) disconnect_entrance_for_randomization(e) self.assertIsNone(e.connected_region) self.assertEqual([], r2.entrances) self.assertEqual(1, len(r1.exits)) self.assertEqual(e, r1.exits[0]) self.assertEqual(1, len(r1.entrances)) self.assertIsNone(r1.entrances[0].parent_region) self.assertEqual("e", r1.entrances[0].name) self.assertEqual(EntranceType.TWO_WAY, r1.entrances[0].randomization_type) self.assertEqual(1, r1.entrances[0].randomization_group) def test_disconnect_default_1way(self): multiworld = generate_test_multiworld() r1 = Region("r1", 1, multiworld) r2 = Region("r2", 1, multiworld) e = r1.create_exit("e") e.randomization_type = EntranceType.ONE_WAY e.randomization_group = 1 e.connect(r2) disconnect_entrance_for_randomization(e) self.assertIsNone(e.connected_region) self.assertEqual([], r1.entrances) self.assertEqual(1, len(r1.exits)) self.assertEqual(e, r1.exits[0]) self.assertEqual(1, len(r2.entrances)) self.assertIsNone(r2.entrances[0].parent_region) self.assertEqual("r2", r2.entrances[0].name) self.assertEqual(EntranceType.ONE_WAY, r2.entrances[0].randomization_type) self.assertEqual(1, r2.entrances[0].randomization_group) def test_disconnect_uses_alternate_group(self): multiworld = generate_test_multiworld() r1 = Region("r1", 1, multiworld) r2 = Region("r2", 1, multiworld) e = r1.create_exit("e") e.randomization_type = EntranceType.ONE_WAY e.randomization_group = 1 e.connect(r2) disconnect_entrance_for_randomization(e, 2) self.assertIsNone(e.connected_region) self.assertEqual([], r1.entrances) self.assertEqual(1, len(r1.exits)) self.assertEqual(e, r1.exits[0]) self.assertEqual(1, len(r2.entrances)) self.assertIsNone(r2.entrances[0].parent_region) self.assertEqual("r2", r2.entrances[0].name) self.assertEqual(EntranceType.ONE_WAY, r2.entrances[0].randomization_type) self.assertEqual(2, r2.entrances[0].randomization_group) class TestRandomizeEntrances(unittest.TestCase): def test_determinism(self): """tests that the same output is produced for the same input""" multiworld1 = generate_test_multiworld() generate_disconnected_region_grid(multiworld1, 5) multiworld2 = generate_test_multiworld() generate_disconnected_region_grid(multiworld2, 5) result1 = randomize_entrances(multiworld1.worlds[1], False, directionally_matched_group_lookup) result2 = randomize_entrances(multiworld2.worlds[1], False, directionally_matched_group_lookup) self.assertEqual(result1.pairings, result2.pairings) for e1, e2 in zip(result1.placements, result2.placements): self.assertEqual(e1.name, e2.name) self.assertEqual(e1.parent_region.name, e1.parent_region.name) self.assertEqual(e1.connected_region.name, e2.connected_region.name) def test_all_entrances_placed(self): """tests that all entrances and exits were placed, all regions are connected, and no dangling edges exist""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup) self.assertEqual([], [entrance for region in multiworld.get_regions() for entrance in region.entrances if not entrance.parent_region]) self.assertEqual([], [exit_ for region in multiworld.get_regions() for exit_ in region.exits if not exit_.connected_region]) # 5x5 grid + menu self.assertEqual(26, len(result.placed_regions)) self.assertEqual(80, len(result.pairings)) self.assertEqual(80, len(result.placements)) def test_coupling(self): """tests that in coupled mode, all 2 way transitions have an inverse""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) seen_placement_count = 0 def verify_coupled(_: ERPlacementState, placed_entrances: list[Entrance]): nonlocal seen_placement_count seen_placement_count += len(placed_entrances) self.assertEqual(2, len(placed_entrances)) self.assertEqual(placed_entrances[0].parent_region, placed_entrances[1].connected_region) self.assertEqual(placed_entrances[1].parent_region, placed_entrances[0].connected_region) result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup, on_connect=verify_coupled) # if we didn't visit every placement the verification on_connect doesn't really mean much self.assertEqual(len(result.placements), seen_placement_count) def test_uncoupled(self): """tests that in uncoupled mode, no transitions have an (intentional) inverse""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) seen_placement_count = 0 def verify_uncoupled(state: ERPlacementState, placed_entrances: list[Entrance]): nonlocal seen_placement_count seen_placement_count += len(placed_entrances) self.assertEqual(1, len(placed_entrances)) result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup, on_connect=verify_uncoupled) # if we didn't visit every placement the verification on_connect doesn't really mean much self.assertEqual(len(result.placements), seen_placement_count) def test_oneway_twoway_pairing(self): """tests that 1 ways are only paired to 1 ways and 2 ways are only paired to 2 ways""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) region26 = Region("region26", 1, multiworld) multiworld.regions.append(region26) for index, region in enumerate(["region4", "region20", "region24"]): x = multiworld.get_region(region, 1).create_exit(f"{region}_bottom_1way") x.randomization_type = EntranceType.ONE_WAY x.randomization_group = ERTestGroups.BOTTOM e = region26.create_er_target(f"region26_top_1way{index}") e.randomization_type = EntranceType.ONE_WAY e.randomization_group = ERTestGroups.TOP result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup) for exit_name, entrance_name in result.pairings: # we have labeled our entrances in such a way that all the 1 way entrances have 1way in the name, # so test for that since the ER target will have been discarded if "1way" in exit_name: self.assertIn("1way", entrance_name) def test_group_constraints_satisfied(self): """tests that all grouping constraints are satisfied""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup) for exit_name, entrance_name in result.pairings: # we have labeled our entrances in such a way that all the entrances contain their group in the name # so test for that since the ER target will have been discarded if "top" in exit_name: self.assertIn("bottom", entrance_name) if "bottom" in exit_name: self.assertIn("top", entrance_name) if "left" in exit_name: self.assertIn("right", entrance_name) if "right" in exit_name: self.assertIn("left", entrance_name) def test_minimal_entrance_rando(self): """tests that entrance randomization can complete with minimal accessibility and unreachable exits""" multiworld = generate_test_multiworld() multiworld.worlds[1].options.accessibility = Accessibility.from_any(Accessibility.option_minimal) multiworld.completion_condition[1] = lambda state: state.can_reach("region24", player=1) generate_disconnected_region_grid(multiworld, 5, 1) prog_items = generate_items(10, 1, True) multiworld.itempool += prog_items filler_items = generate_items(15, 1, False) multiworld.itempool += filler_items e = multiworld.get_entrance("region1_right", 1) set_rule(e, lambda state: False) randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup) self.assertEqual([], [entrance for region in multiworld.get_regions() for entrance in region.entrances if not entrance.parent_region]) self.assertEqual([], [exit_ for region in multiworld.get_regions() for exit_ in region.exits if not exit_.connected_region]) def test_restrictive_region_requirement_does_not_fail(self): multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 2, 1) region = Region("region4", 1, multiworld) multiworld.regions.append(region) generate_entrance_pair(multiworld.get_region("region0", 1), "_right2", ERTestGroups.RIGHT) generate_entrance_pair(region, "_left", ERTestGroups.LEFT) blocked_exits = ["region1_left", "region1_bottom", "region2_top", "region2_right", "region3_left", "region3_top"] for exit_name in blocked_exits: blocked_exit = multiworld.get_entrance(exit_name, 1) blocked_exit.access_rule = lambda state: state.can_reach_region("region4", 1) multiworld.register_indirect_condition(region, blocked_exit) result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup) # verifying that we did in fact place region3 adjacent to region0 to unblock all the other connections # (and implicitly, that ER didn't fail) self.assertTrue(("region0_right", "region4_left") in result.pairings or ("region0_right2", "region4_left") in result.pairings) def test_fails_when_mismatched_entrance_and_exit_count(self): """tests that entrance randomization fast-fails if the input exit and entrance count do not match""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) multiworld.get_region("region1", 1).create_exit("extra") self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False, directionally_matched_group_lookup) def test_fails_when_some_unreachable_exit(self): """tests that entrance randomization fails if an exit is never reachable (non-minimal accessibility)""" multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5) e = multiworld.get_entrance("region1_right", 1) set_rule(e, lambda state: False) self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False, directionally_matched_group_lookup) def test_fails_when_some_unconnectable_exit(self): """tests that entrance randomization fails if an exit can't be made into a valid placement (non-minimal)""" class CustomEntrance(Entrance): def can_connect_to(self, other: Entrance, dead_end: bool, er_state: "ERPlacementState") -> bool: if other.name == "region1_right": return False class CustomRegion(Region): entrance_type = CustomEntrance multiworld = generate_test_multiworld() generate_disconnected_region_grid(multiworld, 5, region_type=CustomRegion) self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False, directionally_matched_group_lookup) def test_minimal_er_fails_when_not_enough_locations_to_fit_progression(self): """ tests that entrance randomization fails in minimal accessibility if there are not enough locations available to place all progression items locally """ multiworld = generate_test_multiworld() multiworld.worlds[1].options.accessibility = Accessibility.from_any(Accessibility.option_minimal) multiworld.completion_condition[1] = lambda state: state.can_reach("region24", player=1) generate_disconnected_region_grid(multiworld, 5, 1) prog_items = generate_items(30, 1, True) multiworld.itempool += prog_items e = multiworld.get_entrance("region1_right", 1) set_rule(e, lambda state: False) self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False, directionally_matched_group_lookup)