Core: fix exceptions coming from LocationStore (#4358)
* Speedups: add instructions for ASAN * Speedups: move typevars out of classes * Speedups, NetUtils: raise correct exceptions * Speedups: double-check malloc * Tests: more LocationStore tests
This commit is contained in:
		
							parent
							
								
									f79657b41a
								
							
						
					
					
						commit
						3fb0b57d19
					
				|  | @ -410,6 +410,8 @@ class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tu | ||||||
|         checked = state[team, slot] |         checked = state[team, slot] | ||||||
|         if not checked: |         if not checked: | ||||||
|             # This optimizes the case where everyone connects to a fresh game at the same time. |             # This optimizes the case where everyone connects to a fresh game at the same time. | ||||||
|  |             if slot not in self: | ||||||
|  |                 raise KeyError(slot) | ||||||
|             return [] |             return [] | ||||||
|         return [location_id for |         return [location_id for | ||||||
|                 location_id in self[slot] if |                 location_id in self[slot] if | ||||||
|  |  | ||||||
|  | @ -69,6 +69,14 @@ cdef struct IndexEntry: | ||||||
|     size_t count |     size_t count | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     State = Dict[Tuple[int, int], Set[int]] | ||||||
|  | else: | ||||||
|  |     State = Union[Tuple[int, int], Set[int], defaultdict] | ||||||
|  | 
 | ||||||
|  | T = TypeVar('T') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @cython.auto_pickle(False) | @cython.auto_pickle(False) | ||||||
| cdef class LocationStore: | cdef class LocationStore: | ||||||
|     """Compact store for locations and their items in a MultiServer""" |     """Compact store for locations and their items in a MultiServer""" | ||||||
|  | @ -137,10 +145,16 @@ cdef class LocationStore: | ||||||
|             warnings.warn("Game has no locations") |             warnings.warn("Game has no locations") | ||||||
| 
 | 
 | ||||||
|         # allocate the arrays and invalidate index (0xff...) |         # allocate the arrays and invalidate index (0xff...) | ||||||
|         self.entries = <LocationEntry*>self._mem.alloc(count, sizeof(LocationEntry)) |         if count: | ||||||
|  |             # leaving entries as NULL if there are none, makes potential memory errors more visible | ||||||
|  |             self.entries = <LocationEntry*>self._mem.alloc(count, sizeof(LocationEntry)) | ||||||
|         self.sender_index = <IndexEntry*>self._mem.alloc(max_sender + 1, sizeof(IndexEntry)) |         self.sender_index = <IndexEntry*>self._mem.alloc(max_sender + 1, sizeof(IndexEntry)) | ||||||
|         self._raw_proxies = <PyObject**>self._mem.alloc(max_sender + 1, sizeof(PyObject*)) |         self._raw_proxies = <PyObject**>self._mem.alloc(max_sender + 1, sizeof(PyObject*)) | ||||||
| 
 | 
 | ||||||
|  |         assert (not self.entries) == (not count) | ||||||
|  |         assert self.sender_index | ||||||
|  |         assert self._raw_proxies | ||||||
|  | 
 | ||||||
|         # build entries and index |         # build entries and index | ||||||
|         cdef size_t i = 0 |         cdef size_t i = 0 | ||||||
|         for sender, locations in sorted(locations_dict.items()): |         for sender, locations in sorted(locations_dict.items()): | ||||||
|  | @ -190,8 +204,6 @@ cdef class LocationStore: | ||||||
|             raise KeyError(key) |             raise KeyError(key) | ||||||
|         return <object>self._raw_proxies[key] |         return <object>self._raw_proxies[key] | ||||||
| 
 | 
 | ||||||
|     T = TypeVar('T') |  | ||||||
| 
 |  | ||||||
|     def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]: |     def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]: | ||||||
|         # calling into self.__getitem__ here is slow, but this is not used in MultiServer |         # calling into self.__getitem__ here is slow, but this is not used in MultiServer | ||||||
|         try: |         try: | ||||||
|  | @ -246,12 +258,11 @@ cdef class LocationStore: | ||||||
|                         all_locations[sender].add(entry.location) |                         all_locations[sender].add(entry.location) | ||||||
|         return all_locations |         return all_locations | ||||||
| 
 | 
 | ||||||
|     if TYPE_CHECKING: |  | ||||||
|         State = Dict[Tuple[int, int], Set[int]] |  | ||||||
|     else: |  | ||||||
|         State = Union[Tuple[int, int], Set[int], defaultdict] |  | ||||||
| 
 |  | ||||||
|     def get_checked(self, state: State, team: int, slot: int) -> List[int]: |     def get_checked(self, state: State, team: int, slot: int) -> List[int]: | ||||||
|  |         cdef ap_player_t sender = slot | ||||||
|  |         if sender < 0 or sender >= self.sender_index_size: | ||||||
|  |             raise KeyError(slot) | ||||||
|  | 
 | ||||||
|         # This used to validate checks actually exist. A remnant from the past. |         # This used to validate checks actually exist. A remnant from the past. | ||||||
|         # If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it. |         # If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it. | ||||||
|         cdef set checked = state[team, slot] |         cdef set checked = state[team, slot] | ||||||
|  | @ -263,7 +274,6 @@ cdef class LocationStore: | ||||||
| 
 | 
 | ||||||
|         # Unless the set is close to empty, it's cheaper to use the python set directly, so we do that. |         # Unless the set is close to empty, it's cheaper to use the python set directly, so we do that. | ||||||
|         cdef LocationEntry* entry |         cdef LocationEntry* entry | ||||||
|         cdef ap_player_t sender = slot |  | ||||||
|         cdef size_t start = self.sender_index[sender].start |         cdef size_t start = self.sender_index[sender].start | ||||||
|         cdef size_t count = self.sender_index[sender].count |         cdef size_t count = self.sender_index[sender].count | ||||||
|         return [entry.location for |         return [entry.location for | ||||||
|  | @ -273,9 +283,11 @@ cdef class LocationStore: | ||||||
|     def get_missing(self, state: State, team: int, slot: int) -> List[int]: |     def get_missing(self, state: State, team: int, slot: int) -> List[int]: | ||||||
|         cdef LocationEntry* entry |         cdef LocationEntry* entry | ||||||
|         cdef ap_player_t sender = slot |         cdef ap_player_t sender = slot | ||||||
|  |         if sender < 0 or sender >= self.sender_index_size: | ||||||
|  |             raise KeyError(slot) | ||||||
|  |         cdef set checked = state[team, slot] | ||||||
|         cdef size_t start = self.sender_index[sender].start |         cdef size_t start = self.sender_index[sender].start | ||||||
|         cdef size_t count = self.sender_index[sender].count |         cdef size_t count = self.sender_index[sender].count | ||||||
|         cdef set checked = state[team, slot] |  | ||||||
|         if not len(checked): |         if not len(checked): | ||||||
|             # Skip `in` if none have been checked. |             # Skip `in` if none have been checked. | ||||||
|             # This optimizes the case where everyone connects to a fresh game at the same time. |             # This optimizes the case where everyone connects to a fresh game at the same time. | ||||||
|  | @ -290,9 +302,11 @@ cdef class LocationStore: | ||||||
|     def get_remaining(self, state: State, team: int, slot: int) -> List[Tuple[int, int]]: |     def get_remaining(self, state: State, team: int, slot: int) -> List[Tuple[int, int]]: | ||||||
|         cdef LocationEntry* entry |         cdef LocationEntry* entry | ||||||
|         cdef ap_player_t sender = slot |         cdef ap_player_t sender = slot | ||||||
|  |         if sender < 0 or sender >= self.sender_index_size: | ||||||
|  |             raise KeyError(slot) | ||||||
|  |         cdef set checked = state[team, slot] | ||||||
|         cdef size_t start = self.sender_index[sender].start |         cdef size_t start = self.sender_index[sender].start | ||||||
|         cdef size_t count = self.sender_index[sender].count |         cdef size_t count = self.sender_index[sender].count | ||||||
|         cdef set checked = state[team, slot] |  | ||||||
|         return sorted([(entry.receiver, entry.item) for |         return sorted([(entry.receiver, entry.item) for | ||||||
|                         entry in self.entries[start:start+count] if |                         entry in self.entries[start:start+count] if | ||||||
|                         entry.location not in checked]) |                         entry.location not in checked]) | ||||||
|  | @ -328,7 +342,8 @@ cdef class PlayerLocationProxy: | ||||||
|         cdef LocationEntry* entry = NULL |         cdef LocationEntry* entry = NULL | ||||||
|         # binary search |         # binary search | ||||||
|         cdef size_t l = self._store.sender_index[self._player].start |         cdef size_t l = self._store.sender_index[self._player].start | ||||||
|         cdef size_t r = l + self._store.sender_index[self._player].count |         cdef size_t e = l + self._store.sender_index[self._player].count | ||||||
|  |         cdef size_t r = e | ||||||
|         cdef size_t m |         cdef size_t m | ||||||
|         while l < r: |         while l < r: | ||||||
|             m = (l + r) // 2 |             m = (l + r) // 2 | ||||||
|  | @ -337,7 +352,7 @@ cdef class PlayerLocationProxy: | ||||||
|                 l = m + 1 |                 l = m + 1 | ||||||
|             else: |             else: | ||||||
|                 r = m |                 r = m | ||||||
|         if entry:  # count != 0 |         if l < e: | ||||||
|             entry = self._store.entries + l |             entry = self._store.entries + l | ||||||
|             if entry.location == loc: |             if entry.location == loc: | ||||||
|                 return entry |                 return entry | ||||||
|  | @ -349,8 +364,6 @@ cdef class PlayerLocationProxy: | ||||||
|             return entry.item, entry.receiver, entry.flags |             return entry.item, entry.receiver, entry.flags | ||||||
|         raise KeyError(f"No location {key} for player {self._player}") |         raise KeyError(f"No location {key} for player {self._player}") | ||||||
| 
 | 
 | ||||||
|     T = TypeVar('T') |  | ||||||
| 
 |  | ||||||
|     def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]: |     def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]: | ||||||
|         cdef LocationEntry* entry = self._get(key) |         cdef LocationEntry* entry = self._get(key) | ||||||
|         if entry: |         if entry: | ||||||
|  |  | ||||||
|  | @ -3,8 +3,16 @@ import os | ||||||
| 
 | 
 | ||||||
| def make_ext(modname, pyxfilename): | def make_ext(modname, pyxfilename): | ||||||
|     from distutils.extension import Extension |     from distutils.extension import Extension | ||||||
|     return Extension(name=modname, |     return Extension( | ||||||
|                      sources=[pyxfilename], |         name=modname, | ||||||
|                      depends=["intset.h"], |         sources=[pyxfilename], | ||||||
|                      include_dirs=[os.getcwd()], |         depends=["intset.h"], | ||||||
|                      language="c") |         include_dirs=[os.getcwd()], | ||||||
|  |         language="c", | ||||||
|  |         # to enable ASAN and debug build: | ||||||
|  |         # extra_compile_args=["-fsanitize=address", "-UNDEBUG", "-Og", "-g"], | ||||||
|  |         # extra_objects=["-fsanitize=address"], | ||||||
|  |         # NOTE: we can not put -lasan at the front of link args, so needs to be run with | ||||||
|  |         #       LD_PRELOAD=/usr/lib/libasan.so ASAN_OPTIONS=detect_leaks=0 path/to/exe | ||||||
|  |         # NOTE: this can't find everything unless libpython and cymem are also built with ASAN | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  | @ -115,6 +115,7 @@ class Base: | ||||||
|         def test_get_for_player(self) -> None: |         def test_get_for_player(self) -> None: | ||||||
|             self.assertEqual(self.store.get_for_player(3), {4: {9}}) |             self.assertEqual(self.store.get_for_player(3), {4: {9}}) | ||||||
|             self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}}) |             self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}}) | ||||||
|  |             self.assertEqual(self.store.get_for_player(9999), {}) | ||||||
| 
 | 
 | ||||||
