From 3fb0b57d19b9c223107308c84605e05e6b16e1cf Mon Sep 17 00:00:00 2001 From: black-sliver <59490463+black-sliver@users.noreply.github.com> Date: Tue, 10 Dec 2024 20:09:36 +0100 Subject: [PATCH] Core: fix exceptions coming from LocationStore (#4358) * Speedups: add instructions for ASAN * Speedups: move typevars out of classes * Speedups, NetUtils: raise correct exceptions * Speedups: double-check malloc * Tests: more LocationStore tests --- NetUtils.py | 2 ++ _speedups.pyx | 43 ++++++++++++++++++---------- _speedups.pyxbld | 18 ++++++++---- test/netutils/test_location_store.py | 41 ++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 20 deletions(-) diff --git a/NetUtils.py b/NetUtils.py index ec6ff3eb..196a030f 100644 --- a/NetUtils.py +++ b/NetUtils.py @@ -410,6 +410,8 @@ class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tu checked = state[team, slot] if not checked: # This optimizes the case where everyone connects to a fresh game at the same time. + if slot not in self: + raise KeyError(slot) return [] return [location_id for location_id in self[slot] if diff --git a/_speedups.pyx b/_speedups.pyx index dc039e33..2ad1a295 100644 --- a/_speedups.pyx +++ b/_speedups.pyx @@ -69,6 +69,14 @@ cdef struct IndexEntry: size_t count +if TYPE_CHECKING: + State = Dict[Tuple[int, int], Set[int]] +else: + State = Union[Tuple[int, int], Set[int], defaultdict] + +T = TypeVar('T') + + @cython.auto_pickle(False) cdef class LocationStore: """Compact store for locations and their items in a MultiServer""" @@ -137,10 +145,16 @@ cdef class LocationStore: warnings.warn("Game has no locations") # allocate the arrays and invalidate index (0xff...) - self.entries = self._mem.alloc(count, sizeof(LocationEntry)) + if count: + # leaving entries as NULL if there are none, makes potential memory errors more visible + self.entries = self._mem.alloc(count, sizeof(LocationEntry)) self.sender_index = self._mem.alloc(max_sender + 1, sizeof(IndexEntry)) self._raw_proxies = self._mem.alloc(max_sender + 1, sizeof(PyObject*)) + assert (not self.entries) == (not count) + assert self.sender_index + assert self._raw_proxies + # build entries and index cdef size_t i = 0 for sender, locations in sorted(locations_dict.items()): @@ -190,8 +204,6 @@ cdef class LocationStore: raise KeyError(key) return self._raw_proxies[key] - T = TypeVar('T') - def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]: # calling into self.__getitem__ here is slow, but this is not used in MultiServer try: @@ -246,12 +258,11 @@ cdef class LocationStore: all_locations[sender].add(entry.location) return all_locations - if TYPE_CHECKING: - State = Dict[Tuple[int, int], Set[int]] - else: - State = Union[Tuple[int, int], Set[int], defaultdict] - def get_checked(self, state: State, team: int, slot: int) -> List[int]: + cdef ap_player_t sender = slot + if sender < 0 or sender >= self.sender_index_size: + raise KeyError(slot) + # This used to validate checks actually exist. A remnant from the past. # If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it. cdef set checked = state[team, slot] @@ -263,7 +274,6 @@ cdef class LocationStore: # Unless the set is close to empty, it's cheaper to use the python set directly, so we do that. cdef LocationEntry* entry - cdef ap_player_t sender = slot cdef size_t start = self.sender_index[sender].start cdef size_t count = self.sender_index[sender].count return [entry.location for @@ -273,9 +283,11 @@ cdef class LocationStore: def get_missing(self, state: State, team: int, slot: int) -> List[int]: cdef LocationEntry* entry cdef ap_player_t sender = slot + if sender < 0 or sender >= self.sender_index_size: + raise KeyError(slot) + cdef set checked = state[team, slot] cdef size_t start = self.sender_index[sender].start cdef size_t count = self.sender_index[sender].count - cdef set checked = state[team, slot] if not len(checked): # Skip `in` if none have been checked. # This optimizes the case where everyone connects to a fresh game at the same time. @@ -290,9 +302,11 @@ cdef class LocationStore: def get_remaining(self, state: State, team: int, slot: int) -> List[Tuple[int, int]]: cdef LocationEntry* entry cdef ap_player_t sender = slot + if sender < 0 or sender >= self.sender_index_size: + raise KeyError(slot) + cdef set checked = state[team, slot] cdef size_t start = self.sender_index[sender].start cdef size_t count = self.sender_index[sender].count - cdef set checked = state[team, slot] return sorted([(entry.receiver, entry.item) for entry in self.entries[start:start+count] if entry.location not in checked]) @@ -328,7 +342,8 @@ cdef class PlayerLocationProxy: cdef LocationEntry* entry = NULL # binary search cdef size_t l = self._store.sender_index[self._player].start - cdef size_t r = l + self._store.sender_index[self._player].count + cdef size_t e = l + self._store.sender_index[self._player].count + cdef size_t r = e cdef size_t m while l < r: m = (l + r) // 2 @@ -337,7 +352,7 @@ cdef class PlayerLocationProxy: l = m + 1 else: r = m - if entry: # count != 0 + if l < e: entry = self._store.entries + l if entry.location == loc: return entry @@ -349,8 +364,6 @@ cdef class PlayerLocationProxy: return entry.item, entry.receiver, entry.flags raise KeyError(f"No location {key} for player {self._player}") - T = TypeVar('T') - def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]: cdef LocationEntry* entry = self._get(key) if entry: diff --git a/_speedups.pyxbld b/_speedups.pyxbld index 974eaed0..98f97346 100644 --- a/_speedups.pyxbld +++ b/_speedups.pyxbld @@ -3,8 +3,16 @@ import os def make_ext(modname, pyxfilename): from distutils.extension import Extension - return Extension(name=modname, - sources=[pyxfilename], - depends=["intset.h"], - include_dirs=[os.getcwd()], - language="c") + return Extension( + name=modname, + sources=[pyxfilename], + depends=["intset.h"], + include_dirs=[os.getcwd()], + language="c", + # to enable ASAN and debug build: + # extra_compile_args=["-fsanitize=address", "-UNDEBUG", "-Og", "-g"], + # extra_objects=["-fsanitize=address"], + # NOTE: we can not put -lasan at the front of link args, so needs to be run with + # LD_PRELOAD=/usr/lib/libasan.so ASAN_OPTIONS=detect_leaks=0 path/to/exe + # NOTE: this can't find everything unless libpython and cymem are also built with ASAN + ) diff --git a/test/netutils/test_location_store.py b/test/netutils/test_location_store.py index 1b984015..264f35b3 100644 --- a/test/netutils/test_location_store.py +++ b/test/netutils/test_location_store.py @@ -115,6 +115,7 @@ class Base: def test_get_for_player(self) -> None: self.assertEqual(self.store.get_for_player(3), {4: {9}}) self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}}) + self.assertEqual(self.store.get_for_player(9999), {}) def test_get_checked(self) -> None: self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13]) @@ -122,18 +123,48 @@ class Base: self.assertEqual(self.store.get_checked(empty_state, 0, 1), []) self.assertEqual(self.store.get_checked(full_state, 0, 3), [9]) + def test_get_checked_exception(self) -> None: + with self.assertRaises(KeyError): + self.store.get_checked(empty_state, 0, 9999) + bad_state = {(0, 6): {1}} + with self.assertRaises(KeyError): + self.store.get_checked(bad_state, 0, 6) + bad_state = {(0, 9999): set()} + with self.assertRaises(KeyError): + self.store.get_checked(bad_state, 0, 9999) + def test_get_missing(self) -> None: self.assertEqual(self.store.get_missing(full_state, 0, 1), []) self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13]) self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13]) self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9]) + def test_get_missing_exception(self) -> None: + with self.assertRaises(KeyError): + self.store.get_missing(empty_state, 0, 9999) + bad_state = {(0, 6): {1}} + with self.assertRaises(KeyError): + self.store.get_missing(bad_state, 0, 6) + bad_state = {(0, 9999): set()} + with self.assertRaises(KeyError): + self.store.get_missing(bad_state, 0, 9999) + def test_get_remaining(self) -> None: self.assertEqual(self.store.get_remaining(full_state, 0, 1), []) self.assertEqual(self.store.get_remaining(one_state, 0, 1), [(1, 13), (2, 21)]) self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [(1, 13), (2, 21), (2, 22)]) self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [(4, 99)]) + def test_get_remaining_exception(self) -> None: + with self.assertRaises(KeyError): + self.store.get_remaining(empty_state, 0, 9999) + bad_state = {(0, 6): {1}} + with self.assertRaises(KeyError): + self.store.get_missing(bad_state, 0, 6) + bad_state = {(0, 9999): set()} + with self.assertRaises(KeyError): + self.store.get_remaining(bad_state, 0, 9999) + def test_location_set_intersection(self) -> None: locations = {10, 11, 12} locations.intersection_update(self.store[1]) @@ -181,6 +212,16 @@ class Base: }) self.assertEqual(len(store), 1) self.assertEqual(len(store[1]), 0) + self.assertEqual(sorted(store.find_item(set(), 1)), []) + self.assertEqual(sorted(store.find_item({1}, 1)), []) + self.assertEqual(sorted(store.find_item({1, 2}, 1)), []) + self.assertEqual(store.get_for_player(1), {}) + self.assertEqual(store.get_checked(empty_state, 0, 1), []) + self.assertEqual(store.get_checked(full_state, 0, 1), []) + self.assertEqual(store.get_missing(empty_state, 0, 1), []) + self.assertEqual(store.get_missing(full_state, 0, 1), []) + self.assertEqual(store.get_remaining(empty_state, 0, 1), []) + self.assertEqual(store.get_remaining(full_state, 0, 1), []) def test_no_locations_for_1(self) -> None: store = self.type({