Core: fix some memory leak sources without removing caching (#2400)
* Core: fix some memory leak sources * Core: run gc before detecting memory leaks * Core: restore caching in BaseClasses.MultiWorld * SM: move spheres cache to MultiWorld._sm_spheres to avoid memory leak * Test: add tests for world memory leaks * Test: limit WorldTestBase leak-check to py>=3.11 --------- Co-authored-by: Fabian Dill <fabian.dill@web.de>
This commit is contained in:
parent
d4498948f2
commit
5f5c48e17b
|
@ -325,15 +325,15 @@ class MultiWorld():
|
|||
def player_ids(self) -> Tuple[int, ...]:
|
||||
return tuple(range(1, self.players + 1))
|
||||
|
||||
@functools.lru_cache()
|
||||
@Utils.cache_self1
|
||||
def get_game_players(self, game_name: str) -> Tuple[int, ...]:
|
||||
return tuple(player for player in self.player_ids if self.game[player] == game_name)
|
||||
|
||||
@functools.lru_cache()
|
||||
@Utils.cache_self1
|
||||
def get_game_groups(self, game_name: str) -> Tuple[int, ...]:
|
||||
return tuple(group_id for group_id in self.groups if self.game[group_id] == game_name)
|
||||
|
||||
@functools.lru_cache()
|
||||
@Utils.cache_self1
|
||||
def get_game_worlds(self, game_name: str):
|
||||
return tuple(world for player, world in self.worlds.items() if
|
||||
player not in self.groups and self.game[player] == game_name)
|
||||
|
|
17
Generate.py
17
Generate.py
|
@ -7,8 +7,8 @@ import random
|
|||
import string
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from collections import ChainMap, Counter
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import ModuleUpdate
|
||||
|
||||
|
@ -225,7 +225,7 @@ def main(args=None, callback=ERmain):
|
|||
with open(os.path.join(args.outputpath if args.outputpath else ".", f"generate_{seed_name}.yaml"), "wt") as f:
|
||||
yaml.dump(important, f)
|
||||
|
||||
callback(erargs, seed)
|
||||
return callback(erargs, seed)
|
||||
|
||||
|
||||
def read_weights_yamls(path) -> Tuple[Any, ...]:
|
||||
|
@ -639,6 +639,15 @@ def roll_alttp_settings(ret: argparse.Namespace, weights, plando_options):
|
|||
if __name__ == '__main__':
|
||||
import atexit
|
||||
confirmation = atexit.register(input, "Press enter to close.")
|
||||
main()
|
||||
multiworld = main()
|
||||
if __debug__:
|
||||
import gc
|
||||
import sys
|
||||
import weakref
|
||||
weak = weakref.ref(multiworld)
|
||||
del multiworld
|
||||
gc.collect() # need to collect to deref all hard references
|
||||
assert not weak(), f"MultiWorld object was not de-allocated, it's referenced {sys.getrefcount(weak())} times." \
|
||||
" This would be a memory leak."
|
||||
# in case of error-free exit should not need confirmation
|
||||
atexit.unregister(confirmation)
|
||||
|
|
27
Utils.py
27
Utils.py
|
@ -74,6 +74,8 @@ def snes_to_pc(value: int) -> int:
|
|||
|
||||
|
||||
RetType = typing.TypeVar("RetType")
|
||||
S = typing.TypeVar("S")
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
|
||||
def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[], RetType]:
|
||||
|
@ -91,6 +93,31 @@ def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[]
|
|||
return _wrap
|
||||
|
||||
|
||||
def cache_self1(function: typing.Callable[[S, T], RetType]) -> typing.Callable[[S, T], RetType]:
|
||||
"""Specialized cache for self + 1 arg. Does not keep global ref to self and skips building a dict key tuple."""
|
||||
|
||||
assert function.__code__.co_argcount == 2, "Can only cache 2 argument functions with this cache."
|
||||
|
||||
cache_name = f"__cache_{function.__name__}__"
|
||||
|
||||
@functools.wraps(function)
|
||||
def wrap(self: S, arg: T) -> RetType:
|
||||
cache: Optional[Dict[T, RetType]] = typing.cast(Optional[Dict[T, RetType]],
|
||||
getattr(self, cache_name, None))
|
||||
if cache is None:
|
||||
res = function(self, arg)
|
||||
setattr(self, cache_name, {arg: res})
|
||||
return res
|
||||
try:
|
||||
return cache[arg]
|
||||
except KeyError:
|
||||
res = function(self, arg)
|
||||
cache[arg] = res
|
||||
return res
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def is_frozen() -> bool:
|
||||
return typing.cast(bool, getattr(sys, 'frozen', False))
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import sys
|
||||
import typing
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
|
@ -107,11 +108,36 @@ class WorldTestBase(unittest.TestCase):
|
|||
game: typing.ClassVar[str] # define game name in subclass, example "Secret of Evermore"
|
||||
auto_construct: typing.ClassVar[bool] = True
|
||||
""" automatically set up a world for each test in this class """
|
||||
memory_leak_tested: typing.ClassVar[bool] = False
|
||||
""" remember if memory leak test was already done for this class """
|
||||
|
||||
def setUp(self) -> None:
|
||||
if self.auto_construct:
|
||||
self.world_setup()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if self.__class__.memory_leak_tested or not self.options or not self.constructed or \
|
||||
sys.version_info < (3, 11, 0): # the leak check in tearDown fails in py<3.11 for an unknown reason
|
||||
# only run memory leak test once per class, only for constructed with non-default options
|
||||
# default options will be tested in test/general
|
||||
super().tearDown()
|
||||
return
|
||||
|
||||
import gc
|
||||
import weakref
|
||||
weak = weakref.ref(self.multiworld)
|
||||
for attr_name in dir(self): # delete all direct references to MultiWorld and World
|
||||
attr: object = typing.cast(object, getattr(self, attr_name))
|
||||
if type(attr) is MultiWorld or isinstance(attr, AutoWorld.World):
|
||||
delattr(self, attr_name)
|
||||
state_cache: typing.Optional[typing.Dict[typing.Any, typing.Any]] = getattr(self, "_state_cache", None)
|
||||
if state_cache is not None: # in case of multiple inheritance with TestBase, we need to clear its cache
|
||||
state_cache.clear()
|
||||
gc.collect()
|
||||
self.__class__.memory_leak_tested = True
|
||||
self.assertFalse(weak(), f"World {getattr(self, 'game', '')} leaked MultiWorld object")
|
||||
super().tearDown()
|
||||
|
||||
def world_setup(self, seed: typing.Optional[int] = None) -> None:
|
||||
if type(self) is WorldTestBase or \
|
||||
(hasattr(WorldTestBase, self._testMethodName)
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
import unittest
|
||||
|
||||
from worlds.AutoWorld import AutoWorldRegister
|
||||
from . import setup_solo_multiworld
|
||||
|
||||
|
||||
class TestWorldMemory(unittest.TestCase):
|
||||
def test_leak(self):
|
||||
"""Tests that worlds don't leak references to MultiWorld or themselves with default options."""
|
||||
import gc
|
||||
import weakref
|
||||
for game_name, world_type in AutoWorldRegister.world_types.items():
|
||||
with self.subTest("Game", game_name=game_name):
|
||||
weak = weakref.ref(setup_solo_multiworld(world_type))
|
||||
gc.collect()
|
||||
self.assertFalse(weak(), "World leaked a reference")
|
|
@ -0,0 +1,66 @@
|
|||
# Tests for caches in Utils.py
|
||||
|
||||
import unittest
|
||||
from typing import Any
|
||||
|
||||
from Utils import cache_argsless, cache_self1
|
||||
|
||||
|
||||
class TestCacheArgless(unittest.TestCase):
|
||||
def test_cache(self) -> None:
|
||||
@cache_argsless
|
||||
def func_argless() -> object:
|
||||
return object()
|
||||
|
||||
self.assertTrue(func_argless() is func_argless())
|
||||
|
||||
if __debug__: # assert only available with __debug__
|
||||
def test_invalid_decorator(self) -> None:
|
||||
with self.assertRaises(Exception):
|
||||
@cache_argsless # type: ignore[arg-type]
|
||||
def func_with_arg(_: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class TestCacheSelf1(unittest.TestCase):
|
||||
def test_cache(self) -> None:
|
||||
class Cls:
|
||||
@cache_self1
|
||||
def func(self, _: Any) -> object:
|
||||
return object()
|
||||
|
||||
o1 = Cls()
|
||||
o2 = Cls()
|
||||
self.assertTrue(o1.func(1) is o1.func(1))
|
||||
self.assertFalse(o1.func(1) is o1.func(2))
|
||||
self.assertFalse(o1.func(1) is o2.func(1))
|
||||
|
||||
def test_gc(self) -> None:
|
||||
# verify that we don't keep a global reference
|
||||
import gc
|
||||
import weakref
|
||||
|
||||
class Cls:
|
||||
@cache_self1
|
||||
def func(self, _: Any) -> object:
|
||||
return object()
|
||||
|
||||
o = Cls()
|
||||
_ = o.func(o) # keep a hard ref to the result
|
||||
r = weakref.ref(o) # keep weak ref to the cache
|
||||
del o # remove hard ref to the cache
|
||||
gc.collect()
|
||||
self.assertFalse(r()) # weak ref should be dead now
|
||||
|
||||
if __debug__: # assert only available with __debug__
|
||||
def test_no_self(self) -> None:
|
||||
with self.assertRaises(Exception):
|
||||
@cache_self1 # type: ignore[arg-type]
|
||||
def func() -> Any:
|
||||
pass
|
||||
|
||||
def test_too_many_args(self) -> None:
|
||||
with self.assertRaises(Exception):
|
||||
@cache_self1 # type: ignore[arg-type]
|
||||
def func(_1: Any, _2: Any, _3: Any) -> Any:
|
||||
pass
|
|
@ -112,15 +112,12 @@ class SMWorld(World):
|
|||
required_client_version = (0, 2, 6)
|
||||
|
||||
itemManager: ItemManager
|
||||
spheres = None
|
||||
|
||||
Logic.factory('vanilla')
|
||||
|
||||
def __init__(self, world: MultiWorld, player: int):
|
||||
self.rom_name_available_event = threading.Event()
|
||||
self.locations = {}
|
||||
if SMWorld.spheres != None:
|
||||
SMWorld.spheres = None
|
||||
super().__init__(world, player)
|
||||
|
||||
@classmethod
|
||||
|
@ -368,7 +365,7 @@ class SMWorld(World):
|
|||
locationsDict[first_local_collected_loc.name]),
|
||||
itemLoc.item.player,
|
||||
True)
|
||||
for itemLoc in SMWorld.spheres if itemLoc.item.player == self.player and (not progression_only or itemLoc.item.advancement)
|
||||
for itemLoc in spheres if itemLoc.item.player == self.player and (not progression_only or itemLoc.item.advancement)
|
||||
]
|
||||
|
||||
# Having a sorted itemLocs from collection order is required for escapeTrigger when Tourian is Disabled.
|
||||
|
@ -376,8 +373,10 @@ class SMWorld(World):
|
|||
# get_spheres could be cached in multiworld?
|
||||
# Another possible solution would be to have a globally accessible list of items in the order in which the get placed in push_item
|
||||
# and use the inversed starting from the first progression item.
|
||||
if (SMWorld.spheres == None):
|
||||
SMWorld.spheres = [itemLoc for sphere in self.multiworld.get_spheres() for itemLoc in sorted(sphere, key=lambda location: location.name)]
|
||||
spheres: List[Location] = getattr(self.multiworld, "_sm_spheres", None)
|
||||
if spheres is None:
|
||||
spheres = [itemLoc for sphere in self.multiworld.get_spheres() for itemLoc in sorted(sphere, key=lambda location: location.name)]
|
||||
setattr(self.multiworld, "_sm_spheres", spheres)
|
||||
|
||||
self.itemLocs = [
|
||||
ItemLocation(copy.copy(ItemManager.Items[itemLoc.item.type
|
||||
|
|
Loading…
Reference in New Issue