MultiServer: speed up location commands (#1926)

* MultiServer: speed up location commands

Adds optimized pure python wrapper around locations dict
Adds optimized cython implementation of the wrapper, saving cpu time and 80% memory use

* Speedups: auto-build on import and build during setup

* Speedups: add requirements

* CI: don't break with build_ext

* Speedups: use C++ compiler for pyximport

* Speedups: cleanup and more validation

* Speedups: add tests for LocationStore

* Setup: delete temp in-place build modules

* Speedups: more tests and safer indices

The change has no security implications, but ensures that entries[IndexEntry.start] is always valid.

* Speedups: add cython3 compatibility

* Speedups: remove unused import

* Speedups: reformat

* Speedup: fix empty set in test

* Speedups: use regular dict in Locations.get_for_player

* CI: run unittests with beta cython

now with 2x nicer names
This commit is contained in:
black-sliver 2023-07-04 19:12:43 +02:00 committed by GitHub
parent d35d3b629e
commit b6e78bd1a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 675 additions and 34 deletions

View File

@ -38,12 +38,13 @@ jobs:
run: |
python -m pip install --upgrade pip
python setup.py build_exe --yes
$NAME="$(ls build)".Split('.',2)[1]
$NAME="$(ls build | Select-String -Pattern 'exe')".Split('.',2)[1]
$ZIP_NAME="Archipelago_$NAME.7z"
echo "$NAME -> $ZIP_NAME"
echo "ZIP_NAME=$ZIP_NAME" >> $Env:GITHUB_ENV
New-Item -Path dist -ItemType Directory -Force
cd build
Rename-Item exe.$NAME Archipelago
Rename-Item "exe.$NAME" Archipelago
7z a -mx=9 -mhe=on -ms "../dist/$ZIP_NAME" Archipelago
- name: Store 7z
uses: actions/upload-artifact@v3

View File

@ -26,12 +26,14 @@ on:
jobs:
build:
runs-on: ${{ matrix.os }}
name: Test Python ${{ matrix.python.version }} ${{ matrix.os }}
name: Test Python ${{ matrix.python.version }} ${{ matrix.os }} ${{ matrix.cython }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
cython:
- '' # default
python:
- {version: '3.8'}
- {version: '3.9'}
@ -43,6 +45,9 @@ jobs:
os: windows-latest
- python: {version: '3.10'} # current
os: macos-latest
- python: {version: '3.10'} # current
os: ubuntu-latest
cython: beta
steps:
- uses: actions/checkout@v3
@ -50,6 +55,11 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python.version }}
- name: Install cython beta
if: ${{ matrix.cython == 'beta' }}
run: |
python -m pip install --upgrade pip
python -m pip install --pre --upgrade cython
- name: Install dependencies
run: |
python -m pip install --upgrade pip

4
.gitignore vendored
View File

@ -168,6 +168,10 @@ dmypy.json
# Cython debug symbols
cython_debug/
# Cython intermediates
_speedups.cpp
_speedups.html
# minecraft server stuff
jdk*/
minecraft*/

View File