|         def test_get_checked(self) -> None: |         def test_get_checked(self) -> None: | ||||||
|             self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13]) |             self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13]) | ||||||
|  | @ -122,18 +123,48 @@ class Base: | ||||||
|             self.assertEqual(self.store.get_checked(empty_state, 0, 1), []) |             self.assertEqual(self.store.get_checked(empty_state, 0, 1), []) | ||||||
|             self.assertEqual(self.store.get_checked(full_state, 0, 3), [9]) |             self.assertEqual(self.store.get_checked(full_state, 0, 3), [9]) | ||||||
| 
 | 
 | ||||||
|  |         def test_get_checked_exception(self) -> None: | ||||||
|  |             with self.assertRaises(KeyError): | ||||||
|  |                 self.store.get_checked(empty_state, 0, 9999) | ||||||
|  |             bad_state = {(0, 6): {1}} | ||||||
|  |             with self.assertRaises(KeyError): | ||||||
|  |                 self.store.get_checked(bad_state, 0, 6) | ||||||
|  |             bad_state = {(0, 9999): set()} | ||||||
|  |             with self.assertRaises(KeyError): | ||||||
|  |                 self.store.get_checked(bad_state, 0, 9999) | ||||||
|  | 
 | ||||||
|         def test_get_missing(self) -> None: |         def test_get_missing(self) -> None: | ||||||
|             self.assertEqual(self.store.get_missing(full_state, 0, 1), []) |             self.assertEqual(self.store.get_missing(full_state, 0, 1), []) | ||||||
|             self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13]) |             self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13]) | ||||||
|             self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13]) |             self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13]) | ||||||
|             self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9]) |             self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9]) | ||||||
| 
 | 
 | ||||||
