Core: fix exceptions coming from LocationStore (#4358)
* Speedups: add instructions for ASAN * Speedups: move typevars out of classes * Speedups, NetUtils: raise correct exceptions * Speedups: double-check malloc * Tests: more LocationStore tests
This commit is contained in:
		
							parent
							
								
									f79657b41a
								
							
						
					
					
						commit
						3fb0b57d19
					
				| 
						 | 
					@ -410,6 +410,8 @@ class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tu
 | 
				
			||||||
        checked = state[team, slot]
 | 
					        checked = state[team, slot]
 | 
				
			||||||
        if not checked:
 | 
					        if not checked:
 | 
				
			||||||
            # This optimizes the case where everyone connects to a fresh game at the same time.
 | 
					            # This optimizes the case where everyone connects to a fresh game at the same time.
 | 
				
			||||||
 | 
					            if slot not in self:
 | 
				
			||||||
 | 
					                raise KeyError(slot)
 | 
				
			||||||
            return []
 | 
					            return []
 | 
				
			||||||
        return [location_id for
 | 
					        return [location_id for
 | 
				
			||||||
                location_id in self[slot] if
 | 
					                location_id in self[slot] if
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -69,6 +69,14 @@ cdef struct IndexEntry:
 | 
				
			||||||
    size_t count
 | 
					    size_t count
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
 | 
					    State = Dict[Tuple[int, int], Set[int]]
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    State = Union[Tuple[int, int], Set[int], defaultdict]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					T = TypeVar('T')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@cython.auto_pickle(False)
 | 
					@cython.auto_pickle(False)
 | 
				
			||||||
cdef class LocationStore:
 | 
					cdef class LocationStore:
 | 
				
			||||||
    """Compact store for locations and their items in a MultiServer"""
 | 
					    """Compact store for locations and their items in a MultiServer"""
 | 
				
			||||||
| 
						 | 
					@ -137,10 +145,16 @@ cdef class LocationStore:
 | 
				
			||||||
            warnings.warn("Game has no locations")
 | 
					            warnings.warn("Game has no locations")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # allocate the arrays and invalidate index (0xff...)
 | 
					        # allocate the arrays and invalidate index (0xff...)
 | 
				
			||||||
        self.entries = <LocationEntry*>self._mem.alloc(count, sizeof(LocationEntry))
 | 
					        if count:
 | 
				
			||||||
 | 
					            # leaving entries as NULL if there are none, makes potential memory errors more visible
 | 
				
			||||||
 | 
					            self.entries = <LocationEntry*>self._mem.alloc(count, sizeof(LocationEntry))
 | 
				
			||||||
        self.sender_index = <IndexEntry*>self._mem.alloc(max_sender + 1, sizeof(IndexEntry))
 | 
					        self.sender_index = <IndexEntry*>self._mem.alloc(max_sender + 1, sizeof(IndexEntry))
 | 
				
			||||||
        self._raw_proxies = <PyObject**>self._mem.alloc(max_sender + 1, sizeof(PyObject*))
 | 
					        self._raw_proxies = <PyObject**>self._mem.alloc(max_sender + 1, sizeof(PyObject*))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert (not self.entries) == (not count)
 | 
				
			||||||
 | 
					        assert self.sender_index
 | 
				
			||||||
 | 
					        assert self._raw_proxies
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # build entries and index
 | 
					        # build entries and index
 | 
				
			||||||
        cdef size_t i = 0
 | 
					        cdef size_t i = 0
 | 
				
			||||||
        for sender, locations in sorted(locations_dict.items()):
 | 
					        for sender, locations in sorted(locations_dict.items()):
 | 
				
			||||||
| 
						 | 
					@ -190,8 +204,6 @@ cdef class LocationStore:
 | 
				
			||||||
            raise KeyError(key)
 | 
					            raise KeyError(key)
 | 
				
			||||||
        return <object>self._raw_proxies[key]
 | 
					        return <object>self._raw_proxies[key]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    T = TypeVar('T')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]:
 | 
					    def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]:
 | 
				
			||||||
        # calling into self.__getitem__ here is slow, but this is not used in MultiServer
 | 
					        # calling into self.__getitem__ here is slow, but this is not used in MultiServer
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -246,12 +258,11 @@ cdef class LocationStore:
 | 
				
			||||||
                        all_locations[sender].add(entry.location)
 | 
					                        all_locations[sender].add(entry.location)
 | 
				
			||||||
        return all_locations
 | 
					        return all_locations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if TYPE_CHECKING:
 | 
					 | 
				
			||||||
        State = Dict[Tuple[int, int], Set[int]]
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        State = Union[Tuple[int, int], Set[int], defaultdict]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_checked(self, state: State, team: int, slot: int) -> List[int]:
 | 
					    def get_checked(self, state: State, team: int, slot: int) -> List[int]:
 | 
				
			||||||
 | 
					        cdef ap_player_t sender = slot
 | 
				
			||||||
 | 
					        if sender < 0 or sender >= self.sender_index_size:
 | 
				
			||||||
 | 
					            raise KeyError(slot)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # This used to validate checks actually exist. A remnant from the past.
 | 
					        # This used to validate checks actually exist. A remnant from the past.
 | 
				
			||||||
        # If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it.
 | 
					        # If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it.
 | 
				
			||||||
        cdef set checked = state[team, slot]
 | 
					        cdef set checked = state[team, slot]
 | 
				
			||||||
| 
						 | 
					@ -263,7 +274,6 @@ cdef class LocationStore:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Unless the set is close to empty, it's cheaper to use the python set directly, so we do that.
 | 
					        # Unless the set is close to empty, it's cheaper to use the python set directly, so we do that.
 | 
				
			||||||
        cdef LocationEntry* entry
 | 
					        cdef LocationEntry* entry
 | 
				
			||||||
        cdef ap_player_t sender = slot
 | 
					 | 
				
			||||||
        cdef size_t start = self.sender_index[sender].start
 | 
					        cdef size_t start = self.sender_index[sender].start
 | 
				
			||||||
        cdef size_t count = self.sender_index[sender].count
 | 
					        cdef size_t count = self.sender_index[sender].count
 | 
				
			||||||
        return [entry.location for
 | 
					        return [entry.location for
 | 
				
			||||||
| 
						 | 
					@ -273,9 +283,11 @@ cdef class LocationStore:
 | 
				
			||||||
    def get_missing(self, state: State, team: int, slot: int) -> List[int]:
 | 
					    def get_missing(self, state: State, team: int, slot: int) -> List[int]:
 | 
				
			||||||
        cdef LocationEntry* entry
 | 
					        cdef LocationEntry* entry
 | 
				
			||||||
        cdef ap_player_t sender = slot
 | 
					        cdef ap_player_t sender = slot
 | 
				
			||||||
 | 
					        if sender < 0 or sender >= self.sender_index_size:
 | 
				
			||||||
 | 
					            raise KeyError(slot)
 | 
				
			||||||
 | 
					        cdef set checked = state[team, slot]
 | 
				
			||||||
        cdef size_t start = self.sender_index[sender].start
 | 
					        cdef size_t start = self.sender_index[sender].start
 | 
				
			||||||
        cdef size_t count = self.sender_index[sender].count
 | 
					        cdef size_t count = self.sender_index[sender].count
 | 
				
			||||||
        cdef set checked = state[team, slot]
 | 
					 | 
				
			||||||
        if not len(checked):
 | 
					        if not len(checked):
 | 
				
			||||||
            # Skip `in` if none have been checked.
 | 
					            # Skip `in` if none have been checked.
 | 
				
			||||||
            # This optimizes the case where everyone connects to a fresh game at the same time.
 | 
					            # This optimizes the case where everyone connects to a fresh game at the same time.
 | 
				
			||||||
| 
						 | 
					@ -290,9 +302,11 @@ cdef class LocationStore:
 | 
				
			||||||
    def get_remaining(self, state: State, team: int, slot: int) -> List[Tuple[int, int]]:
 | 
					    def get_remaining(self, state: State, team: int, slot: int) -> List[Tuple[int, int]]:
 | 
				
			||||||
        cdef LocationEntry* entry
 | 
					        cdef LocationEntry* entry
 | 
				
			||||||
        cdef ap_player_t sender = slot
 | 
					        cdef ap_player_t sender = slot
 | 
				
			||||||
 | 
					        if sender < 0 or sender >= self.sender_index_size:
 | 
				
			||||||
 | 
					            raise KeyError(slot)
 | 
				
			||||||
 | 
					        cdef set checked = state[team, slot]
 | 
				
			||||||
        cdef size_t start = self.sender_index[sender].start
 | 
					        cdef size_t start = self.sender_index[sender].start
 | 
				
			||||||
        cdef size_t count = self.sender_index[sender].count
 | 
					        cdef size_t count = self.sender_index[sender].count
 | 
				
			||||||
        cdef set checked = state[team, slot]
 | 
					 | 
				
			||||||
        return sorted([(entry.receiver, entry.item) for
 | 
					        return sorted([(entry.receiver, entry.item) for
 | 
				
			||||||
                        entry in self.entries[start:start+count] if
 | 
					                        entry in self.entries[start:start+count] if
 | 
				
			||||||
                        entry.location not in checked])
 | 
					                        entry.location not in checked])
 | 
				
			||||||
| 
						 | 
					@ -328,7 +342,8 @@ cdef class PlayerLocationProxy:
 | 
				
			||||||
        cdef LocationEntry* entry = NULL
 | 
					        cdef LocationEntry* entry = NULL
 | 
				
			||||||
        # binary search
 | 
					        # binary search
 | 
				
			||||||
        cdef size_t l = self._store.sender_index[self._player].start
 | 
					        cdef size_t l = self._store.sender_index[self._player].start
 | 
				
			||||||
        cdef size_t r = l + self._store.sender_index[self._player].count
 | 
					        cdef size_t e = l + self._store.sender_index[self._player].count
 | 
				
			||||||
 | 
					        cdef size_t r = e
 | 
				
			||||||
        cdef size_t m
 | 
					        cdef size_t m
 | 
				
			||||||
        while l < r:
 | 
					        while l < r:
 | 
				
			||||||
            m = (l + r) // 2
 | 
					            m = (l + r) // 2
 | 
				
			||||||
| 
						 | 
					@ -337,7 +352,7 @@ cdef class PlayerLocationProxy:
 | 
				
			||||||
                l = m + 1
 | 
					                l = m + 1
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                r = m
 | 
					                r = m
 | 
				
			||||||
        if entry:  # count != 0
 | 
					        if l < e:
 | 
				
			||||||
            entry = self._store.entries + l
 | 
					            entry = self._store.entries + l
 | 
				
			||||||
            if entry.location == loc:
 | 
					            if entry.location == loc:
 | 
				
			||||||
                return entry
 | 
					                return entry
 | 
				
			||||||
| 
						 | 
					@ -349,8 +364,6 @@ cdef class PlayerLocationProxy:
 | 
				
			||||||
            return entry.item, entry.receiver, entry.flags
 | 
					            return entry.item, entry.receiver, entry.flags
 | 
				
			||||||
        raise KeyError(f"No location {key} for player {self._player}")
 | 
					        raise KeyError(f"No location {key} for player {self._player}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    T = TypeVar('T')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]:
 | 
					    def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]:
 | 
				
			||||||
        cdef LocationEntry* entry = self._get(key)
 | 
					        cdef LocationEntry* entry = self._get(key)
 | 
				
			||||||
        if entry:
 | 
					        if entry:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,8 +3,16 @@ import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_ext(modname, pyxfilename):
 | 
					def make_ext(modname, pyxfilename):
 | 
				
			||||||
    from distutils.extension import Extension
 | 
					    from distutils.extension import Extension
 | 
				
			||||||
    return Extension(name=modname,
 | 
					    return Extension(
 | 
				
			||||||
                     sources=[pyxfilename],
 | 
					        name=modname,
 | 
				
			||||||
                     depends=["intset.h"],
 | 
					        sources=[pyxfilename],
 | 
				
			||||||
                     include_dirs=[os.getcwd()],
 | 
					        depends=["intset.h"],
 | 
				
			||||||
                     language="c")
 | 
					        include_dirs=[os.getcwd()],
 | 
				
			||||||
 | 
					        language="c",
 | 
				
			||||||
 | 
					        # to enable ASAN and debug build:
 | 
				
			||||||
 | 
					        # extra_compile_args=["-fsanitize=address", "-UNDEBUG", "-Og", "-g"],
 | 
				
			||||||
 | 
					        # extra_objects=["-fsanitize=address"],
 | 
				
			||||||
 | 
					        # NOTE: we can not put -lasan at the front of link args, so needs to be run with
 | 
				
			||||||
 | 
					        #       LD_PRELOAD=/usr/lib/libasan.so ASAN_OPTIONS=detect_leaks=0 path/to/exe
 | 
				
			||||||
 | 
					        # NOTE: this can't find everything unless libpython and cymem are also built with ASAN
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -115,6 +115,7 @@ class Base:
 | 
				
			||||||
        def test_get_for_player(self) -> None:
 | 
					        def test_get_for_player(self) -> None:
 | 
				
			||||||
            self.assertEqual(self.store.get_for_player(3), {4: {9}})
 | 
					            self.assertEqual(self.store.get_for_player(3), {4: {9}})
 | 
				
			||||||
            self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}})
 | 
					            self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}})
 | 
				
			||||||
 | 
					            self.assertEqual(self.store.get_for_player(9999), {})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def test_get_checked(self) -> None:
 | 
					        def test_get_checked(self) -> None:
 | 
				
			||||||
            self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13])
 | 
					            self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13])
 | 
				
			||||||