@ -38,7 +38,7 @@ import NetUtils
import Utils
from Utils import version_tuple, restricted_loads, Version, async_start
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
SlotType
SlotType, LocationStore
min_client_version = Version(0, 1, 6)
colorama.init()
@ -152,7 +152,9 @@ class Context:
"compatibility": int}
# team -> slot id -> list of clients authenticated to slot.
clients: typing.Dict[int, typing.Dict[int, typing.List[Client]]]
locations: typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]]
locations: LocationStore # typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]]
location_checks: typing.Dict[typing.Tuple[int, int], typing.Set[int]]
hints_used: typing.Dict[typing.Tuple[int, int], int]
groups: typing.Dict[int, typing.Set[int]]
save_version = 2
stored_data: typing.Dict[str, object]
@ -187,8 +189,6 @@ class Context:
self.player_name_lookup: typing.Dict[str, team_slot] = {}
self.connect_names = {} # names of slots clients can connect to
self.allow_releases = {}
# player location_id item_id target_player_id
self.locations = {}
self.host = host
self.port = port
self.server_password = server_password
@ -284,6 +284,7 @@ class Context:
except websockets.ConnectionClosed:
logging.exception(f"Exception during send_msgs, could not send {msg}")
await self.disconnect(endpoint)
return False
else:
if self.log_network:
logging.info(f"Outgoing message: {msg}")
@ -297,6 +298,7 @@ class Context:
except websockets.ConnectionClosed:
logging.exception("Exception during send_encoded_msgs")
await self.disconnect(endpoint)
return False
else:
if self.log_network:
logging.info(f"Outgoing message: {msg}")
@ -311,6 +313,7 @@ class Context:
websockets.broadcast(sockets, msg)
except RuntimeError:
logging.exception("Exception during broadcast_send_encoded_msgs")
return False
else:
if self.log_network:
logging.info(f"Outgoing broadcast: {msg}")
@ -413,7 +416,7 @@ class Context:
self.seed_name = decoded_obj["seed_name"]
self.random.seed(self.seed_name)
self.connect_names = decoded_obj['connect_names']
self.locations = decoded_obj['locations']
self.locations = LocationStore(decoded_obj.pop("locations")) # pre-emptively free memory
self.slot_data = decoded_obj['slot_data']
for slot, data in self.slot_data.items():
self.read_data[f"slot_data_{slot}"] = lambda data=data: data
@ -902,11 +905,7 @@ def release_player(ctx: Context, team: int, slot: int):
def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False):
"""register any locations that are in the multidata, pointing towards this player"""
all_locations = collections.defaultdict(set)
for source_slot, location_data in ctx.locations.items():
for location_id, values in location_data.items():
if values[1] == slot:
all_locations[source_slot].add(location_id)
all_locations = ctx.locations.get_for_player(slot)
ctx.broadcast_text_all("%s (Team #%d) has collected their items from other worlds."
% (ctx.player_names[(team, slot)], team + 1),
@ -925,11 +924,7 @@ def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False):
def get_remaining(ctx: Context, team: int, slot: int) -> typing.List[int]:
items = []
for location_id in ctx.locations[slot]:
if location_id not in ctx.location_checks[team, slot]:
items.append(ctx.locations[slot][location_id][0]) # item ID
return sorted(items)
return ctx.locations.get_remaining(ctx.location_checks, team, slot)
def send_items_to(ctx: Context, team: int, target_slot: int, *items: NetworkItem):
@ -977,13 +972,12 @@ def collect_hints(ctx: Context, team: int, slot: int, item: typing.Union[int, st
slots.add(group_id)
seeked_item_id = item if isinstance(item, int) else ctx.item_names_for_game(ctx.games[slot])[item]
for finding_player, check_data in ctx.locations.items():
for location_id, (item_id, receiving_player, item_flags) in check_data.items():
if receiving_player in slots and item_id == seeked_item_id:
found = location_id in ctx.location_checks[team, finding_player]
entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "")
hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance,
item_flags))
for finding_player, location_id, item_id, receiving_player, item_flags \
in ctx.locations.find_item(slots, seeked_item_id):
found = location_id in ctx.location_checks[team, finding_player]
entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "")
hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance,
item_flags))
return hints
@ -1555,15 +1549,11 @@ class ClientMessageProcessor(CommonCommandProcessor):
def get_checked_checks(ctx: Context, team: int, slot: int) -> typing.List[int]:
return [location_id for
location_id in ctx.locations[slot] if
location_id in ctx.location_checks[team, slot]]
return ctx.locations.get_checked(ctx.location_checks, team, slot)
def get_missing_checks(ctx: Context, team: int, slot: int) -> typing.List[int]:
return [location_id for
location_id in ctx.locations[slot] if
location_id not in ctx.location_checks[team, slot]]
return ctx.locations.get_missing(ctx.location_checks, team, slot)
def get_client_points(ctx: Context, client: Client) -> int:

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import typing
import enum
import warnings
from json import JSONEncoder, JSONDecoder
import websockets
@ -343,3 +344,65 @@ class Hint(typing.NamedTuple):
@property
def local(self):
return self.receiving_player == self.finding_player
class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tuple[int, int, int]]]):
def find_item(self, slots: typing.Set[int], seeked_item_id: int
) -> typing.Generator[typing.Tuple[int, int, int, int, int], None, None]:
for finding_player, check_data in self.items():
for location_id, (item_id, receiving_player, item_flags) in check_data.items():
if receiving_player in slots and item_id == seeked_item_id:
yield finding_player, location_id, item_id, receiving_player, item_flags
def get_for_player(self, slot: int) -> typing.Dict[int, typing.Set[int]]:
import collections
all_locations: typing.Dict[int, typing.Set[int]] = collections.defaultdict(set)
for source_slot, location_data in self.items():
for location_id, values in location_data.items():
if values[1] == slot:
all_locations[source_slot].add(location_id)
return all_locations
def get_checked(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
) -> typing.List[int]:
checked = state[team, slot]
if not checked:
# This optimizes the case where everyone connects to a fresh game at the same time.
return []
return [location_id for
location_id in self[slot] if
location_id in checked]
def get_missing(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
) -> typing.List[int]:
checked = state[team, slot]
if not checked:
# This optimizes the case where everyone connects to a fresh game at the same time.
return list(self)
return [location_id for
location_id in self[slot] if
location_id not in checked]
def get_remaining(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
) -> typing.List[int]:
checked = state[team, slot]
player_locations = self[slot]
return sorted([player_locations[location_id][0] for
location_id in player_locations if
location_id not in checked])
if typing.TYPE_CHECKING: # type-check with pure python implementation until we have a typing stub
LocationStore = _LocationStore
else:
try:
import pyximport
pyximport.install()
except ImportError:
pyximport = None
try:
from _speedups import LocationStore
except ImportError:
warnings.warn("_speedups not available. Falling back to pure python LocationStore. "
"Install a matching C++ compiler for your platform to compile _speedups.")
LocationStore = _LocationStore