|  |         def test_get_missing_exception(self) -> None: | ||||||
|  |             with self.assertRaises(KeyError): | ||||||
|  |                 self.store.get_missing(empty_state, 0, 9999) | ||||||
|  |             bad_state = {(0, 6): {1}} | ||||||
|  |             with self.assertRaises(KeyError): | ||||||
|  |                 self.store.get_missing(bad_state, 0, 6) | ||||||
|  |             bad_state = {(0, 9999): set()} | ||||||
|  |             with self.assertRaises(KeyError): | ||||||
|  |                 self.store.get_missing(bad_state, 0, 9999) | ||||||
|  | 
 | ||||||
|         def test_get_remaining(self) -> None: |         def test_get_remaining(self) -> None: | ||||||
|             self.assertEqual(self.store.get_remaining(full_state, 0, 1), []) |             self.assertEqual(self.store.get_remaining(full_state, 0, 1), []) | ||||||
|             self.assertEqual(self.store.get_remaining(one_state, 0, 1), [(1, 13), (2, 21)]) |             self.assertEqual(self.store.get_remaining(one_state, 0, 1), [(1, 13), (2, 21)]) | ||||||
|             self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [(1, 13), (2, 21), (2, 22)]) |             self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [(1, 13), (2, 21), (2, 22)]) | ||||||
|             self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [(4, 99)]) |             self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [(4, 99)]) | ||||||
| 
 | 
 | ||||||
