388 lines
18 KiB
Python
388 lines
18 KiB
Python
import unittest
|
|
from enum import IntEnum
|
|
|
|
from BaseClasses import Region, EntranceType, MultiWorld, Entrance
|
|
from entrance_rando import disconnect_entrance_for_randomization, randomize_entrances, EntranceRandomizationError, \
|
|
ERPlacementState, EntranceLookup, bake_target_group_lookup
|
|
from Options import Accessibility
|
|
from test.general import generate_test_multiworld, generate_locations, generate_items
|
|
from worlds.generic.Rules import set_rule
|
|
|
|
|
|
class ERTestGroups(IntEnum):
|
|
LEFT = 1
|
|
RIGHT = 2
|
|
TOP = 3
|
|
BOTTOM = 4
|
|
|
|
|
|
directionally_matched_group_lookup = {
|
|
ERTestGroups.LEFT: [ERTestGroups.RIGHT],
|
|
ERTestGroups.RIGHT: [ERTestGroups.LEFT],
|
|
ERTestGroups.TOP: [ERTestGroups.BOTTOM],
|
|
ERTestGroups.BOTTOM: [ERTestGroups.TOP]
|
|
}
|
|
|
|
|
|
def generate_entrance_pair(region: Region, name_suffix: str, group: int):
|
|
lx = region.create_exit(region.name + name_suffix)
|
|
lx.randomization_group = group
|
|
lx.randomization_type = EntranceType.TWO_WAY
|
|
le = region.create_er_target(region.name + name_suffix)
|
|
le.randomization_group = group
|
|
le.randomization_type = EntranceType.TWO_WAY
|
|
|
|
|
|
def generate_disconnected_region_grid(multiworld: MultiWorld, grid_side_length: int, region_size: int = 0,
|
|
region_type: type[Region] = Region):
|
|
"""
|
|
Generates a grid-like region structure for ER testing, where menu is connected to the top-left region, and each
|
|
region "in vanilla" has 2 2-way exits going either down or to the right, until reaching the goal region in the
|
|
bottom right
|
|
"""
|
|
for row in range(grid_side_length):
|
|
for col in range(grid_side_length):
|
|
index = row * grid_side_length + col
|
|
name = f"region{index}"
|
|
region = region_type(name, 1, multiworld)
|
|
multiworld.regions.append(region)
|
|
generate_locations(region_size, 1, region=region, tag=f"_{name}")
|
|
|
|
if row == 0 and col == 0:
|
|
multiworld.get_region("Menu", 1).connect(region)
|
|
if col != 0:
|
|
generate_entrance_pair(region, "_left", ERTestGroups.LEFT)
|
|
if col != grid_side_length - 1:
|
|
generate_entrance_pair(region, "_right", ERTestGroups.RIGHT)
|
|
if row != 0:
|
|
generate_entrance_pair(region, "_top", ERTestGroups.TOP)
|
|
if row != grid_side_length - 1:
|
|
generate_entrance_pair(region, "_bottom", ERTestGroups.BOTTOM)
|
|
|
|
|
|
class TestEntranceLookup(unittest.TestCase):
|
|
def test_shuffled_targets(self):
|
|
"""tests that get_targets shuffles targets between groups when requested"""
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
|
|
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True)
|
|
er_targets = [entrance for region in multiworld.get_regions(1)
|
|
for entrance in region.entrances if not entrance.parent_region]
|
|
for entrance in er_targets:
|
|
lookup.add(entrance)
|
|
|
|
retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM],
|
|
False, False)
|
|
prev = None
|
|
group_order = [prev := group.randomization_group for group in retrieved_targets
|
|
if prev != group.randomization_group]
|
|
# technically possible that group order may not be shuffled, by some small chance, on some seeds. but generally
|
|
# a shuffled list should alternate more frequently which is the desired behavior here
|
|
self.assertGreater(len(group_order), 2)
|
|
|
|
|
|
def test_ordered_targets(self):
|
|
"""tests that get_targets does not shuffle targets between groups when requested"""
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
|
|
lookup = EntranceLookup(multiworld.worlds[1].random, coupled=True)
|
|
er_targets = [entrance for region in multiworld.get_regions(1)
|
|
for entrance in region.entrances if not entrance.parent_region]
|
|
for entrance in er_targets:
|
|
lookup.add(entrance)
|
|
|
|
retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM],
|
|
False, True)
|
|
prev = None
|
|
group_order = [prev := group.randomization_group for group in retrieved_targets if prev != group.randomization_group]
|
|
self.assertEqual([ERTestGroups.TOP, ERTestGroups.BOTTOM], group_order)
|
|
|
|
|
|
class TestBakeTargetGroupLookup(unittest.TestCase):
|
|
def test_lookup_generation(self):
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
world = multiworld.worlds[1]
|
|
expected = {
|
|
ERTestGroups.LEFT: [-ERTestGroups.LEFT],
|
|
ERTestGroups.RIGHT: [-ERTestGroups.RIGHT],
|
|
ERTestGroups.TOP: [-ERTestGroups.TOP],
|
|
ERTestGroups.BOTTOM: [-ERTestGroups.BOTTOM]
|
|
}
|
|
actual = bake_target_group_lookup(world, lambda g: [-g])
|
|
self.assertEqual(expected, actual)
|
|
|
|
|
|
class TestDisconnectForRandomization(unittest.TestCase):
|
|
def test_disconnect_default_2way(self):
|
|
multiworld = generate_test_multiworld()
|
|
r1 = Region("r1", 1, multiworld)
|
|
r2 = Region("r2", 1, multiworld)
|
|
e = r1.create_exit("e")
|
|
e.randomization_type = EntranceType.TWO_WAY
|
|
e.randomization_group = 1
|
|
e.connect(r2)
|
|
|
|
disconnect_entrance_for_randomization(e)
|
|
|
|
self.assertIsNone(e.connected_region)
|
|
self.assertEqual([], r2.entrances)
|
|
|
|
self.assertEqual(1, len(r1.exits))
|
|
self.assertEqual(e, r1.exits[0])
|
|
|
|
self.assertEqual(1, len(r1.entrances))
|
|
self.assertIsNone(r1.entrances[0].parent_region)
|
|
self.assertEqual("e", r1.entrances[0].name)
|
|
self.assertEqual(EntranceType.TWO_WAY, r1.entrances[0].randomization_type)
|
|
self.assertEqual(1, r1.entrances[0].randomization_group)
|
|
|
|
def test_disconnect_default_1way(self):
|
|
multiworld = generate_test_multiworld()
|
|
r1 = Region("r1", 1, multiworld)
|
|
r2 = Region("r2", 1, multiworld)
|
|
e = r1.create_exit("e")
|
|
e.randomization_type = EntranceType.ONE_WAY
|
|
e.randomization_group = 1
|
|
e.connect(r2)
|
|
|
|
disconnect_entrance_for_randomization(e)
|
|
|
|
self.assertIsNone(e.connected_region)
|
|
self.assertEqual([], r1.entrances)
|
|
|
|
self.assertEqual(1, len(r1.exits))
|
|
self.assertEqual(e, r1.exits[0])
|
|
|
|
self.assertEqual(1, len(r2.entrances))
|
|
self.assertIsNone(r2.entrances[0].parent_region)
|
|
self.assertEqual("r2", r2.entrances[0].name)
|
|
self.assertEqual(EntranceType.ONE_WAY, r2.entrances[0].randomization_type)
|
|
self.assertEqual(1, r2.entrances[0].randomization_group)
|
|
|
|
def test_disconnect_uses_alternate_group(self):
|
|
multiworld = generate_test_multiworld()
|
|
r1 = Region("r1", 1, multiworld)
|
|
r2 = Region("r2", 1, multiworld)
|
|
e = r1.create_exit("e")
|
|
e.randomization_type = EntranceType.ONE_WAY
|
|
e.randomization_group = 1
|
|
e.connect(r2)
|
|
|
|
disconnect_entrance_for_randomization(e, 2)
|
|
|
|
self.assertIsNone(e.connected_region)
|
|
self.assertEqual([], r1.entrances)
|
|
|
|
self.assertEqual(1, len(r1.exits))
|
|
self.assertEqual(e, r1.exits[0])
|
|
|
|
self.assertEqual(1, len(r2.entrances))
|
|
self.assertIsNone(r2.entrances[0].parent_region)
|
|
self.assertEqual("r2", r2.entrances[0].name)
|
|
self.assertEqual(EntranceType.ONE_WAY, r2.entrances[0].randomization_type)
|
|
self.assertEqual(2, r2.entrances[0].randomization_group)
|
|
|
|
|
|
class TestRandomizeEntrances(unittest.TestCase):
|
|
def test_determinism(self):
|
|
"""tests that the same output is produced for the same input"""
|
|
multiworld1 = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld1, 5)
|
|
multiworld2 = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld2, 5)
|
|
|
|
result1 = randomize_entrances(multiworld1.worlds[1], False, directionally_matched_group_lookup)
|
|
result2 = randomize_entrances(multiworld2.worlds[1], False, directionally_matched_group_lookup)
|
|
self.assertEqual(result1.pairings, result2.pairings)
|
|
for e1, e2 in zip(result1.placements, result2.placements):
|
|
self.assertEqual(e1.name, e2.name)
|
|
self.assertEqual(e1.parent_region.name, e1.parent_region.name)
|
|
self.assertEqual(e1.connected_region.name, e2.connected_region.name)
|
|
|
|
def test_all_entrances_placed(self):
|
|
"""tests that all entrances and exits were placed, all regions are connected, and no dangling edges exist"""
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
|
|
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup)
|
|
|
|
self.assertEqual([], [entrance for region in multiworld.get_regions()
|
|
for entrance in region.entrances if not entrance.parent_region])
|
|
self.assertEqual([], [exit_ for region in multiworld.get_regions()
|
|
for exit_ in region.exits if not exit_.connected_region])
|
|
# 5x5 grid + menu
|
|
self.assertEqual(26, len(result.placed_regions))
|
|
self.assertEqual(80, len(result.pairings))
|
|
self.assertEqual(80, len(result.placements))
|
|
|
|
def test_coupling(self):
|
|
"""tests that in coupled mode, all 2 way transitions have an inverse"""
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
seen_placement_count = 0
|
|
|
|
def verify_coupled(_: ERPlacementState, placed_entrances: list[Entrance]):
|
|
nonlocal seen_placement_count
|
|
seen_placement_count += len(placed_entrances)
|
|
self.assertEqual(2, len(placed_entrances))
|
|
self.assertEqual(placed_entrances[0].parent_region, placed_entrances[1].connected_region)
|
|
self.assertEqual(placed_entrances[1].parent_region, placed_entrances[0].connected_region)
|
|
|
|
result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup,
|
|
on_connect=verify_coupled)
|
|
# if we didn't visit every placement the verification on_connect doesn't really mean much
|
|
self.assertEqual(len(result.placements), seen_placement_count)
|
|
|
|
def test_uncoupled(self):
|
|
"""tests that in uncoupled mode, no transitions have an (intentional) inverse"""
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
seen_placement_count = 0
|
|
|
|
def verify_uncoupled(state: ERPlacementState, placed_entrances: list[Entrance]):
|
|
nonlocal seen_placement_count
|
|
seen_placement_count += len(placed_entrances)
|
|
self.assertEqual(1, len(placed_entrances))
|
|
|
|
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup,
|
|
on_connect=verify_uncoupled)
|
|
# if we didn't visit every placement the verification on_connect doesn't really mean much
|
|
self.assertEqual(len(result.placements), seen_placement_count)
|
|
|
|
def test_oneway_twoway_pairing(self):
|
|
"""tests that 1 ways are only paired to 1 ways and 2 ways are only paired to 2 ways"""
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
region26 = Region("region26", 1, multiworld)
|
|
multiworld.regions.append(region26)
|
|
for index, region in enumerate(["region4", "region20", "region24"]):
|
|
x = multiworld.get_region(region, 1).create_exit(f"{region}_bottom_1way")
|
|
x.randomization_type = EntranceType.ONE_WAY
|
|
x.randomization_group = ERTestGroups.BOTTOM
|
|
e = region26.create_er_target(f"region26_top_1way{index}")
|
|
e.randomization_type = EntranceType.ONE_WAY
|
|
e.randomization_group = ERTestGroups.TOP
|
|
|
|
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup)
|
|
for exit_name, entrance_name in result.pairings:
|
|
# we have labeled our entrances in such a way that all the 1 way entrances have 1way in the name,
|
|
# so test for that since the ER target will have been discarded
|
|
if "1way" in exit_name:
|
|
self.assertIn("1way", entrance_name)
|
|
|
|
def test_group_constraints_satisfied(self):
|
|
"""tests that all grouping constraints are satisfied"""
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
|
|
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup)
|
|
for exit_name, entrance_name in result.pairings:
|
|
# we have labeled our entrances in such a way that all the entrances contain their group in the name
|
|
# so test for that since the ER target will have been discarded
|
|
if "top" in exit_name:
|
|
self.assertIn("bottom", entrance_name)
|
|
if "bottom" in exit_name:
|
|
self.assertIn("top", entrance_name)
|
|
if "left" in exit_name:
|
|
self.assertIn("right", entrance_name)
|
|
if "right" in exit_name:
|
|
self.assertIn("left", entrance_name)
|
|
|
|
def test_minimal_entrance_rando(self):
|
|
"""tests that entrance randomization can complete with minimal accessibility and unreachable exits"""
|
|
multiworld = generate_test_multiworld()
|
|
multiworld.worlds[1].options.accessibility = Accessibility.from_any(Accessibility.option_minimal)
|
|
multiworld.completion_condition[1] = lambda state: state.can_reach("region24", player=1)
|
|
generate_disconnected_region_grid(multiworld, 5, 1)
|
|
prog_items = generate_items(10, 1, True)
|
|
multiworld.itempool += prog_items
|
|
filler_items = generate_items(15, 1, False)
|
|
multiworld.itempool += filler_items
|
|
e = multiworld.get_entrance("region1_right", 1)
|
|
set_rule(e, lambda state: False)
|
|
|
|
randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup)
|
|
|
|
self.assertEqual([], [entrance for region in multiworld.get_regions()
|
|
for entrance in region.entrances if not entrance.parent_region])
|
|
self.assertEqual([], [exit_ for region in multiworld.get_regions()
|
|
for exit_ in region.exits if not exit_.connected_region])
|
|
|
|
def test_restrictive_region_requirement_does_not_fail(self):
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 2, 1)
|
|
|
|
region = Region("region4", 1, multiworld)
|
|
multiworld.regions.append(region)
|
|
generate_entrance_pair(multiworld.get_region("region0", 1), "_right2", ERTestGroups.RIGHT)
|
|
generate_entrance_pair(region, "_left", ERTestGroups.LEFT)
|
|
|
|
blocked_exits = ["region1_left", "region1_bottom",
|
|
"region2_top", "region2_right",
|
|
"region3_left", "region3_top"]
|
|
for exit_name in blocked_exits:
|
|
blocked_exit = multiworld.get_entrance(exit_name, 1)
|
|
blocked_exit.access_rule = lambda state: state.can_reach_region("region4", 1)
|
|
multiworld.register_indirect_condition(region, blocked_exit)
|
|
|
|
result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup)
|
|
# verifying that we did in fact place region3 adjacent to region0 to unblock all the other connections
|
|
# (and implicitly, that ER didn't fail)
|
|
self.assertTrue(("region0_right", "region4_left") in result.pairings
|
|
or ("region0_right2", "region4_left") in result.pairings)
|
|
|
|
def test_fails_when_mismatched_entrance_and_exit_count(self):
|
|
"""tests that entrance randomization fast-fails if the input exit and entrance count do not match"""
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
multiworld.get_region("region1", 1).create_exit("extra")
|
|
|
|
self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False,
|
|
directionally_matched_group_lookup)
|
|
|
|
def test_fails_when_some_unreachable_exit(self):
|
|
"""tests that entrance randomization fails if an exit is never reachable (non-minimal accessibility)"""
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5)
|
|
e = multiworld.get_entrance("region1_right", 1)
|
|
set_rule(e, lambda state: False)
|
|
|
|
self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False,
|
|
directionally_matched_group_lookup)
|
|
|
|
def test_fails_when_some_unconnectable_exit(self):
|
|
"""tests that entrance randomization fails if an exit can't be made into a valid placement (non-minimal)"""
|
|
class CustomEntrance(Entrance):
|
|
def can_connect_to(self, other: Entrance, dead_end: bool, er_state: "ERPlacementState") -> bool:
|
|
if other.name == "region1_right":
|
|
return False
|
|
|
|
class CustomRegion(Region):
|
|
entrance_type = CustomEntrance
|
|
|
|
multiworld = generate_test_multiworld()
|
|
generate_disconnected_region_grid(multiworld, 5, region_type=CustomRegion)
|
|
|
|
self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False,
|
|
directionally_matched_group_lookup)
|
|
|
|
def test_minimal_er_fails_when_not_enough_locations_to_fit_progression(self):
|
|
"""
|
|
tests that entrance randomization fails in minimal accessibility if there are not enough locations
|
|
available to place all progression items locally
|
|
"""
|
|
multiworld = generate_test_multiworld()
|
|
multiworld.worlds[1].options.accessibility = Accessibility.from_any(Accessibility.option_minimal)
|
|
multiworld.completion_condition[1] = lambda state: state.can_reach("region24", player=1)
|
|
generate_disconnected_region_grid(multiworld, 5, 1)
|
|
prog_items = generate_items(30, 1, True)
|
|
multiworld.itempool += prog_items
|
|
e = multiworld.get_entrance("region1_right", 1)
|
|
set_rule(e, lambda state: False)
|
|
|
|
self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False,
|
|
directionally_matched_group_lookup)
|