335
_speedups.pyx Normal file
View File

@ -0,0 +1,335 @@
#cython: language_level=3
#distutils: language = c++
"""
Provides faster implementation of some core parts.
This is deliberately .pyx because using a non-compiled "pure python" may be slower.
"""
# pip install cython cymem
import cython
from cpython cimport PyObject
from typing import Any, Dict, Iterable, Iterator, Generator, Sequence, Tuple, TypeVar, Union, Set, List, TYPE_CHECKING
from cymem.cymem cimport Pool
from libc.stdint cimport int64_t, uint32_t
from libcpp.set cimport set as std_set
from collections import defaultdict
ctypedef uint32_t ap_player_t # on AMD64 this is faster (and smaller) than 64bit ints
ctypedef uint32_t ap_flags_t
ctypedef int64_t ap_id_t
cdef ap_player_t MAX_PLAYER_ID = 1000000 # limit the size of indexing array
cdef size_t INVALID_SIZE = <size_t>(-1) # this is all 0xff... adding 1 results in 0, but it's not negative
cdef struct LocationEntry:
# layout is so that
# 64bit player: location+sender and item+receiver 128bit comparisons, if supported
# 32bit player: aligned to 32/64bit with no unused space
ap_id_t location
ap_player_t sender
ap_player_t receiver
ap_id_t item
ap_flags_t flags
cdef struct IndexEntry:
size_t start
size_t count
cdef class LocationStore:
"""Compact store for locations and their items in a MultiServer"""
# The original implementation uses Dict[int, Dict[int, Tuple(int, int, int]]
# with sender, location, (item, receiver, flags).
# This implementation is a flat list of (sender, location, item, receiver, flags) using native integers
# as well as some mapping arrays used to speed up stuff, saving a lot of memory while speeding up hints.
# Using std::map might be worth investigating, but memory overhead would be ~100% compared to arrays.
cdef Pool _mem
cdef object _len
cdef LocationEntry* entries # 3.2MB/100k items
cdef size_t entry_count
cdef IndexEntry* sender_index # 16KB/1000 players
cdef size_t sender_index_size
cdef list _keys # ~36KB/1000 players, speed up iter (28 per int + 8 per list entry)
cdef list _items # ~64KB/1000 players, speed up items (56 per tuple + 8 per list entry)
cdef list _proxies # ~92KB/1000 players, speed up self[player] (56 per struct + 28 per len + 8 per list entry)
cdef PyObject** _raw_proxies # 8K/1000 players, faster access to _proxies, but does not keep a ref
def get_size(self):
from sys import getsizeof
size = getsizeof(self) + getsizeof(self._mem) + getsizeof(self._len) \
+ sizeof(LocationEntry) * self.entry_count + sizeof(IndexEntry) * self.sender_index_size
size += getsizeof(self._keys) + getsizeof(self._items) + getsizeof(self._proxies)
size += sum(sizeof(key) for key in self._keys)
size += sum(sizeof(item) for item in self._items)
size += sum(sizeof(proxy) for proxy in self._proxies)
size += sizeof(self._raw_proxies[0]) * self.sender_index_size
return size
def __cinit__(self, locations_dict: Dict[int, Dict[int, Sequence[int]]]) -> None:
self._mem = None
self._keys = None
self._items = None
self._proxies = None
self._len = 0
self.entries = NULL
self.entry_count = 0
self.sender_index = NULL
self.sender_index_size = 0
self._raw_proxies = NULL
def __init__(self, locations_dict: Dict[int, Dict[int, Sequence[int]]]) -> None:
self._mem = Pool()
cdef object key
self._keys = []
self._items = []
self._proxies = []
# iterate over everything to get all maxima and validate everything
cdef size_t max_sender = INVALID_SIZE # keep track of highest used player id for indexing
cdef size_t sender_count = 0
cdef size_t count = 0
for sender, locations in locations_dict.items():
# we don't require the dict to be sorted here
if not isinstance(sender, int) or sender < 1 or sender > MAX_PLAYER_ID:
raise ValueError(f"Invalid player id {sender} for location")
if max_sender == INVALID_SIZE:
max_sender = sender
else:
max_sender = max(max_sender, sender)
for location, data in locations.items():
receiver = data[1]
if receiver < 1 or receiver > MAX_PLAYER_ID:
raise ValueError(f"Invalid player id {receiver} for item")
count += 1
sender_count += 1
if not count:
raise ValueError("No locations")
if sender_count != max_sender:
# we assume player 0 will never have locations
raise ValueError("Player IDs not continuous")
# allocate the arrays and invalidate index (0xff...)
self.entries = <LocationEntry*>self._mem.alloc(count, sizeof(LocationEntry))
self.sender_index = <IndexEntry*>self._mem.alloc(max_sender + 1, sizeof(IndexEntry))
self._raw_proxies = <PyObject**>self._mem.alloc(max_sender + 1, sizeof(PyObject*))
# build entries and index
cdef size_t i = 0
for sender, locations in sorted(locations_dict.items()):
self.sender_index[sender].start = i
self.sender_index[sender].count = 0
# Sorting locations here makes it possible to write a faster lookup without an additional index.
for location, data in sorted(locations.items()):
self.entries[i].sender = sender
self.entries[i].location = location
self.entries[i].item = data[0]
self.entries[i].receiver = data[1]
if len(data) > 2:
self.entries[i].flags = data[2] # initialized to 0 during alloc
# Ignoring extra data. warn?
self.sender_index[sender].count += 1
i += 1
# build pyobject caches
self._proxies.append(None) # player 0
assert self.sender_index[0].count == 0
for i in range(1, max_sender + 1):
if self.sender_index[i].count == 0 and self.sender_index[i].start >= count:
self.sender_index[i].start = 0 # do not point outside valid entries
assert self.sender_index[i].start < count
key = i # allocate python integer
proxy = PlayerLocationProxy(self, i)
self._keys.append(key)
self._items.append((key, proxy))
self._proxies.append(proxy)
self._raw_proxies[i] = <PyObject*>proxy
self.sender_index_size = max_sender + 1
self.entry_count = count
self._len = sender_count
# fake dict access
def __len__(self) -> int:
return self._len
def __iter__(self) -> Iterator[int]:
return self._keys.__iter__()
def __getitem__(self, key: int) -> Any:
# figure out if player actually exists in the multidata and return a proxy
cdef size_t i = key # NOTE: this may raise TypeError
if i < 1 or i >= self.sender_index_size:
raise KeyError(key)
return <object>self._raw_proxies[key]
T = TypeVar('T')
def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]:
# calling into self.__getitem__ here is slow, but this is not used in MultiServer
try:
return self[key]
except KeyError:
return default
def items(self) -> Iterable[Tuple[int, PlayerLocationProxy]]:
return self._items
# specialized accessors
def find_item(self, slots: Set[int], seeked_item_id: int) -> Generator[Tuple[int, int, int, int, int], None, None]:
cdef ap_id_t item = seeked_item_id
cdef ap_player_t receiver
cdef std_set[ap_player_t] receivers
cdef size_t slot_count = len(slots)
if slot_count == 1:
# specialized implementation for single slot
receiver = list(slots)[0]
with nogil:
for entry in self.entries[:self.entry_count]:
if entry.item == item and entry.receiver == receiver:
with gil:
yield entry.sender, entry.location, entry.item, entry.receiver, entry.flags
elif slot_count:
# generic implementation with lookup in set
for receiver in slots:
receivers.insert(receiver)
with nogil:
for entry in self.entries[:self.entry_count]:
if entry.item == item and receivers.count(entry.receiver):
with gil:
yield entry.sender, entry.location, entry.item, entry.receiver, entry.flags
def get_for_player(self, slot: int) -> Dict[int, Set[int]]:
cdef ap_player_t receiver = slot
all_locations: Dict[int, Set[int]] = {}
with nogil:
for entry in self.entries[:self.entry_count]:
if entry.receiver == receiver:
with gil:
sender: int = entry.sender
if sender not in all_locations:
all_locations[sender] = set()
all_locations[sender].add(entry.location)
return all_locations
if TYPE_CHECKING:
State = Dict[Tuple[int, int], Set[int]]
else:
State = Union[Tuple[int, int], Set[int], defaultdict]
def get_checked(self, state: State, team: int, slot: int) -> List[int]:
# This used to validate checks actually exist. A remnant from the past.
# If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it.
cdef set checked = state[team, slot]
if not len(checked):
# Skips loop if none have been checked.
# This optimizes the case where everyone connects to a fresh game at the same time.
return []
# Unless the set is close to empty, it's cheaper to use the python set directly, so we do that.
cdef LocationEntry* entry
cdef ap_player_t sender = slot
cdef size_t start = self.sender_index[sender].start
cdef size_t count = self.sender_index[sender].count
return [entry.location for
entry in self.entries[start:start+count] if
entry.location in checked]
def get_missing(self, state: State, team: int, slot: int) -> List[int]:
cdef LocationEntry* entry
cdef ap_player_t sender = slot
cdef size_t start = self.sender_index[sender].start
cdef size_t count = self.sender_index[sender].count
cdef set checked = state[team, slot]
if not len(checked):
# Skip `in` if none have been checked.
# This optimizes the case where everyone connects to a fresh game at the same time.
return [entry.location for
entry in self.entries[start:start + count]]
else:
# Unless the set is close to empty, it's cheaper to use the python set directly, so we do that.
return [entry.location for
entry in self.entries[start:start + count] if
entry.location not in checked]
def get_remaining(self, state: State, team: int, slot: int) -> List[int]:
cdef LocationEntry* entry
cdef ap_player_t sender = slot
cdef size_t start = self.sender_index[sender].start
cdef size_t count = self.sender_index[sender].count
cdef set checked = state[team, slot]
return sorted([entry.item for
entry in self.entries[start:start+count] if
entry.location not in checked])
@cython.internal # unsafe. disable direct import
cdef class PlayerLocationProxy:
cdef LocationStore _store
cdef size_t _player
cdef object _len
def __init__(self, store: LocationStore, player: int) -> None:
self._store = store
self._player = player
self._len = self._store.sender_index[self._player].count
def __len__(self) -> int:
return self._store.sender_index[self._player].count
def __iter__(self) -> Generator[int, None, None]:
cdef LocationEntry* entry
cdef size_t i
cdef size_t off = self._store.sender_index[self._player].start
for i in range(self._store.sender_index[self._player].count):
entry = self._store.entries + off + i
yield entry.location
cdef LocationEntry* _get(self, ap_id_t loc):
# This requires locations to be sorted.
# This is always going to be slower than a pure python dict, because constructing the result tuple takes as long
# as the search in a python dict, which stores a pointer to an existing tuple.
cdef LocationEntry* entry = NULL
# binary search
cdef size_t l = self._store.sender_index[self._player].start
cdef size_t r = l + self._store.sender_index[self._player].count
cdef size_t m
while l < r:
m = (l + r) // 2
entry = self._store.entries + m
if entry.location < loc:
l = m + 1
else:
r = m
if entry: # count != 0
entry = self._store.entries + l
if entry.location == loc:
return entry
return NULL
def __getitem__(self, key: int) -> Tuple[int, int, int]:
cdef LocationEntry* entry = self._get(key)
if entry:
return entry.item, entry.receiver, entry.flags
raise KeyError(f"No location {key} for player {self._player}")
T = TypeVar('T')
def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]:
cdef LocationEntry* entry = self._get(key)
if entry:
return entry.item, entry.receiver, entry.flags
return default
def items(self) -> Generator[Tuple[int, Tuple[int, int, int]], None, None]:
cdef LocationEntry* entry
start = self._store.sender_index[self._player].start
count = self._store.sender_index[self._player].count
for entry in self._store.entries[start:start+count]:
yield entry.location, (entry.item, entry.receiver, entry.flags)