|  |         def test_get_remaining_exception(self) -> None: | ||||||
|  |             with self.assertRaises(KeyError): | ||||||
|  |                 self.store.get_remaining(empty_state, 0, 9999) | ||||||
|  |             bad_state = {(0, 6): {1}} | ||||||
|  |             with self.assertRaises(KeyError): | ||||||
|  |                 self.store.get_missing(bad_state, 0, 6) | ||||||
|  |             bad_state = {(0, 9999): set()} | ||||||
|  |             with self.assertRaises(KeyError): | ||||||
|  |                 self.store.get_remaining(bad_state, 0, 9999) | ||||||
|  | 
 | ||||||
|         def test_location_set_intersection(self) -> None: |         def test_location_set_intersection(self) -> None: | ||||||
|             locations = {10, 11, 12} |             locations = {10, 11, 12} | ||||||
|             locations.intersection_update(self.store[1]) |             locations.intersection_update(self.store[1]) | ||||||
|  | @ -181,6 +212,16 @@ class Base: | ||||||
|                 }) |                 }) | ||||||
|                 self.assertEqual(len(store), 1) |                 self.assertEqual(len(store), 1) | ||||||
|                 self.assertEqual(len(store[1]), 0) |                 self.assertEqual(len(store[1]), 0) | ||||||
|  |                 self.assertEqual(sorted(store.find_item(set(), 1)), []) | ||||||
|  |                 self.assertEqual(sorted(store.find_item({1}, 1)), []) | ||||||
|  |                 self.assertEqual(sorted(store.find_item({1, 2}, 1)), []) | ||||||
|  |                 self.assertEqual(store.get_for_player(1), {}) | ||||||
|  |                 self.assertEqual(store.get_checked(empty_state, 0, 1), []) | ||||||
|  |                 self.assertEqual(store.get_checked(full_state, 0, 1), []) | ||||||
|  |                 self.assertEqual(store.get_missing(empty_state, 0, 1), []) | ||||||
|  |                 self.assertEqual(store.get_missing(full_state, 0, 1), []) | ||||||
|  |                 self.assertEqual(store.get_remaining(empty_state, 0, 1), []) | ||||||
|  |                 self.assertEqual(store.get_remaining(full_state, 0, 1), []) | ||||||
| 
 | 
 | ||||||
|         def test_no_locations_for_1(self) -> None: |         def test_no_locations_for_1(self) -> None: | ||||||
|             store = self.type({ |             store = self.type({ | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue