Core: fix some memory leak sources without removing caching (#2400)

* Core: fix some memory leak sources

* Core: run gc before detecting memory leaks

* Core: restore caching in BaseClasses.MultiWorld

* SM: move spheres cache to MultiWorld._sm_spheres to avoid memory leak

* Test: add tests for world memory leaks

* Test: limit WorldTestBase leak-check to py>=3.11

---------

Co-authored-by: Fabian Dill <fabian.dill@web.de>
This commit is contained in:
black-sliver 2023-10-31 02:08:56 +01:00 committed by GitHub
parent d4498948f2
commit 5f5c48e17b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 156 additions and 13 deletions

View File

@ -325,15 +325,15 @@ class MultiWorld():
def player_ids(self) -> Tuple[int, ...]:
return tuple(range(1, self.players + 1))
@functools.lru_cache()
@Utils.cache_self1
def get_game_players(self, game_name: str) -> Tuple[int, ...]:
return tuple(player for player in self.player_ids if self.game[player] == game_name)
@functools.lru_cache()
@Utils.cache_self1
def get_game_groups(self, game_name: str) -> Tuple[int, ...]:
return tuple(group_id for group_id in self.groups if self.game[group_id] == game_name)
@functools.lru_cache()
@Utils.cache_self1
def get_game_worlds(self, game_name: str):
return tuple(world for player, world in self.worlds.items() if
player not in self.groups and self.game[player] == game_name)

View File

@ -7,8 +7,8 @@ import random
import string
import urllib.parse
import urllib.request
from collections import ChainMap, Counter
from typing import Any, Callable, Dict, Tuple, Union
from collections import Counter
from typing import Any, Dict, Tuple, Union
import ModuleUpdate
@ -225,7 +225,7 @@ def main(args=None, callback=ERmain):
with open(os.path.join(args.outputpath if args.outputpath else ".", f"generate_{seed_name}.yaml"), "wt") as f:
yaml.dump(important, f)
callback(erargs, seed)
return callback(erargs, seed)
def read_weights_yamls(path) -> Tuple[Any, ...]:
@ -639,6 +639,15 @@ def roll_alttp_settings(ret: argparse.Namespace, weights, plando_options):
if __name__ == '__main__':
import atexit
confirmation = atexit.register(input, "Press enter to close.")
main()
multiworld = main()
if __debug__:
import gc
import sys
import weakref
weak = weakref.ref(multiworld)
del multiworld
gc.collect() # need to collect to deref all hard references
assert not weak(), f"MultiWorld object was not de-allocated, it's referenced {sys.getrefcount(weak())} times." \
" This would be a memory leak."
# in case of error-free exit should not need confirmation
atexit.unregister(confirmation)

View File

@ -74,6 +74,8 @@ def snes_to_pc(value: int) -> int:
RetType = typing.TypeVar("RetType")
S = typing.TypeVar("S")
T = typing.TypeVar("T")
def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[], RetType]:
@ -91,6 +93,31 @@ def cache_argsless(function: typing.Callable[[], RetType]) -> typing.Callable[[]
return _wrap
def cache_self1(function: typing.Callable[[S, T], RetType]) -> typing.Callable[[S, T], RetType]:
"""Specialized cache for self + 1 arg. Does not keep global ref to self and skips building a dict key tuple."""
assert function.__code__.co_argcount == 2, "Can only cache 2 argument functions with this cache."
cache_name = f"__cache_{function.__name__}__"
@functools.wraps(function)
def wrap(self: S, arg: T) -> RetType:
cache: Optional[Dict[T, RetType]] = typing.cast(Optional[Dict[T, RetType]],
getattr(self, cache_name, None))
if cache is None:
res = function(self, arg)
setattr(self, cache_name, {arg: res})
return res
try:
return cache[arg]
except KeyError:
res = function(self, arg)
cache[arg] = res
return res
return wrap
def is_frozen() -> bool:
return typing.cast(bool, getattr(sys, 'frozen', False))

View File

@ -1,3 +1,4 @@
import sys
import typing
import unittest
from argparse import Namespace
@ -107,11 +108,36 @@ class WorldTestBase(unittest.TestCase):
game: typing.ClassVar[str] # define game name in subclass, example "Secret of Evermore"
auto_construct: typing.ClassVar[bool] = True
""" automatically set up a world for each test in this class """
memory_leak_tested: typing.ClassVar[bool] = False
""" remember if memory leak test was already done for this class """
def setUp(self) -> None:
if self.auto_construct:
self.world_setup()
def tearDown(self) -> None:
if self.__class__.memory_leak_tested or not self.options or not self.constructed or \
sys.version_info < (3, 11, 0): # the leak check in tearDown fails in py<3.11 for an unknown reason
# only run memory leak test once per class, only for constructed with non-default options
# default options will be tested in test/general
super().tearDown()
return
import gc
import weakref
weak = weakref.ref(self.multiworld)
for attr_name in dir(self): # delete all direct references to MultiWorld and World
attr: object = typing.cast(object, getattr(self, attr_name))
if type(attr) is MultiWorld or isinstance(attr, AutoWorld.World):
delattr(self, attr_name)
state_cache: typing.Optional[typing.Dict[typing.Any, typing.Any]] = getattr(self, "_state_cache", None)
if state_cache is not None: # in case of multiple inheritance with TestBase, we need to clear its cache
state_cache.clear()
gc.collect()
self.__class__.memory_leak_tested = True
self.assertFalse(weak(), f"World {getattr(self, 'game', '')} leaked MultiWorld object")
super().tearDown()
def world_setup(self, seed: typing.Optional[int] = None) -> None:
if type(self) is WorldTestBase or \
(hasattr(WorldTestBase, self._testMethodName)

View File

@ -0,0 +1,16 @@
import unittest
from worlds.AutoWorld import AutoWorldRegister
from . import setup_solo_multiworld
class TestWorldMemory(unittest.TestCase):
def test_leak(self):
"""Tests that worlds don't leak references to MultiWorld or themselves with default options."""
import gc
import weakref
for game_name, world_type in AutoWorldRegister.world_types.items():
with self.subTest("Game", game_name=game_name):
weak = weakref.ref(setup_solo_multiworld(world_type))
gc.collect()
self.assertFalse(weak(), "World leaked a reference")

66
test/utils/test_caches.py Normal file
View File

@ -0,0 +1,66 @@
# Tests for caches in Utils.py
import unittest
from typing import Any
from Utils import cache_argsless, cache_self1
class TestCacheArgless(unittest.TestCase):
def test_cache(self) -> None:
@cache_argsless
def func_argless() -> object:
return object()
self.assertTrue(func_argless() is func_argless())
if __debug__: # assert only available with __debug__
def test_invalid_decorator(self) -> None:
with self.assertRaises(Exception):
@cache_argsless # type: ignore[arg-type]
def func_with_arg(_: Any) -> None:
pass
class TestCacheSelf1(unittest.TestCase):
def test_cache(self) -> None:
class Cls:
@cache_self1
def func(self, _: Any) -> object:
return object()
o1 = Cls()
o2 = Cls()
self.assertTrue(o1.func(1) is o1.func(1))
self.assertFalse(o1.func(1) is o1.func(2))
self.assertFalse(o1.func(1) is o2.func(1))
def test_gc(self) -> None:
# verify that we don't keep a global reference
import gc
import weakref
class Cls:
@cache_self1
def func(self, _: Any) -> object:
return object()
o = Cls()
_ = o.func(o) # keep a hard ref to the result
r = weakref.ref(o) # keep weak ref to the cache
del o # remove hard ref to the cache
gc.collect()
self.assertFalse(r()) # weak ref should be dead now
if __debug__: # assert only available with __debug__
def test_no_self(self) -> None:
with self.assertRaises(Exception):
@cache_self1 # type: ignore[arg-type]
def func() -> Any:
pass
def test_too_many_args(self) -> None:
with self.assertRaises(Exception):
@cache_self1 # type: ignore[arg-type]
def func(_1: Any, _2: Any, _3: Any) -> Any:
pass

View File

@ -112,15 +112,12 @@ class SMWorld(World):
required_client_version = (0, 2, 6)
itemManager: ItemManager
spheres = None
Logic.factory('vanilla')
def __init__(self, world: MultiWorld, player: int):
self.rom_name_available_event = threading.Event()
self.locations = {}
if SMWorld.spheres != None:
SMWorld.spheres = None
super().__init__(world, player)
@classmethod
@ -368,7 +365,7 @@ class SMWorld(World):
locationsDict[first_local_collected_loc.name]),
itemLoc.item.player,
True)
for itemLoc in SMWorld.spheres if itemLoc.item.player == self.player and (not progression_only or itemLoc.item.advancement)
for itemLoc in spheres if itemLoc.item.player == self.player and (not progression_only or itemLoc.item.advancement)
]
# Having a sorted itemLocs from collection order is required for escapeTrigger when Tourian is Disabled.
@ -376,8 +373,10 @@ class SMWorld(World):
# get_spheres could be cached in multiworld?
# Another possible solution would be to have a globally accessible list of items in the order in which the get placed in push_item
# and use the inversed starting from the first progression item.
if (SMWorld.spheres == None):
SMWorld.spheres = [itemLoc for sphere in self.multiworld.get_spheres() for itemLoc in sorted(sphere, key=lambda location: location.name)]
spheres: List[Location] = getattr(self.multiworld, "_sm_spheres", None)
if spheres is None:
spheres = [itemLoc for sphere in self.multiworld.get_spheres() for itemLoc in sorted(sphere, key=lambda location: location.name)]
setattr(self.multiworld, "_sm_spheres", spheres)
self.itemLocs = [
ItemLocation(copy.copy(ItemManager.Items[itemLoc.item.type