diff --git a/test/bases.py b/test/bases.py index 7ce12cc7..2d4111d1 100644 --- a/test/bases.py +++ b/test/bases.py @@ -7,7 +7,7 @@ from argparse import Namespace from Generate import get_seed_name from test.general import gen_steps from worlds import AutoWorld -from worlds.AutoWorld import call_all +from worlds.AutoWorld import World, call_all from BaseClasses import Location, MultiWorld, CollectionState, ItemClassification, Item from worlds.alttp.Items import ItemFactory @@ -105,9 +105,15 @@ class TestBase(unittest.TestCase): class WorldTestBase(unittest.TestCase): options: typing.Dict[str, typing.Any] = {} + """Define options that should be used when setting up this TestBase.""" multiworld: MultiWorld + """The constructed MultiWorld instance after setup.""" + world: World + """The constructed World instance after setup.""" + player: typing.ClassVar[int] = 1 - game: typing.ClassVar[str] # define game name in subclass, example "Secret of Evermore" + game: typing.ClassVar[str] + """Define game name in subclass, example "Secret of Evermore".""" auto_construct: typing.ClassVar[bool] = True """ automatically set up a world for each test in this class """ memory_leak_tested: typing.ClassVar[bool] = False @@ -150,8 +156,8 @@ class WorldTestBase(unittest.TestCase): if not hasattr(self, "game"): raise NotImplementedError("didn't define game name") self.multiworld = MultiWorld(1) - self.multiworld.game[1] = self.game - self.multiworld.player_name = {1: "Tester"} + self.multiworld.game[self.player] = self.game + self.multiworld.player_name = {self.player: "Tester"} self.multiworld.set_seed(seed) self.multiworld.state = CollectionState(self.multiworld) random.seed(self.multiworld.seed) @@ -159,9 +165,10 @@ class WorldTestBase(unittest.TestCase): args = Namespace() for name, option in AutoWorld.AutoWorldRegister.world_types[self.game].options_dataclass.type_hints.items(): setattr(args, name, { - 1: option.from_any(self.options.get(name, getattr(option, "default"))) + 1: option.from_any(self.options.get(name, option.default)) }) self.multiworld.set_options(args) + self.world = self.multiworld.worlds[self.player] for step in gen_steps: call_all(self.multiworld, step) @@ -220,19 +227,19 @@ class WorldTestBase(unittest.TestCase): def can_reach_location(self, location: str) -> bool: """Determines if the current state can reach the provided location name""" - return self.multiworld.state.can_reach(location, "Location", 1) + return self.multiworld.state.can_reach(location, "Location", self.player) def can_reach_entrance(self, entrance: str) -> bool: """Determines if the current state can reach the provided entrance name""" - return self.multiworld.state.can_reach(entrance, "Entrance", 1) + return self.multiworld.state.can_reach(entrance, "Entrance", self.player) def can_reach_region(self, region: str) -> bool: """Determines if the current state can reach the provided region name""" - return self.multiworld.state.can_reach(region, "Region", 1) + return self.multiworld.state.can_reach(region, "Region", self.player) def count(self, item_name: str) -> int: """Returns the amount of an item currently in state""" - return self.multiworld.state.count(item_name, 1) + return self.multiworld.state.count(item_name, self.player) def assertAccessDependency(self, locations: typing.List[str], @@ -246,10 +253,11 @@ class WorldTestBase(unittest.TestCase): self.collect_all_but(all_items, state) if only_check_listed: for location in locations: - self.assertFalse(state.can_reach(location, "Location", 1), f"{location} is reachable without {all_items}") + self.assertFalse(state.can_reach(location, "Location", self.player), + f"{location} is reachable without {all_items}") else: for location in self.multiworld.get_locations(): - loc_reachable = state.can_reach(location, "Location", 1) + loc_reachable = state.can_reach(location, "Location", self.player) self.assertEqual(loc_reachable, location.name not in locations, f"{location.name} is reachable without {all_items}" if loc_reachable else f"{location.name} is not reachable without {all_items}") @@ -258,7 +266,7 @@ class WorldTestBase(unittest.TestCase): for item in items: state.collect(item) for location in locations: - self.assertTrue(state.can_reach(location, "Location", 1), + self.assertTrue(state.can_reach(location, "Location", self.player), f"{location} not reachable with {item_names}") for item in items: state.remove(item) @@ -285,7 +293,7 @@ class WorldTestBase(unittest.TestCase): if not (self.run_default_tests and self.constructed): return with self.subTest("Game", game=self.game): - excluded = self.multiworld.worlds[1].options.exclude_locations.value + excluded = self.multiworld.worlds[self.player].options.exclude_locations.value state = self.multiworld.get_all_state(False) for location in self.multiworld.get_locations(): if location.name not in excluded: @@ -302,7 +310,7 @@ class WorldTestBase(unittest.TestCase): return with self.subTest("Game", game=self.game): state = CollectionState(self.multiworld) - locations = self.multiworld.get_reachable_locations(state, 1) + locations = self.multiworld.get_reachable_locations(state, self.player) self.assertGreater(len(locations), 0, "Need to be able to reach at least one location to get started.") @@ -328,7 +336,7 @@ class WorldTestBase(unittest.TestCase): for location in sphere: if location.item: state.collect(location.item, True, location) - return self.multiworld.has_beaten_game(state, 1) + return self.multiworld.has_beaten_game(state, self.player) with self.subTest("Game", game=self.game, seed=self.multiworld.seed): distribute_items_restrictive(self.multiworld) diff --git a/worlds/messenger/test/__init__.py b/worlds/messenger/test/__init__.py index 7ab1e117..f3fcd4ae 100644 --- a/worlds/messenger/test/__init__.py +++ b/worlds/messenger/test/__init__.py @@ -1,6 +1,7 @@ from test.TestBase import WorldTestBase +from .. import MessengerWorld class MessengerTestBase(WorldTestBase): game = "The Messenger" - player: int = 1 + world: MessengerWorld diff --git a/worlds/messenger/test/test_locations.py b/worlds/messenger/test/test_locations.py index 0c330be4..627d58c2 100644 --- a/worlds/messenger/test/test_locations.py +++ b/worlds/messenger/test/test_locations.py @@ -12,5 +12,5 @@ class LocationsTest(MessengerTestBase): return False def test_locations_exist(self) -> None: - for location in self.multiworld.worlds[1].location_name_to_id: + for location in self.world.location_name_to_id: self.assertIsInstance(self.multiworld.get_location(location, self.player), MessengerLocation) diff --git a/worlds/messenger/test/test_shop.py b/worlds/messenger/test/test_shop.py index afb1b32b..ee7e82d6 100644 --- a/worlds/messenger/test/test_shop.py +++ b/worlds/messenger/test/test_shop.py @@ -17,7 +17,7 @@ class ShopCostTest(MessengerTestBase): self.assertFalse(self.can_reach_location(loc)) def test_shop_prices(self) -> None: - prices: Dict[str, int] = self.multiworld.worlds[self.player].shop_prices + prices: Dict[str, int] = self.world.shop_prices for loc, price in prices.items(): with self.subTest("prices", loc=loc): self.assertLessEqual(price, self.multiworld.get_location(f"The Shop - {loc}", self.player).cost) @@ -51,7 +51,7 @@ class ShopCostMinTest(ShopCostTest): } def test_shop_rules(self) -> None: - if self.multiworld.worlds[self.player].total_shards: + if self.world.total_shards: super().test_shop_rules() else: for loc in SHOP_ITEMS: @@ -85,7 +85,7 @@ class PlandoTest(MessengerTestBase): with self.subTest("has cost", loc=loc): self.assertFalse(self.can_reach_location(loc)) - prices = self.multiworld.worlds[self.player].shop_prices + prices = self.world.shop_prices for loc, price in prices.items(): with self.subTest("prices", loc=loc): if loc == "Karuta Plates": @@ -98,7 +98,7 @@ class PlandoTest(MessengerTestBase): self.assertTrue(loc.replace("The Shop - ", "") in SHOP_ITEMS) self.assertEqual(len(prices), len(SHOP_ITEMS)) - figures = self.multiworld.worlds[self.player].figurine_prices + figures = self.world.figurine_prices for loc, price in figures.items(): with self.subTest("figure prices", loc=loc): if loc == "Barmath'azel Figurine": diff --git a/worlds/messenger/test/test_shop_chest.py b/worlds/messenger/test/test_shop_chest.py index a34fa0fb..f2030c63 100644 --- a/worlds/messenger/test/test_shop_chest.py +++ b/worlds/messenger/test/test_shop_chest.py @@ -41,8 +41,8 @@ class HalfSealsRequired(MessengerTestBase): def test_seals_amount(self) -> None: """Should have 45 power seals in the item pool and half that required""" self.assertEqual(self.multiworld.total_seals[self.player], 45) - self.assertEqual(self.multiworld.worlds[self.player].total_seals, 45) - self.assertEqual(self.multiworld.worlds[self.player].required_seals, 22) + self.assertEqual(self.world.total_seals, 45) + self.assertEqual(self.world.required_seals, 22) total_seals = [seal for seal in self.multiworld.itempool if seal.name == "Power Seal"] required_seals = [seal for seal in total_seals if seal.classification == ItemClassification.progression_skip_balancing] @@ -60,8 +60,8 @@ class ThirtyThirtySeals(MessengerTestBase): def test_seals_amount(self) -> None: """Should have 30 power seals in the pool and 33 percent of that required.""" self.assertEqual(self.multiworld.total_seals[self.player], 30) - self.assertEqual(self.multiworld.worlds[self.player].total_seals, 30) - self.assertEqual(self.multiworld.worlds[self.player].required_seals, 10) + self.assertEqual(self.world.total_seals, 30) + self.assertEqual(self.world.required_seals, 10) total_seals = [seal for seal in self.multiworld.itempool if seal.name == "Power Seal"] required_seals = [seal for seal in total_seals if seal.classification == ItemClassification.progression_skip_balancing] @@ -78,7 +78,7 @@ class MaxSealsNoShards(MessengerTestBase): def test_seals_amount(self) -> None: """Should set total seals to 70 since shards aren't shuffled.""" self.assertEqual(self.multiworld.total_seals[self.player], 85) - self.assertEqual(self.multiworld.worlds[self.player].total_seals, 70) + self.assertEqual(self.world.total_seals, 70) class MaxSealsWithShards(MessengerTestBase): @@ -91,8 +91,8 @@ class MaxSealsWithShards(MessengerTestBase): def test_seals_amount(self) -> None: """Should have 85 seals in the pool with all required and be a valid seed.""" self.assertEqual(self.multiworld.total_seals[self.player], 85) - self.assertEqual(self.multiworld.worlds[self.player].total_seals, 85) - self.assertEqual(self.multiworld.worlds[self.player].required_seals, 85) + self.assertEqual(self.world.total_seals, 85) + self.assertEqual(self.world.required_seals, 85) total_seals = [seal for seal in self.multiworld.itempool if seal.name == "Power Seal"] required_seals = [seal for seal in total_seals if seal.classification == ItemClassification.progression_skip_balancing]