8
_speedups.pyxbld Normal file
View File

@ -0,0 +1,8 @@
# This file is required to get pyximport to work with C++.
# Switching from std::set to a pure C implementation is still on the table to simplify everything.
def make_ext(modname, pyxfilename):
from distutils.extension import Extension
return Extension(name=modname,
sources=[pyxfilename],
language='c++')

View File

@ -8,3 +8,5 @@ kivy>=2.2.0
bsdiff4>=1.2.3
platformdirs>=3.8.0
certifi>=2023.5.7
cython>=0.29.35
cymem>=2.0.7

View File

@ -57,6 +57,7 @@ if __name__ == "__main__":
from worlds.LauncherComponents import components, icon_paths
from Utils import version_tuple, is_windows, is_linux
from Cython.Build import cythonize
# On Python < 3.10 LogicMixin is not currently supported.
@ -292,17 +293,27 @@ class BuildExeCommand(cx_Freeze.command.build_exe.BuildEXE):
sni_thread = threading.Thread(target=download_SNI, name="SNI Downloader")
sni_thread.start()
# pre build steps
# pre-build steps
print(f"Outputting to: {self.buildfolder}")
os.makedirs(self.buildfolder, exist_ok=True)
import ModuleUpdate
ModuleUpdate.requirements_files.add(os.path.join("WebHostLib", "requirements.txt"))
ModuleUpdate.update(yes=self.yes)
# auto-build cython modules
build_ext = self.distribution.get_command_obj("build_ext")
build_ext.inplace = True
self.run_command("build_ext")
# regular cx build
self.buildtime = datetime.datetime.utcnow()
super().run()
# delete in-place built modules, otherwise this interferes with future pyximport
for path in build_ext.get_output_mapping().values():
print(f"deleting temp {path}")
os.unlink(path)
# need to finish download before copying
sni_thread.join()
@ -585,10 +596,10 @@ cx_Freeze.setup(
version=f"{version_tuple.major}.{version_tuple.minor}.{version_tuple.build}",
description="Archipelago",
executables=exes,
ext_modules=[], # required to disable auto-discovery with setuptools>=61
ext_modules=cythonize("_speedups.pyx"),
options={
"build_exe": {
"packages": ["worlds", "kivy"],
"packages": ["worlds", "kivy", "_speedups", "cymem"],
"includes": [],
"excludes": ["numpy", "Cython", "PySide2", "PIL",
"pandas"],

View File

@ -0,0 +1,217 @@
# Tests for _speedups.LocationStore and NetUtils._LocationStore
import typing
import unittest
from NetUtils import LocationStore, _LocationStore
sample_data = {
1: {
11: (21, 2, 7),
12: (22, 2, 0),
13: (13, 1, 0),
},
2: {
23: (11, 1, 0),
22: (12, 1, 0),
21: (23, 2, 0),
},
4: {
9: (99, 3, 0),
},
3: {
9: (99, 4, 0),
},
}
empty_state = {
(0, slot): set() for slot in sample_data
}
full_state = {
(0, slot): set(locations) for (slot, locations) in sample_data.items()
}
one_state = {
(0, 1): {12}
}
class Base:
class TestLocationStore(unittest.TestCase):
store: typing.Union[LocationStore, _LocationStore]
def test_len(self):
self.assertEqual(len(self.store), 4)
self.assertEqual(len(self.store[1]), 3)
def test_key_error(self):
with self.assertRaises(KeyError):
_ = self.store[0]
with self.assertRaises(KeyError):
_ = self.store[5]
locations = self.store[1] # no Exception
with self.assertRaises(KeyError):
_ = locations[7]
_ = locations[11] # no Exception
def test_getitem(self):
self.assertEqual(self.store[1][11], (21, 2, 7))
self.assertEqual(self.store[1][13], (13, 1, 0))
self.assertEqual(self.store[2][22], (12, 1, 0))
self.assertEqual(self.store[4][9], (99, 3, 0))
def test_get(self):
self.assertEqual(self.store.get(1, None), self.store[1])
self.assertEqual(self.store.get(0, None), None)
self.assertEqual(self.store[1].get(11, (None, None, None)), self.store[1][11])
self.assertEqual(self.store[1].get(10, (None, None, None)), (None, None, None))
def test_iter(self):
self.assertEqual(sorted(self.store), [1, 2, 3, 4])
self.assertEqual(len(self.store), len(sample_data))
self.assertEqual(list(self.store[1]), [11, 12, 13])
self.assertEqual(len(self.store[1]), len(sample_data[1]))
def test_items(self):
self.assertEqual(sorted(p for p, _ in self.store.items()), sorted(self.store))
self.assertEqual(sorted(p for p, _ in self.store[1].items()), sorted(self.store[1]))
self.assertEqual(sorted(self.store.items())[0][0], 1)
self.assertEqual(sorted(self.store.items())[0][1], self.store[1])
self.assertEqual(sorted(self.store[1].items())[0][0], 11)
self.assertEqual(sorted(self.store[1].items())[0][1], self.store[1][11])
def test_find_item(self):
self.assertEqual(sorted(self.store.find_item(set(), 99)), [])
self.assertEqual(sorted(self.store.find_item({3}, 1)), [])
self.assertEqual(sorted(self.store.find_item({5}, 99)), [])
self.assertEqual(sorted(self.store.find_item({3}, 99)),
[(4, 9, 99, 3, 0)])
self.assertEqual(sorted(self.store.find_item({3, 4}, 99)),
[(3, 9, 99, 4, 0), (4, 9, 99, 3, 0)])
def test_get_for_player(self):
self.assertEqual(self.store.get_for_player(3), {4: {9}})
self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}})
def get_checked(self):
self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13])
self.assertEqual(self.store.get_checked(one_state, 0, 1), [12])
self.assertEqual(self.store.get_checked(empty_state, 0, 1), [])
self.assertEqual(self.store.get_checked(full_state, 0, 3), [9])
def get_missing(self):
self.assertEqual(self.store.get_missing(full_state, 0, 1), [])
self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13])
self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13])
self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9])
def get_remaining(self):
self.assertEqual(self.store.get_remaining(full_state, 0, 1), [])
self.assertEqual(self.store.get_remaining(one_state, 0, 1), [13, 21])
self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [13, 21, 22])
self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [99])
class TestPurePythonLocationStore(Base.TestLocationStore):
def setUp(self) -> None:
self.store = _LocationStore(sample_data)
super().setUp()
@unittest.skipIf(LocationStore is _LocationStore, "_speedups not available")
class TestSpeedupsLocationStore(Base.TestLocationStore):
def setUp(self) -> None:
self.store = LocationStore(sample_data)
super().setUp()
@unittest.skipIf(LocationStore is _LocationStore, "_speedups not available")
class TestSpeedupsLocationStoreConstructor(unittest.TestCase):
def test_float_key(self):
with self.assertRaises(Exception):
LocationStore({
1: {1: (1, 1, 1)},
1.1: {1: (1, 1, 1)},
3: {1: (1, 1, 1)}
})
def test_string_key(self):
with self.assertRaises(Exception):
LocationStore({
"1": {1: (1, 1, 1)},
})
def test_hole(self):
with self.assertRaises(Exception):
LocationStore({
1: {1: (1, 1, 1)},
3: {1: (1, 1, 1)},
})
def test_no_slot1(self):
with self.assertRaises(Exception):
LocationStore({
2: {1: (1, 1, 1)},
3: {1: (1, 1, 1)},
})
def test_slot0(self):
with self.assertRaises(Exception):
LocationStore({
0: {1: (1, 1, 1)},
1: {1: (1, 1, 1)},
})
with self.assertRaises(Exception):
LocationStore({
0: {1: (1, 1, 1)},
2: {1: (1, 1, 1)},
})
def test_high_player_number(self):
with self.assertRaises(Exception):
LocationStore({
1 << 32: {1: (1, 1, 1)},
})
def test_no_players(self):
try: # either is fine: raise during init, or behave like {}
store = LocationStore({})
self.assertEqual(len(store), 0)
with self.assertRaises(KeyError):
_ = store[1]
except ValueError:
pass
def test_no_locations(self):
try: # either is fine: raise during init, or behave like {1: {}}
store = LocationStore({
1: {},
})
self.assertEqual(len(store), 1)
self.assertEqual(len(store[1]), 0)
except ValueError:
pass
def test_no_locations_for_1(self):
store = LocationStore({
1: {},
2: {1: (1, 2, 3)},
})
self.assertEqual(len(store), 2)
self.assertEqual(len(store[1]), 0)
self.assertEqual(len(store[2]), 1)
def test_no_locations_for_last(self):
store = LocationStore({
1: {1: (1, 2, 3)},
2: {},
})
self.assertEqual(len(store), 2)
self.assertEqual(len(store[1]), 1)
self.assertEqual(len(store[2]), 0)
def test_not_a_tuple(self):
with self.assertRaises(Exception):
LocationStore({
1: {1: None},
})

View File