diff --git a/BaseClasses.py b/BaseClasses.py index af1f2180..5dcc9daa 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -325,15 +325,15 @@ class MultiWorld(): def player_ids(self) -> Tuple[int, ...]: return tuple(range(1, self.players + 1)) - @functools.lru_cache() + @Utils.cache_self1 def get_game_players(self, game_name: str) -> Tuple[int, ...]: return tuple(player for player in self.player_ids if self.game[player] == game_name) - @functools.lru_cache() + @Utils.cache_self1 def get_game_groups(self, game_name: str) -> Tuple[int, ...]: return tuple(group_id for group_id in self.groups if self.game[group_id] == game_name) - @functools.lru_cache() + @Utils.cache_self1 def get_game_worlds(self, game_name: str): return tuple(world for player, world in self.worlds.items() if player not in self.groups and self.game[player] == game_name) diff --git a/Generate.py b/Generate.py index 34a0084e..8113d8a0 100644 --- a/Generate.py +++ b/Generate.py @@ -7,8 +7,8 @@ import random import string import urllib.parse import urllib.request -from collections import ChainMap, Counter -from typing import Any, Callable, Dict, Tuple, Union +from collections import Counter +from typing import Any, Dict, Tuple, Union import ModuleUpdate @@ -225,7 +225,7 @@ def main(args=None, callback=ERmain): with open(os.path.join(args.outputpath if args.outputpath else ".", f"generate_{seed_name}.yaml"), "wt") as f: yaml.dump(important, f) - callback(erargs, seed) + return callback(erargs, seed) def read_weights_yamls(path) -> Tuple[Any, ...]: @@ -639,6 +639,15 @@ def roll_alttp_settings(ret: argparse.Namespace, weights, plando_options): if __name__ == '__main__': import atexit confirmation = atexit.register(input, "Press enter to close.") - main() + multiworld = main() + if __debug__: + import gc + import sys + import weakref + weak = weakref.ref(multiworld) + del multiworld + gc.collect() # need to collect to deref all hard references + assert not weak(), f"MultiWorld object was not de-allocated, it's referenced {sys.getrefcount(weak())} times." \ + " This would be a memory leak." # in case of error-free exit should not need confirmation atexit.unregister(confirmation) diff --git a/Utils.py b/Utils.py index 4cf8ca22..114c2e81 100644 --- a/Utils.py +++ b/Utils.py @@ -74,6 +74,8 @@ def snes_to_pc(value: int) -> int: RetType = typing.TypeVar("RetType") +S = typing.TypeVar("S") +T = typing.TypeVar("T") def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[], RetType]: @@ -91,6 +93,31 @@ def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[] return _wrap +def cache_self1(function: typing.Callable[[S, T], RetType]) -> typing.Callable[[S, T], RetType]: + """Specialized cache for self + 1 arg. Does not keep global ref to self and skips building a dict key tuple.""" + + assert function.__code__.co_argcount == 2, "Can only cache 2 argument functions with this cache." + + cache_name = f"__cache_{function.__name__}__" + + @functools.wraps(function) + def wrap(self: S, arg: T) -> RetType: + cache: Optional[Dict[T, RetType]] = typing.cast(Optional[Dict[T, RetType]], + getattr(self, cache_name, None)) + if cache is None: + res = function(self, arg) + setattr(self, cache_name, {arg: res}) + return res + try: + return cache[arg] + except KeyError: + res = function(self, arg) + cache[arg] = res + return res + + return wrap + + def is_frozen() -> bool: return typing.cast(bool, getattr(sys, 'frozen', False)) diff --git a/test/bases.py b/test/bases.py index 9911a45b..2054c2d1 100644 --- a/test/bases.py +++ b/test/bases.py @@ -1,3 +1,4 @@ +import sys import typing import unittest from argparse import Namespace @@ -107,11 +108,36 @@ class WorldTestBase(unittest.TestCase): game: typing.ClassVar[str] # define game name in subclass, example "Secret of Evermore" auto_construct: typing.ClassVar[bool] = True """ automatically set up a world for each test in this class """ + memory_leak_tested: typing.ClassVar[bool] = False + """ remember if memory leak test was already done for this class """ def setUp(self) -> None: if self.auto_construct: self.world_setup() + def tearDown(self) -> None: + if self.__class__.memory_leak_tested or not self.options or not self.constructed or \ + sys.version_info < (3, 11, 0): # the leak check in tearDown fails in py<3.11 for an unknown reason + # only run memory leak test once per class, only for constructed with non-default options + # default options will be tested in test/general + super().tearDown() + return + + import gc + import weakref + weak = weakref.ref(self.multiworld) + for attr_name in dir(self): # delete all direct references to MultiWorld and World + attr: object = typing.cast(object, getattr(self, attr_name)) + if type(attr) is MultiWorld or isinstance(attr, AutoWorld.World): + delattr(self, attr_name) + state_cache: typing.Optional[typing.Dict[typing.Any, typing.Any]] = getattr(self, "_state_cache", None) + if state_cache is not None: # in case of multiple inheritance with TestBase, we need to clear its cache + state_cache.clear() + gc.collect() + self.__class__.memory_leak_tested = True + self.assertFalse(weak(), f"World {getattr(self, 'game', '')} leaked MultiWorld object") + super().tearDown() + def world_setup(self, seed: typing.Optional[int] = None) -> None: if type(self) is WorldTestBase or \ (hasattr(WorldTestBase, self._testMethodName) diff --git a/test/general/test_memory.py b/test/general/test_memory.py new file mode 100644 index 00000000..e352b9e8 --- /dev/null +++ b/test/general/test_memory.py @@ -0,0 +1,16 @@ +import unittest + +from worlds.AutoWorld import AutoWorldRegister +from . import setup_solo_multiworld + + +class TestWorldMemory(unittest.TestCase): + def test_leak(self): + """Tests that worlds don't leak references to MultiWorld or themselves with default options.""" + import gc + import weakref + for game_name, world_type in AutoWorldRegister.world_types.items(): + with self.subTest("Game", game_name=game_name): + weak = weakref.ref(setup_solo_multiworld(world_type)) + gc.collect() + self.assertFalse(weak(), "World leaked a reference") diff --git a/test/utils/test_caches.py b/test/utils/test_caches.py new file mode 100644 index 00000000..fc681611 --- /dev/null +++ b/test/utils/test_caches.py @@ -0,0 +1,66 @@ +# Tests for caches in Utils.py + +import unittest +from typing import Any + +from Utils import cache_argsless, cache_self1 + + +class TestCacheArgless(unittest.TestCase): + def test_cache(self) -> None: + @cache_argsless + def func_argless() -> object: + return object() + + self.assertTrue(func_argless() is func_argless()) + + if __debug__: # assert only available with __debug__ + def test_invalid_decorator(self) -> None: + with self.assertRaises(Exception): + @cache_argsless # type: ignore[arg-type] + def func_with_arg(_: Any) -> None: + pass + + +class TestCacheSelf1(unittest.TestCase): + def test_cache(self) -> None: + class Cls: + @cache_self1 + def func(self, _: Any) -> object: + return object() + + o1 = Cls() + o2 = Cls() + self.assertTrue(o1.func(1) is o1.func(1)) + self.assertFalse(o1.func(1) is o1.func(2)) + self.assertFalse(o1.func(1) is o2.func(1)) + + def test_gc(self) -> None: + # verify that we don't keep a global reference + import gc + import weakref + + class Cls: + @cache_self1 + def func(self, _: Any) -> object: + return object() + + o = Cls() + _ = o.func(o) # keep a hard ref to the result + r = weakref.ref(o) # keep weak ref to the cache + del o # remove hard ref to the cache + gc.collect() + self.assertFalse(r()) # weak ref should be dead now + + if __debug__: # assert only available with __debug__ + def test_no_self(self) -> None: + with self.assertRaises(Exception): + @cache_self1 # type: ignore[arg-type] + def func() -> Any: + pass + + def test_too_many_args(self) -> None: + with self.assertRaises(Exception): + @cache_self1 # type: ignore[arg-type] + def func(_1: Any, _2: Any, _3: Any) -> Any: + pass diff --git a/worlds/sm/__init__.py b/worlds/sm/__init__.py index 4b4002c1..e85d79d3 100644 --- a/worlds/sm/__init__.py +++ b/worlds/sm/__init__.py @@ -112,15 +112,12 @@ class SMWorld(World): required_client_version = (0, 2, 6) itemManager: ItemManager - spheres = None Logic.factory('vanilla') def __init__(self, world: MultiWorld, player: int): self.rom_name_available_event = threading.Event() self.locations = {} - if SMWorld.spheres != None: - SMWorld.spheres = None super().__init__(world, player) @classmethod @@ -368,7 +365,7 @@ class SMWorld(World): locationsDict[first_local_collected_loc.name]), itemLoc.item.player, True) - for itemLoc in SMWorld.spheres if itemLoc.item.player == self.player and (not progression_only or itemLoc.item.advancement) + for itemLoc in spheres if itemLoc.item.player == self.player and (not progression_only or itemLoc.item.advancement) ] # Having a sorted itemLocs from collection order is required for escapeTrigger when Tourian is Disabled. @@ -376,8 +373,10 @@ class SMWorld(World): # get_spheres could be cached in multiworld? # Another possible solution would be to have a globally accessible list of items in the order in which the get placed in push_item # and use the inversed starting from the first progression item. - if (SMWorld.spheres == None): - SMWorld.spheres = [itemLoc for sphere in self.multiworld.get_spheres() for itemLoc in sorted(sphere, key=lambda location: location.name)] + spheres: List[Location] = getattr(self.multiworld, "_sm_spheres", None) + if spheres is None: + spheres = [itemLoc for sphere in self.multiworld.get_spheres() for itemLoc in sorted(sphere, key=lambda location: location.name)] + setattr(self.multiworld, "_sm_spheres", spheres) self.itemLocs = [ ItemLocation(copy.copy(ItemManager.Items[itemLoc.item.type