Core: fix exceptions coming from LocationStore (#4358)
* Speedups: add instructions for ASAN * Speedups: move typevars out of classes * Speedups, NetUtils: raise correct exceptions * Speedups: double-check malloc * Tests: more LocationStore tests
This commit is contained in:
parent
f79657b41a
commit
3fb0b57d19
|
@ -410,6 +410,8 @@ class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tu
|
||||||
checked = state[team, slot]
|
checked = state[team, slot]
|
||||||
if not checked:
|
if not checked:
|
||||||
# This optimizes the case where everyone connects to a fresh game at the same time.
|
# This optimizes the case where everyone connects to a fresh game at the same time.
|
||||||
|
if slot not in self:
|
||||||
|
raise KeyError(slot)
|
||||||
return []
|
return []
|
||||||
return [location_id for
|
return [location_id for
|
||||||
location_id in self[slot] if
|
location_id in self[slot] if
|
||||||
|
|
|
@ -69,6 +69,14 @@ cdef struct IndexEntry:
|
||||||
size_t count
|
size_t count
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
State = Dict[Tuple[int, int], Set[int]]
|
||||||
|
else:
|
||||||
|
State = Union[Tuple[int, int], Set[int], defaultdict]
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
@cython.auto_pickle(False)
|
@cython.auto_pickle(False)
|
||||||
cdef class LocationStore:
|
cdef class LocationStore:
|
||||||
"""Compact store for locations and their items in a MultiServer"""
|
"""Compact store for locations and their items in a MultiServer"""
|
||||||
|
@ -137,10 +145,16 @@ cdef class LocationStore:
|
||||||
warnings.warn("Game has no locations")
|
warnings.warn("Game has no locations")
|
||||||
|
|
||||||
# allocate the arrays and invalidate index (0xff...)
|
# allocate the arrays and invalidate index (0xff...)
|
||||||
|
if count:
|
||||||
|
# leaving entries as NULL if there are none, makes potential memory errors more visible
|
||||||
self.entries = <LocationEntry*>self._mem.alloc(count, sizeof(LocationEntry))
|
self.entries = <LocationEntry*>self._mem.alloc(count, sizeof(LocationEntry))
|
||||||
self.sender_index = <IndexEntry*>self._mem.alloc(max_sender + 1, sizeof(IndexEntry))
|
self.sender_index = <IndexEntry*>self._mem.alloc(max_sender + 1, sizeof(IndexEntry))
|
||||||
self._raw_proxies = <PyObject**>self._mem.alloc(max_sender + 1, sizeof(PyObject*))
|
self._raw_proxies = <PyObject**>self._mem.alloc(max_sender + 1, sizeof(PyObject*))
|
||||||
|
|
||||||
|
assert (not self.entries) == (not count)
|
||||||
|
assert self.sender_index
|
||||||
|
assert self._raw_proxies
|
||||||
|
|
||||||
# build entries and index
|
# build entries and index
|
||||||
cdef size_t i = 0
|
cdef size_t i = 0
|
||||||
for sender, locations in sorted(locations_dict.items()):
|
for sender, locations in sorted(locations_dict.items()):
|
||||||
|
@ -190,8 +204,6 @@ cdef class LocationStore:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
return <object>self._raw_proxies[key]
|
return <object>self._raw_proxies[key]
|
||||||
|
|
||||||
T = TypeVar('T')
|
|
||||||
|
|
||||||
def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]:
|
def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]:
|
||||||
# calling into self.__getitem__ here is slow, but this is not used in MultiServer
|
# calling into self.__getitem__ here is slow, but this is not used in MultiServer
|
||||||
try:
|
try:
|
||||||
|
@ -246,12 +258,11 @@ cdef class LocationStore:
|
||||||
all_locations[sender].add(entry.location)
|
all_locations[sender].add(entry.location)
|
||||||
return all_locations
|
return all_locations
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
State = Dict[Tuple[int, int], Set[int]]
|
|
||||||
else:
|
|
||||||
State = Union[Tuple[int, int], Set[int], defaultdict]
|
|
||||||
|
|
||||||
def get_checked(self, state: State, team: int, slot: int) -> List[int]:
|
def get_checked(self, state: State, team: int, slot: int) -> List[int]:
|
||||||
|
cdef ap_player_t sender = slot
|
||||||
|
if sender < 0 or sender >= self.sender_index_size:
|
||||||
|
raise KeyError(slot)
|
||||||
|
|
||||||
# This used to validate checks actually exist. A remnant from the past.
|
# This used to validate checks actually exist. A remnant from the past.
|
||||||
# If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it.
|
# If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it.
|
||||||
cdef set checked = state[team, slot]
|
cdef set checked = state[team, slot]
|
||||||
|
@ -263,7 +274,6 @@ cdef class LocationStore:
|
||||||
|
|
||||||
# Unless the set is close to empty, it's cheaper to use the python set directly, so we do that.
|
# Unless the set is close to empty, it's cheaper to use the python set directly, so we do that.
|
||||||
cdef LocationEntry* entry
|
cdef LocationEntry* entry
|
||||||
cdef ap_player_t sender = slot
|
|
||||||
cdef size_t start = self.sender_index[sender].start
|
cdef size_t start = self.sender_index[sender].start
|
||||||
cdef size_t count = self.sender_index[sender].count
|
cdef size_t count = self.sender_index[sender].count
|
||||||
return [entry.location for
|
return [entry.location for
|
||||||
|
@ -273,9 +283,11 @@ cdef class LocationStore:
|
||||||
def get_missing(self, state: State, team: int, slot: int) -> List[int]:
|
def get_missing(self, state: State, team: int, slot: int) -> List[int]:
|
||||||
cdef LocationEntry* entry
|
cdef LocationEntry* entry
|
||||||
cdef ap_player_t sender = slot
|
cdef ap_player_t sender = slot
|
||||||
|
if sender < 0 or sender >= self.sender_index_size:
|
||||||
|
raise KeyError(slot)
|
||||||
|
cdef set checked = state[team, slot]
|
||||||
cdef size_t start = self.sender_index[sender].start
|
cdef size_t start = self.sender_index[sender].start
|
||||||
cdef size_t count = self.sender_index[sender].count
|
cdef size_t count = self.sender_index[sender].count
|
||||||
cdef set checked = state[team, slot]
|
|
||||||
if not len(checked):
|
if not len(checked):
|
||||||
# Skip `in` if none have been checked.
|
# Skip `in` if none have been checked.
|
||||||
# This optimizes the case where everyone connects to a fresh game at the same time.
|
# This optimizes the case where everyone connects to a fresh game at the same time.
|
||||||
|
@ -290,9 +302,11 @@ cdef class LocationStore:
|
||||||
def get_remaining(self, state: State, team: int, slot: int) -> List[Tuple[int, int]]:
|
def get_remaining(self, state: State, team: int, slot: int) -> List[Tuple[int, int]]:
|
||||||
cdef LocationEntry* entry
|
cdef LocationEntry* entry
|
||||||
cdef ap_player_t sender = slot
|
cdef ap_player_t sender = slot
|
||||||
|
if sender < 0 or sender >= self.sender_index_size:
|
||||||
|
raise KeyError(slot)
|
||||||
|
cdef set checked = state[team, slot]
|
||||||
cdef size_t start = self.sender_index[sender].start
|
cdef size_t start = self.sender_index[sender].start
|
||||||
cdef size_t count = self.sender_index[sender].count
|
cdef size_t count = self.sender_index[sender].count
|
||||||
cdef set checked = state[team, slot]
|
|
||||||
return sorted([(entry.receiver, entry.item) for
|
return sorted([(entry.receiver, entry.item) for
|
||||||
entry in self.entries[start:start+count] if
|
entry in self.entries[start:start+count] if
|
||||||
entry.location not in checked])
|
entry.location not in checked])
|
||||||
|
@ -328,7 +342,8 @@ cdef class PlayerLocationProxy:
|
||||||
cdef LocationEntry* entry = NULL
|
cdef LocationEntry* entry = NULL
|
||||||
# binary search
|
# binary search
|
||||||
cdef size_t l = self._store.sender_index[self._player].start
|
cdef size_t l = self._store.sender_index[self._player].start
|
||||||
cdef size_t r = l + self._store.sender_index[self._player].count
|
cdef size_t e = l + self._store.sender_index[self._player].count
|
||||||
|
cdef size_t r = e
|
||||||
cdef size_t m
|
cdef size_t m
|
||||||
while l < r:
|
while l < r:
|
||||||
m = (l + r) // 2
|
m = (l + r) // 2
|
||||||
|
@ -337,7 +352,7 @@ cdef class PlayerLocationProxy:
|
||||||
l = m + 1
|
l = m + 1
|
||||||
else:
|
else:
|
||||||
r = m
|
r = m
|
||||||
if entry: # count != 0
|
if l < e:
|
||||||
entry = self._store.entries + l
|
entry = self._store.entries + l
|
||||||
if entry.location == loc:
|
if entry.location == loc:
|
||||||
return entry
|
return entry
|
||||||
|
@ -349,8 +364,6 @@ cdef class PlayerLocationProxy:
|
||||||
return entry.item, entry.receiver, entry.flags
|
return entry.item, entry.receiver, entry.flags
|
||||||
raise KeyError(f"No location {key} for player {self._player}")
|
raise KeyError(f"No location {key} for player {self._player}")
|
||||||
|
|
||||||
T = TypeVar('T')
|
|
||||||
|
|
||||||
def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]:
|
def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]:
|
||||||
cdef LocationEntry* entry = self._get(key)
|
cdef LocationEntry* entry = self._get(key)
|
||||||
if entry:
|
if entry:
|
||||||
|
|
|
@ -3,8 +3,16 @@ import os
|
||||||
|
|
||||||
def make_ext(modname, pyxfilename):
|
def make_ext(modname, pyxfilename):
|
||||||
from distutils.extension import Extension
|
from distutils.extension import Extension
|
||||||
return Extension(name=modname,
|
return Extension(
|
||||||
|
name=modname,
|
||||||
sources=[pyxfilename],
|
sources=[pyxfilename],
|
||||||
depends=["intset.h"],
|
depends=["intset.h"],
|
||||||
include_dirs=[os.getcwd()],
|
include_dirs=[os.getcwd()],
|
||||||
language="c")
|
language="c",
|
||||||
|
# to enable ASAN and debug build:
|
||||||
|
# extra_compile_args=["-fsanitize=address", "-UNDEBUG", "-Og", "-g"],
|
||||||
|
# extra_objects=["-fsanitize=address"],
|
||||||
|
# NOTE: we can not put -lasan at the front of link args, so needs to be run with
|
||||||
|
# LD_PRELOAD=/usr/lib/libasan.so ASAN_OPTIONS=detect_leaks=0 path/to/exe
|
||||||
|
# NOTE: this can't find everything unless libpython and cymem are also built with ASAN
|
||||||
|
)
|
||||||
|
|
|
@ -115,6 +115,7 @@ class Base:
|
||||||
def test_get_for_player(self) -> None:
|
def test_get_for_player(self) -> None:
|
||||||
self.assertEqual(self.store.get_for_player(3), {4: {9}})
|
self.assertEqual(self.store.get_for_player(3), {4: {9}})
|
||||||
self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}})
|
self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}})
|
||||||
|
self.assertEqual(self.store.get_for_player(9999), {})
|
||||||
|
|
||||||
def test_get_checked(self) -> None:
|
def test_get_checked(self) -> None:
|
||||||
self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13])
|
self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13])
|
||||||
|
@ -122,18 +123,48 @@ class Base:
|
||||||
self.assertEqual(self.store.get_checked(empty_state, 0, 1), [])
|
self.assertEqual(self.store.get_checked(empty_state, 0, 1), [])
|
||||||
self.assertEqual(self.store.get_checked(full_state, 0, 3), [9])
|
self.assertEqual(self.store.get_checked(full_state, 0, 3), [9])
|
||||||
|
|
||||||
|
def test_get_checked_exception(self) -> None:
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
self.store.get_checked(empty_state, 0, 9999)
|
||||||
|
bad_state = {(0, 6): {1}}
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
self.store.get_checked(bad_state, 0, 6)
|
||||||
|
bad_state = {(0, 9999): set()}
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
self.store.get_checked(bad_state, 0, 9999)
|
||||||
|
|
||||||
def test_get_missing(self) -> None:
|
def test_get_missing(self) -> None:
|
||||||
self.assertEqual(self.store.get_missing(full_state, 0, 1), [])
|
self.assertEqual(self.store.get_missing(full_state, 0, 1), [])
|
||||||
self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13])
|
self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13])
|
||||||
self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13])
|
self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13])
|
||||||
self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9])
|
self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9])
|
||||||
|
|
||||||
|
def test_get_missing_exception(self) -> None:
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
self.store.get_missing(empty_state, 0, 9999)
|
||||||
|
bad_state = {(0, 6): {1}}
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
self.store.get_missing(bad_state, 0, 6)
|
||||||
|
bad_state = {(0, 9999): set()}
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
self.store.get_missing(bad_state, 0, 9999)
|
||||||
|
|
||||||
def test_get_remaining(self) -> None:
|
def test_get_remaining(self) -> None:
|
||||||
self.assertEqual(self.store.get_remaining(full_state, 0, 1), [])
|
self.assertEqual(self.store.get_remaining(full_state, 0, 1), [])
|
||||||
self.assertEqual(self.store.get_remaining(one_state, 0, 1), [(1, 13), (2, 21)])
|
self.assertEqual(self.store.get_remaining(one_state, 0, 1), [(1, 13), (2, 21)])
|
||||||
self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [(1, 13), (2, 21), (2, 22)])
|
self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [(1, 13), (2, 21), (2, 22)])
|
||||||
self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [(4, 99)])
|
self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [(4, 99)])
|
||||||
|
|
||||||
|
def test_get_remaining_exception(self) -> None:
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
self.store.get_remaining(empty_state, 0, 9999)
|
||||||
|
bad_state = {(0, 6): {1}}
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
self.store.get_missing(bad_state, 0, 6)
|
||||||
|
bad_state = {(0, 9999): set()}
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
self.store.get_remaining(bad_state, 0, 9999)
|
||||||
|
|
||||||
def test_location_set_intersection(self) -> None:
|
def test_location_set_intersection(self) -> None:
|
||||||
locations = {10, 11, 12}
|
locations = {10, 11, 12}
|
||||||
locations.intersection_update(self.store[1])
|
locations.intersection_update(self.store[1])
|
||||||
|
@ -181,6 +212,16 @@ class Base:
|
||||||
})
|
})
|
||||||
self.assertEqual(len(store), 1)
|
self.assertEqual(len(store), 1)
|
||||||
self.assertEqual(len(store[1]), 0)
|
self.assertEqual(len(store[1]), 0)
|
||||||
|
self.assertEqual(sorted(store.find_item(set(), 1)), [])
|
||||||
|
self.assertEqual(sorted(store.find_item({1}, 1)), [])
|
||||||
|
self.assertEqual(sorted(store.find_item({1, 2}, 1)), [])
|
||||||
|
self.assertEqual(store.get_for_player(1), {})
|
||||||
|
self.assertEqual(store.get_checked(empty_state, 0, 1), [])
|
||||||
|
self.assertEqual(store.get_checked(full_state, 0, 1), [])
|
||||||
|
self.assertEqual(store.get_missing(empty_state, 0, 1), [])
|
||||||
|
self.assertEqual(store.get_missing(full_state, 0, 1), [])
|
||||||
|
self.assertEqual(store.get_remaining(empty_state, 0, 1), [])
|
||||||
|
self.assertEqual(store.get_remaining(full_state, 0, 1), [])
|
||||||
|
|
||||||
def test_no_locations_for_1(self) -> None:
|
def test_no_locations_for_1(self) -> None:
|
||||||
store = self.type({
|
store = self.type({
|
||||||
|
|
Loading…
Reference in New Issue