Core: fix some memory leak sources without removing caching (#2400)
* Core: fix some memory leak sources * Core: run gc before detecting memory leaks * Core: restore caching in BaseClasses.MultiWorld * SM: move spheres cache to MultiWorld._sm_spheres to avoid memory leak * Test: add tests for world memory leaks * Test: limit WorldTestBase leak-check to py>=3.11 --------- Co-authored-by: Fabian Dill <fabian.dill@web.de>
This commit is contained in:
parent
d4498948f2
commit
5f5c48e17b
|
@ -325,15 +325,15 @@ class MultiWorld():
|
||||||
def player_ids(self) -> Tuple[int, ...]:
|
def player_ids(self) -> Tuple[int, ...]:
|
||||||
return tuple(range(1, self.players + 1))
|
return tuple(range(1, self.players + 1))
|
||||||
|
|
||||||
@functools.lru_cache()
|
@Utils.cache_self1
|
||||||
def get_game_players(self, game_name: str) -> Tuple[int, ...]:
|
def get_game_players(self, game_name: str) -> Tuple[int, ...]:
|
||||||
return tuple(player for player in self.player_ids if self.game[player] == game_name)
|
return tuple(player for player in self.player_ids if self.game[player] == game_name)
|
||||||
|
|
||||||
@functools.lru_cache()
|
@Utils.cache_self1
|
||||||
def get_game_groups(self, game_name: str) -> Tuple[int, ...]:
|
def get_game_groups(self, game_name: str) -> Tuple[int, ...]:
|
||||||
return tuple(group_id for group_id in self.groups if self.game[group_id] == game_name)
|
return tuple(group_id for group_id in self.groups if self.game[group_id] == game_name)
|
||||||
|
|
||||||
@functools.lru_cache()
|
@Utils.cache_self1
|
||||||
def get_game_worlds(self, game_name: str):
|
def get_game_worlds(self, game_name: str):
|
||||||
return tuple(world for player, world in self.worlds.items() if
|
return tuple(world for player, world in self.worlds.items() if
|
||||||
player not in self.groups and self.game[player] == game_name)
|
player not in self.groups and self.game[player] == game_name)
|
||||||
|
|
17
Generate.py
17
Generate.py
|
@ -7,8 +7,8 @@ import random
|
||||||
import string
|
import string
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from collections import ChainMap, Counter
|
from collections import Counter
|
||||||
from typing import Any, Callable, Dict, Tuple, Union
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
import ModuleUpdate
|
import ModuleUpdate
|
||||||
|
|
||||||
|
@ -225,7 +225,7 @@ def main(args=None, callback=ERmain):
|
||||||
with open(os.path.join(args.outputpath if args.outputpath else ".", f"generate_{seed_name}.yaml"), "wt") as f:
|
with open(os.path.join(args.outputpath if args.outputpath else ".", f"generate_{seed_name}.yaml"), "wt") as f:
|
||||||
yaml.dump(important, f)
|
yaml.dump(important, f)
|
||||||
|
|
||||||
callback(erargs, seed)
|
return callback(erargs, seed)
|
||||||
|
|
||||||
|
|
||||||
def read_weights_yamls(path) -> Tuple[Any, ...]:
|
def read_weights_yamls(path) -> Tuple[Any, ...]:
|
||||||
|
@ -639,6 +639,15 @@ def roll_alttp_settings(ret: argparse.Namespace, weights, plando_options):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import atexit
|
import atexit
|
||||||
confirmation = atexit.register(input, "Press enter to close.")
|
confirmation = atexit.register(input, "Press enter to close.")
|
||||||
main()
|
multiworld = main()
|
||||||
|
if __debug__:
|
||||||
|
import gc
|
||||||
|
import sys
|
||||||
|
import weakref
|
||||||
|
weak = weakref.ref(multiworld)
|
||||||
|
del multiworld
|
||||||
|
gc.collect() # need to collect to deref all hard references
|
||||||
|
assert not weak(), f"MultiWorld object was not de-allocated, it's referenced {sys.getrefcount(weak())} times." \
|
||||||
|
" This would be a memory leak."
|
||||||
# in case of error-free exit should not need confirmation
|
# in case of error-free exit should not need confirmation
|
||||||
atexit.unregister(confirmation)
|
atexit.unregister(confirmation)
|
||||||
|
|
27
Utils.py
27
Utils.py
|
@ -74,6 +74,8 @@ def snes_to_pc(value: int) -> int:
|
||||||
|
|
||||||
|
|
||||||
RetType = typing.TypeVar("RetType")
|
RetType = typing.TypeVar("RetType")
|
||||||
|
S = typing.TypeVar("S")
|
||||||
|
T = typing.TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[], RetType]:
|
def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[], RetType]:
|
||||||
|
@ -91,6 +93,31 @@ def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[]
|
||||||
return _wrap
|
return _wrap
|
||||||
|
|
||||||
|
|
||||||
|
def cache_self1(function: typing.Callable[[S, T], RetType]) -> typing.Callable[[S, T], RetType]:
|
||||||
|
"""Specialized cache for self + 1 arg. Does not keep global ref to self and skips building a dict key tuple."""
|
||||||
|
|
||||||
|
assert function.__code__.co_argcount == 2, "Can only cache 2 argument functions with this cache."
|
||||||
|
|
||||||
|
cache_name = f"__cache_{function.__name__}__"
|
||||||
|
|
||||||
|
@functools.wraps(function)
|
||||||
|
def wrap(self: S, arg: T) -> RetType:
|
||||||
|
cache: Optional[Dict[T, RetType]] = typing.cast(Optional[Dict[T, RetType]],
|
||||||
|
getattr(self, cache_name, None))
|
||||||
|
if cache is None:
|
||||||
|
res = function(self, arg)
|
||||||
|
setattr(self, cache_name, {arg: res})
|
||||||
|
return res
|
||||||
|
try:
|
||||||
|
return cache[arg]
|
||||||
|
except KeyError:
|
||||||
|
res = function(self, arg)
|
||||||
|
cache[arg] = res
|
||||||
|
return res
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
def is_frozen() -> bool:
|
def is_frozen() -> bool:
|
||||||
return typing.cast(bool, getattr(sys, 'frozen', False))
|
return typing.cast(bool, getattr(sys, 'frozen', False))
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import unittest
|
import unittest
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
@ -107,11 +108,36 @@ class WorldTestBase(unittest.TestCase):
|
||||||
game: typing.ClassVar[str] # define game name in subclass, example "Secret of Evermore"
|
game: typing.ClassVar[str] # define game name in subclass, example "Secret of Evermore"
|
||||||
auto_construct: typing.ClassVar[bool] = True
|
auto_construct: typing.ClassVar[bool] = True
|
||||||
""" automatically set up a world for each test in this class """
|
""" automatically set up a world for each test in this class """
|
||||||
|
memory_leak_tested: typing.ClassVar[bool] = False
|
||||||
|
""" remember if memory leak test was already done for this class """
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
if self.auto_construct:
|
if self.auto_construct:
|
||||||
self.world_setup()
|
self.world_setup()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
if self.__class__.memory_leak_tested or not self.options or not self.constructed or \
|
||||||
|
sys.version_info < (3, 11, 0): # the leak check in tearDown fails in py<3.11 for an unknown reason
|
||||||
|
# only run memory leak test once per class, only for constructed with non-default options
|
||||||
|
# default options will be tested in test/general
|
||||||
|
super().tearDown()
|
||||||
|
return
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import weakref
|
||||||
|
weak = weakref.ref(self.multiworld)
|
||||||
|
for attr_name in dir(self): # delete all direct references to MultiWorld and World
|
||||||
|
attr: object = typing.cast(object, getattr(self, attr_name))
|
||||||
|
if type(attr) is MultiWorld or isinstance(attr, AutoWorld.World):
|
||||||
|
delattr(self, attr_name)
|
||||||
|
state_cache: typing.Optional[typing.Dict[typing.Any, typing.Any]] = getattr(self, "_state_cache", None)
|
||||||
|
if state_cache is not None: # in case of multiple inheritance with TestBase, we need to clear its cache
|
||||||
|
state_cache.clear()
|
||||||
|
gc.collect()
|
||||||
|
self.__class__.memory_leak_tested = True
|
||||||
|
self.assertFalse(weak(), f"World {getattr(self, 'game', '')} leaked MultiWorld object")
|
||||||
|
super().tearDown()
|
||||||
|
|
||||||
def world_setup(self, seed: typing.Optional[int] = None) -> None:
|
def world_setup(self, seed: typing.Optional[int] = None) -> None:
|
||||||
if type(self) is WorldTestBase or \
|
if type(self) is WorldTestBase or \
|
||||||
(hasattr(WorldTestBase, self._testMethodName)
|
(hasattr(WorldTestBase, self._testMethodName)
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from worlds.AutoWorld import AutoWorldRegister
|
||||||
|
from . import setup_solo_multiworld
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorldMemory(unittest.TestCase):
|
||||||
|
def test_leak(self):
|
||||||
|
"""Tests that worlds don't leak references to MultiWorld or themselves with default options."""
|
||||||
|
import gc
|
||||||
|
import weakref
|
||||||
|
for game_name, world_type in AutoWorldRegister.world_types.items():
|
||||||
|
with self.subTest("Game", game_name=game_name):
|
||||||
|
weak = weakref.ref(setup_solo_multiworld(world_type))
|
||||||
|
gc.collect()
|
||||||
|
self.assertFalse(weak(), "World leaked a reference")
|
|
@ -0,0 +1,66 @@
|
||||||
|
# Tests for caches in Utils.py
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from Utils import cache_argsless, cache_self1
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheArgless(unittest.TestCase):
|
||||||
|
def test_cache(self) -> None:
|
||||||
|
@cache_argsless
|
||||||
|
def func_argless() -> object:
|
||||||
|
return object()
|
||||||
|
|
||||||
|
self.assertTrue(func_argless() is func_argless())
|
||||||
|
|
||||||
|
if __debug__: # assert only available with __debug__
|
||||||
|
def test_invalid_decorator(self) -> None:
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
@cache_argsless # type: ignore[arg-type]
|
||||||
|
def func_with_arg(_: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheSelf1(unittest.TestCase):
|
||||||
|
def test_cache(self) -> None:
|
||||||
|
class Cls:
|
||||||
|
@cache_self1
|
||||||
|
def func(self, _: Any) -> object:
|
||||||
|
return object()
|
||||||
|
|
||||||
|
o1 = Cls()
|
||||||
|
o2 = Cls()
|
||||||
|
self.assertTrue(o1.func(1) is o1.func(1))
|
||||||
|
self.assertFalse(o1.func(1) is o1.func(2))
|
||||||
|
self.assertFalse(o1.func(1) is o2.func(1))
|
||||||
|
|
||||||
|
def test_gc(self) -> None:
|
||||||
|
# verify that we don't keep a global reference
|
||||||
|
import gc
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
@cache_self1
|
||||||
|
def func(self, _: Any) -> object:
|
||||||
|
return object()
|
||||||
|
|
||||||
|
o = Cls()
|
||||||
|
_ = o.func(o) # keep a hard ref to the result
|
||||||
|
r = weakref.ref(o) # keep weak ref to the cache
|
||||||
|
del o # remove hard ref to the cache
|
||||||
|
gc.collect()
|
||||||
|
self.assertFalse(r()) # weak ref should be dead now
|
||||||
|
|
||||||
|
if __debug__: # assert only available with __debug__
|
||||||
|
def test_no_self(self) -> None:
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
@cache_self1 # type: ignore[arg-type]
|
||||||
|
def func() -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_too_many_args(self) -> None:
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
@cache_self1 # type: ignore[arg-type]
|
||||||
|
def func(_1: Any, _2: Any, _3: Any) -> Any:
|
||||||
|
pass
|
|
@ -112,15 +112,12 @@ class SMWorld(World):
|
||||||
required_client_version = (0, 2, 6)
|
required_client_version = (0, 2, 6)
|
||||||
|
|
||||||
itemManager: ItemManager
|
itemManager: ItemManager
|
||||||
spheres = None
|
|
||||||
|
|
||||||
Logic.factory('vanilla')
|
Logic.factory('vanilla')
|
||||||
|
|
||||||
def __init__(self, world: MultiWorld, player: int):
|
def __init__(self, world: MultiWorld, player: int):
|
||||||
self.rom_name_available_event = threading.Event()
|
self.rom_name_available_event = threading.Event()
|
||||||
self.locations = {}
|
self.locations = {}
|
||||||
if SMWorld.spheres != None:
|
|
||||||
SMWorld.spheres = None
|
|
||||||
super().__init__(world, player)
|
super().__init__(world, player)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -368,7 +365,7 @@ class SMWorld(World):
|
||||||
locationsDict[first_local_collected_loc.name]),
|
locationsDict[first_local_collected_loc.name]),
|
||||||
itemLoc.item.player,
|
itemLoc.item.player,
|
||||||
True)
|
True)
|
||||||
for itemLoc in SMWorld.spheres if itemLoc.item.player == self.player and (not progression_only or itemLoc.item.advancement)
|
for itemLoc in spheres if itemLoc.item.player == self.player and (not progression_only or itemLoc.item.advancement)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Having a sorted itemLocs from collection order is required for escapeTrigger when Tourian is Disabled.
|
# Having a sorted itemLocs from collection order is required for escapeTrigger when Tourian is Disabled.
|
||||||
|
@ -376,8 +373,10 @@ class SMWorld(World):
|
||||||
# get_spheres could be cached in multiworld?
|
# get_spheres could be cached in multiworld?
|
||||||
# Another possible solution would be to have a globally accessible list of items in the order in which the get placed in push_item
|
# Another possible solution would be to have a globally accessible list of items in the order in which the get placed in push_item
|
||||||
# and use the inversed starting from the first progression item.
|
# and use the inversed starting from the first progression item.
|
||||||
if (SMWorld.spheres == None):
|
spheres: List[Location] = getattr(self.multiworld, "_sm_spheres", None)
|
||||||
SMWorld.spheres = [itemLoc for sphere in self.multiworld.get_spheres() for itemLoc in sorted(sphere, key=lambda location: location.name)]
|
if spheres is None:
|
||||||
|
spheres = [itemLoc for sphere in self.multiworld.get_spheres() for itemLoc in sorted(sphere, key=lambda location: location.name)]
|
||||||
|
setattr(self.multiworld, "_sm_spheres", spheres)
|
||||||
|
|
||||||
self.itemLocs = [
|
self.itemLocs = [
|
||||||
ItemLocation(copy.copy(ItemManager.Items[itemLoc.item.type
|
ItemLocation(copy.copy(ItemManager.Items[itemLoc.item.type
|
||||||
|
|
Loading…
Reference in New Issue