| 
						 | 
					@ -122,18 +123,48 @@ class Base:
 | 
				
			||||||
            self.assertEqual(self.store.get_checked(empty_state, 0, 1), [])
 | 
					            self.assertEqual(self.store.get_checked(empty_state, 0, 1), [])
 | 
				
			||||||
            self.assertEqual(self.store.get_checked(full_state, 0, 3), [9])
 | 
					            self.assertEqual(self.store.get_checked(full_state, 0, 3), [9])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def test_get_checked_exception(self) -> None:
 | 
				
			||||||
 | 
					            with self.assertRaises(KeyError):
 | 
				
			||||||
 | 
					                self.store.get_checked(empty_state, 0, 9999)
 | 
				
			||||||
 | 
					            bad_state = {(0, 6): {1}}
 | 
				
			||||||
 | 
					            with self.assertRaises(KeyError):
 | 
				
			||||||
 | 
					                self.store.get_checked(bad_state, 0, 6)
 | 
				
			||||||
 | 
					            bad_state = {(0, 9999): set()}
 | 
				
			||||||
 | 
					            with self.assertRaises(KeyError):
 | 
				
			||||||
 | 
					                self.store.get_checked(bad_state, 0, 9999)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def test_get_missing(self) -> None:
 | 
					        def test_get_missing(self) -> None:
 | 
				
			||||||
            self.assertEqual(self.store.get_missing(full_state, 0, 1), [])
 | 
					            self.assertEqual(self.store.get_missing(full_state, 0, 1), [])
 | 
				
			||||||
            self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13])
 | 
					            self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13])
 | 
				
			||||||
            self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13])
 | 
					            self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13])
 | 
				
			||||||
            self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9])
 | 
					            self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def test_get_missing_exception(self) -> None:
 | 
				
			||||||
 | 
					            with self.assertRaises(KeyError):
 | 
				
			||||||
 | 
					                self.store.get_missing(empty_state, 0, 9999)
 | 
				
			||||||
 | 
					            bad_state = {(0, 6): {1}}
 | 
				
			||||||
 | 
					            with self.assertRaises(KeyError):
 | 
				
			||||||
 | 
					                self.store.get_missing(bad_state, 0, 6)
 | 
				
			||||||
 | 
					            bad_state = {(0, 9999): set()}
 | 
				
			||||||
 | 
					            with self.assertRaises(KeyError):
 | 
				
			||||||
 | 
					                self.store.get_missing(bad_state, 0, 9999)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def test_get_remaining(self) -> None:
 | 
					        def test_get_remaining(self) -> None:
 | 
				
			||||||
            self.assertEqual(self.store.get_remaining(full_state, 0, 1), [])
 | 
					            self.assertEqual(self.store.get_remaining(full_state, 0, 1), [])
 | 
				
			||||||
            self.assertEqual(self.store.get_remaining(one_state, 0, 1), [(1, 13), (2, 21)])
 | 
					            self.assertEqual(self.store.get_remaining(one_state, 0, 1), [(1, 13), (2, 21)])
 | 
				
			||||||
            self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [(1, 13), (2, 21), (2, 22)])
 | 
					            self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [(1, 13), (2, 21), (2, 22)])
 | 
				
			||||||
            self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [(4, 99)])
 | 
					            self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [(4, 99)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def test_get_remaining_exception(self) -> None:
 | 
				
			||||||
 | 
					            with self.assertRaises(KeyError):
 | 
				
			||||||
 | 
					                self.store.get_remaining(empty_state, 0, 9999)
 | 
				
			||||||
 | 
					            bad_state = {(0, 6): {1}}
 | 
				
			||||||
 | 
					            with self.assertRaises(KeyError):
 | 
				
			||||||
 | 
					                self.store.get_missing(bad_state, 0, 6)
 | 
				
			||||||
 | 
					            bad_state = {(0, 9999): set()}
 | 
				
			||||||
 | 
					            with self.assertRaises(KeyError):
 | 
				
			||||||
 | 
					                self.store.get_remaining(bad_state, 0, 9999)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def test_location_set_intersection(self) -> None:
 | 
					        def test_location_set_intersection(self) -> None:
 | 
				
			||||||
            locations = {10, 11, 12}
 | 
					            locations = {10, 11, 12}
 | 
				
			||||||
            locations.intersection_update(self.store[1])
 | 
					            locations.intersection_update(self.store[1])
 | 
				
			||||||
| 
						 | 
					@ -181,6 +212,16 @@ class Base:
 | 
				
			||||||
                })
 | 
					                })
 | 
				
			||||||
                self.assertEqual(len(store), 1)
 | 
					                self.assertEqual(len(store), 1)
 | 
				
			||||||
                self.assertEqual(len(store[1]), 0)
 | 
					                self.assertEqual(len(store[1]), 0)
 | 
				
			||||||
 | 
					                self.assertEqual(sorted(store.find_item(set(), 1)), [])
 | 
				
			||||||
 | 
					                self.assertEqual(sorted(store.find_item({1}, 1)), [])
 | 
				
			||||||
 | 
					                self.assertEqual(sorted(store.find_item({1, 2}, 1)), [])
 | 
				
			||||||
 | 
					                self.assertEqual(store.get_for_player(1), {})
 | 
				
			||||||
 | 
					                self.assertEqual(store.get_checked(empty_state, 0, 1), [])
 | 
				
			||||||
 | 
					                self.assertEqual(store.get_checked(full_state, 0, 1), [])
 | 
				
			||||||
 | 
					                self.assertEqual(store.get_missing(empty_state, 0, 1), [])
 | 
				
			||||||
 | 
					                self.assertEqual(store.get_missing(full_state, 0, 1), [])
 | 
				
			||||||
 | 
					                self.assertEqual(store.get_remaining(empty_state, 0, 1), [])
 | 
				
			||||||
 | 
					                self.assertEqual(store.get_remaining(full_state, 0, 1), [])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def test_no_locations_for_1(self) -> None:
 | 
					        def test_no_locations_for_1(self) -> None:
 | 
				
			||||||
            store = self.type({
 | 
					            store = self.type({
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue