MultiServer: speed up location commands (#1926)
* MultiServer: speed up location commands Adds optimized pure python wrapper around locations dict Adds optimized cython implementation of the wrapper, saving cpu time and 80% memory use * Speedups: auto-build on import and build during setup * Speedups: add requirements * CI: don't break with build_ext * Speedups: use C++ compiler for pyximport * Speedups: cleanup and more validation * Speedups: add tests for LocationStore * Setup: delete temp in-place build modules * Speedups: more tests and safer indices The change has no security implications, but ensures that entries[IndexEntry.start] is always valid. * Speedups: add cython3 compatibility * Speedups: remove unused import * Speedups: reformat * Speedup: fix empty set in test * Speedups: use regular dict in Locations.get_for_player * CI: run unittests with beta cython now with 2x nicer names
This commit is contained in:
parent
d35d3b629e
commit
b6e78bd1a3
|
@ -38,12 +38,13 @@ jobs:
|
|||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python setup.py build_exe --yes
|
||||
$NAME="$(ls build)".Split('.',2)[1]
|
||||
$NAME="$(ls build | Select-String -Pattern 'exe')".Split('.',2)[1]
|
||||
$ZIP_NAME="Archipelago_$NAME.7z"
|
||||
echo "$NAME -> $ZIP_NAME"
|
||||
echo "ZIP_NAME=$ZIP_NAME" >> $Env:GITHUB_ENV
|
||||
New-Item -Path dist -ItemType Directory -Force
|
||||
cd build
|
||||
Rename-Item exe.$NAME Archipelago
|
||||
Rename-Item "exe.$NAME" Archipelago
|
||||
7z a -mx=9 -mhe=on -ms "../dist/$ZIP_NAME" Archipelago
|
||||
- name: Store 7z
|
||||
uses: actions/upload-artifact@v3
|
||||
|
|
|
@ -26,12 +26,14 @@ on:
|
|||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
name: Test Python ${{ matrix.python.version }} ${{ matrix.os }}
|
||||
name: Test Python ${{ matrix.python.version }} ${{ matrix.os }} ${{ matrix.cython }}
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
cython:
|
||||
- '' # default
|
||||
python:
|
||||
- {version: '3.8'}
|
||||
- {version: '3.9'}
|
||||
|
@ -43,6 +45,9 @@ jobs:
|
|||
os: windows-latest
|
||||
- python: {version: '3.10'} # current
|
||||
os: macos-latest
|
||||
- python: {version: '3.10'} # current
|
||||
os: ubuntu-latest
|
||||
cython: beta
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
@ -50,6 +55,11 @@ jobs:
|
|||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python.version }}
|
||||
- name: Install cython beta
|
||||
if: ${{ matrix.cython == 'beta' }}
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install --pre --upgrade cython
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
|
|
@ -168,6 +168,10 @@ dmypy.json
|
|||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# Cython intermediates
|
||||
_speedups.cpp
|
||||
_speedups.html
|
||||
|
||||
# minecraft server stuff
|
||||
jdk*/
|
||||
minecraft*/
|
||||
|
|
|
@ -38,7 +38,7 @@ import NetUtils
|
|||
import Utils
|
||||
from Utils import version_tuple, restricted_loads, Version, async_start
|
||||
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
|
||||
SlotType
|
||||
SlotType, LocationStore
|
||||
|
||||
min_client_version = Version(0, 1, 6)
|
||||
colorama.init()
|
||||
|
@ -152,7 +152,9 @@ class Context:
|
|||
"compatibility": int}
|
||||
# team -> slot id -> list of clients authenticated to slot.
|
||||
clients: typing.Dict[int, typing.Dict[int, typing.List[Client]]]
|
||||
locations: typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]]
|
||||
locations: LocationStore # typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]]
|
||||
location_checks: typing.Dict[typing.Tuple[int, int], typing.Set[int]]
|
||||
hints_used: typing.Dict[typing.Tuple[int, int], int]
|
||||
groups: typing.Dict[int, typing.Set[int]]
|
||||
save_version = 2
|
||||
stored_data: typing.Dict[str, object]
|
||||
|
@ -187,8 +189,6 @@ class Context:
|
|||
self.player_name_lookup: typing.Dict[str, team_slot] = {}
|
||||
self.connect_names = {} # names of slots clients can connect to
|
||||
self.allow_releases = {}
|
||||
# player location_id item_id target_player_id
|
||||
self.locations = {}
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.server_password = server_password
|
||||
|
@ -284,6 +284,7 @@ class Context:
|
|||
except websockets.ConnectionClosed:
|
||||
logging.exception(f"Exception during send_msgs, could not send {msg}")
|
||||
await self.disconnect(endpoint)
|
||||
return False
|
||||
else:
|
||||
if self.log_network:
|
||||
logging.info(f"Outgoing message: {msg}")
|
||||
|
@ -297,6 +298,7 @@ class Context:
|
|||
except websockets.ConnectionClosed:
|
||||
logging.exception("Exception during send_encoded_msgs")
|
||||
await self.disconnect(endpoint)
|
||||
return False
|
||||
else:
|
||||
if self.log_network:
|
||||
logging.info(f"Outgoing message: {msg}")
|
||||
|
@ -311,6 +313,7 @@ class Context:
|
|||
websockets.broadcast(sockets, msg)
|
||||
except RuntimeError:
|
||||
logging.exception("Exception during broadcast_send_encoded_msgs")
|
||||
return False
|
||||
else:
|
||||
if self.log_network:
|
||||
logging.info(f"Outgoing broadcast: {msg}")
|
||||
|
@ -413,7 +416,7 @@ class Context:
|
|||
self.seed_name = decoded_obj["seed_name"]
|
||||
self.random.seed(self.seed_name)
|
||||
self.connect_names = decoded_obj['connect_names']
|
||||
self.locations = decoded_obj['locations']
|
||||
self.locations = LocationStore(decoded_obj.pop("locations")) # pre-emptively free memory
|
||||
self.slot_data = decoded_obj['slot_data']
|
||||
for slot, data in self.slot_data.items():
|
||||
self.read_data[f"slot_data_{slot}"] = lambda data=data: data
|
||||
|
@ -902,11 +905,7 @@ def release_player(ctx: Context, team: int, slot: int):
|
|||
|
||||
def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False):
|
||||
"""register any locations that are in the multidata, pointing towards this player"""
|
||||
all_locations = collections.defaultdict(set)
|
||||
for source_slot, location_data in ctx.locations.items():
|
||||
for location_id, values in location_data.items():
|
||||
if values[1] == slot:
|
||||
all_locations[source_slot].add(location_id)
|
||||
all_locations = ctx.locations.get_for_player(slot)
|
||||
|
||||
ctx.broadcast_text_all("%s (Team #%d) has collected their items from other worlds."
|
||||
% (ctx.player_names[(team, slot)], team + 1),
|
||||
|
@ -925,11 +924,7 @@ def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False):
|
|||
|
||||
|
||||
def get_remaining(ctx: Context, team: int, slot: int) -> typing.List[int]:
|
||||
items = []
|
||||
for location_id in ctx.locations[slot]:
|
||||
if location_id not in ctx.location_checks[team, slot]:
|
||||
items.append(ctx.locations[slot][location_id][0]) # item ID
|
||||
return sorted(items)
|
||||
return ctx.locations.get_remaining(ctx.location_checks, team, slot)
|
||||
|
||||
|
||||
def send_items_to(ctx: Context, team: int, target_slot: int, *items: NetworkItem):
|
||||
|
@ -977,13 +972,12 @@ def collect_hints(ctx: Context, team: int, slot: int, item: typing.Union[int, st
|
|||
slots.add(group_id)
|
||||
|
||||
seeked_item_id = item if isinstance(item, int) else ctx.item_names_for_game(ctx.games[slot])[item]
|
||||
for finding_player, check_data in ctx.locations.items():
|
||||
for location_id, (item_id, receiving_player, item_flags) in check_data.items():
|
||||
if receiving_player in slots and item_id == seeked_item_id:
|
||||
found = location_id in ctx.location_checks[team, finding_player]
|
||||
entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "")
|
||||
hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance,
|
||||
item_flags))
|
||||
for finding_player, location_id, item_id, receiving_player, item_flags \
|
||||
in ctx.locations.find_item(slots, seeked_item_id):
|
||||
found = location_id in ctx.location_checks[team, finding_player]
|
||||
entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "")
|
||||
hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance,
|
||||
item_flags))
|
||||
|
||||
return hints
|
||||
|
||||
|
@ -1555,15 +1549,11 @@ class ClientMessageProcessor(CommonCommandProcessor):
|
|||
|
||||
|
||||
def get_checked_checks(ctx: Context, team: int, slot: int) -> typing.List[int]:
|
||||
return [location_id for
|
||||
location_id in ctx.locations[slot] if
|
||||
location_id in ctx.location_checks[team, slot]]
|
||||
return ctx.locations.get_checked(ctx.location_checks, team, slot)
|
||||
|
||||
|
||||
def get_missing_checks(ctx: Context, team: int, slot: int) -> typing.List[int]:
|
||||
return [location_id for
|
||||
location_id in ctx.locations[slot] if
|
||||
location_id not in ctx.location_checks[team, slot]]
|
||||
return ctx.locations.get_missing(ctx.location_checks, team, slot)
|
||||
|
||||
|
||||
def get_client_points(ctx: Context, client: Client) -> int:
|
||||
|
|
63
NetUtils.py
63
NetUtils.py
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import typing
|
||||
import enum
|
||||
import warnings
|
||||
from json import JSONEncoder, JSONDecoder
|
||||
|
||||
import websockets
|
||||
|
@ -343,3 +344,65 @@ class Hint(typing.NamedTuple):
|
|||
@property
|
||||
def local(self):
|
||||
return self.receiving_player == self.finding_player
|
||||
|
||||
|
||||
class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tuple[int, int, int]]]):
|
||||
def find_item(self, slots: typing.Set[int], seeked_item_id: int
|
||||
) -> typing.Generator[typing.Tuple[int, int, int, int, int], None, None]:
|
||||
for finding_player, check_data in self.items():
|
||||
for location_id, (item_id, receiving_player, item_flags) in check_data.items():
|
||||
if receiving_player in slots and item_id == seeked_item_id:
|
||||
yield finding_player, location_id, item_id, receiving_player, item_flags
|
||||
|
||||
def get_for_player(self, slot: int) -> typing.Dict[int, typing.Set[int]]:
|
||||
import collections
|
||||
all_locations: typing.Dict[int, typing.Set[int]] = collections.defaultdict(set)
|
||||
for source_slot, location_data in self.items():
|
||||
for location_id, values in location_data.items():
|
||||
if values[1] == slot:
|
||||
all_locations[source_slot].add(location_id)
|
||||
return all_locations
|
||||
|
||||
def get_checked(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
|
||||
) -> typing.List[int]:
|
||||
checked = state[team, slot]
|
||||
if not checked:
|
||||
# This optimizes the case where everyone connects to a fresh game at the same time.
|
||||
return []
|
||||
return [location_id for
|
||||
location_id in self[slot] if
|
||||
location_id in checked]
|
||||
|
||||
def get_missing(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
|
||||
) -> typing.List[int]:
|
||||
checked = state[team, slot]
|
||||
if not checked:
|
||||
# This optimizes the case where everyone connects to a fresh game at the same time.
|
||||
return list(self)
|
||||
return [location_id for
|
||||
location_id in self[slot] if
|
||||
location_id not in checked]
|
||||
|
||||
def get_remaining(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
|
||||
) -> typing.List[int]:
|
||||
checked = state[team, slot]
|
||||
player_locations = self[slot]
|
||||
return sorted([player_locations[location_id][0] for
|
||||
location_id in player_locations if
|
||||
location_id not in checked])
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING: # type-check with pure python implementation until we have a typing stub
|
||||
LocationStore = _LocationStore
|
||||
else:
|
||||
try:
|
||||
import pyximport
|
||||
pyximport.install()
|
||||
except ImportError:
|
||||
pyximport = None
|
||||
try:
|
||||
from _speedups import LocationStore
|
||||
except ImportError:
|
||||
warnings.warn("_speedups not available. Falling back to pure python LocationStore. "
|
||||
"Install a matching C++ compiler for your platform to compile _speedups.")
|
||||
LocationStore = _LocationStore
|
||||
|
|
|
@ -0,0 +1,335 @@
|
|||
#cython: language_level=3
|
||||
#distutils: language = c++
|
||||
|
||||
"""
|
||||
Provides faster implementation of some core parts.
|
||||
This is deliberately .pyx because using a non-compiled "pure python" may be slower.
|
||||
"""
|
||||
|
||||
# pip install cython cymem
|
||||
import cython
|
||||
from cpython cimport PyObject
|
||||
from typing import Any, Dict, Iterable, Iterator, Generator, Sequence, Tuple, TypeVar, Union, Set, List, TYPE_CHECKING
|
||||
from cymem.cymem cimport Pool
|
||||
from libc.stdint cimport int64_t, uint32_t
|
||||
from libcpp.set cimport set as std_set
|
||||
from collections import defaultdict
|
||||
|
||||
ctypedef uint32_t ap_player_t # on AMD64 this is faster (and smaller) than 64bit ints
|
||||
ctypedef uint32_t ap_flags_t
|
||||
ctypedef int64_t ap_id_t
|
||||
|
||||
cdef ap_player_t MAX_PLAYER_ID = 1000000 # limit the size of indexing array
|
||||
cdef size_t INVALID_SIZE = <size_t>(-1) # this is all 0xff... adding 1 results in 0, but it's not negative
|
||||
|
||||
|
||||
cdef struct LocationEntry:
|
||||
# layout is so that
|
||||
# 64bit player: location+sender and item+receiver 128bit comparisons, if supported
|
||||
# 32bit player: aligned to 32/64bit with no unused space
|
||||
ap_id_t location
|
||||
ap_player_t sender
|
||||
ap_player_t receiver
|
||||
ap_id_t item
|
||||
ap_flags_t flags
|
||||
|
||||
|
||||
cdef struct IndexEntry:
|
||||
size_t start
|
||||
size_t count
|
||||
|
||||
|
||||
cdef class LocationStore:
|
||||
"""Compact store for locations and their items in a MultiServer"""
|
||||
# The original implementation uses Dict[int, Dict[int, Tuple(int, int, int]]
|
||||
# with sender, location, (item, receiver, flags).
|
||||
# This implementation is a flat list of (sender, location, item, receiver, flags) using native integers
|
||||
# as well as some mapping arrays used to speed up stuff, saving a lot of memory while speeding up hints.
|
||||
# Using std::map might be worth investigating, but memory overhead would be ~100% compared to arrays.
|
||||
|
||||
cdef Pool _mem
|
||||
cdef object _len
|
||||
cdef LocationEntry* entries # 3.2MB/100k items
|
||||
cdef size_t entry_count
|
||||
cdef IndexEntry* sender_index # 16KB/1000 players
|
||||
cdef size_t sender_index_size
|
||||
cdef list _keys # ~36KB/1000 players, speed up iter (28 per int + 8 per list entry)
|
||||
cdef list _items # ~64KB/1000 players, speed up items (56 per tuple + 8 per list entry)
|
||||
cdef list _proxies # ~92KB/1000 players, speed up self[player] (56 per struct + 28 per len + 8 per list entry)
|
||||
cdef PyObject** _raw_proxies # 8K/1000 players, faster access to _proxies, but does not keep a ref
|
||||
|
||||
def get_size(self):
|
||||
from sys import getsizeof
|
||||
size = getsizeof(self) + getsizeof(self._mem) + getsizeof(self._len) \
|
||||
+ sizeof(LocationEntry) * self.entry_count + sizeof(IndexEntry) * self.sender_index_size
|
||||
size += getsizeof(self._keys) + getsizeof(self._items) + getsizeof(self._proxies)
|
||||
size += sum(sizeof(key) for key in self._keys)
|
||||
size += sum(sizeof(item) for item in self._items)
|
||||
size += sum(sizeof(proxy) for proxy in self._proxies)
|
||||
size += sizeof(self._raw_proxies[0]) * self.sender_index_size
|
||||
return size
|
||||
|
||||
def __cinit__(self, locations_dict: Dict[int, Dict[int, Sequence[int]]]) -> None:
|
||||
self._mem = None
|
||||
self._keys = None
|
||||
self._items = None
|
||||
self._proxies = None
|
||||
self._len = 0
|
||||
self.entries = NULL
|
||||
self.entry_count = 0
|
||||
self.sender_index = NULL
|
||||
self.sender_index_size = 0
|
||||
self._raw_proxies = NULL
|
||||
|
||||
def __init__(self, locations_dict: Dict[int, Dict[int, Sequence[int]]]) -> None:
|
||||
self._mem = Pool()
|
||||
cdef object key
|
||||
self._keys = []
|
||||
self._items = []
|
||||
self._proxies = []
|
||||
|
||||
# iterate over everything to get all maxima and validate everything
|
||||
cdef size_t max_sender = INVALID_SIZE # keep track of highest used player id for indexing
|
||||
cdef size_t sender_count = 0
|
||||
cdef size_t count = 0
|
||||
for sender, locations in locations_dict.items():
|
||||
# we don't require the dict to be sorted here
|
||||
if not isinstance(sender, int) or sender < 1 or sender > MAX_PLAYER_ID:
|
||||
raise ValueError(f"Invalid player id {sender} for location")
|
||||
if max_sender == INVALID_SIZE:
|
||||
max_sender = sender
|
||||
else:
|
||||
max_sender = max(max_sender, sender)
|
||||
for location, data in locations.items():
|
||||
receiver = data[1]
|
||||
if receiver < 1 or receiver > MAX_PLAYER_ID:
|
||||
raise ValueError(f"Invalid player id {receiver} for item")
|
||||
count += 1
|
||||
sender_count += 1
|
||||
|
||||
if not count:
|
||||
raise ValueError("No locations")
|
||||
|
||||
if sender_count != max_sender:
|
||||
# we assume player 0 will never have locations
|
||||
raise ValueError("Player IDs not continuous")
|
||||
|
||||
# allocate the arrays and invalidate index (0xff...)
|
||||
self.entries = <LocationEntry*>self._mem.alloc(count, sizeof(LocationEntry))
|
||||
self.sender_index = <IndexEntry*>self._mem.alloc(max_sender + 1, sizeof(IndexEntry))
|
||||
self._raw_proxies = <PyObject**>self._mem.alloc(max_sender + 1, sizeof(PyObject*))
|
||||
|
||||
# build entries and index
|
||||
cdef size_t i = 0
|
||||
for sender, locations in sorted(locations_dict.items()):
|
||||
self.sender_index[sender].start = i
|
||||
self.sender_index[sender].count = 0
|
||||
# Sorting locations here makes it possible to write a faster lookup without an additional index.
|
||||
for location, data in sorted(locations.items()):
|
||||
self.entries[i].sender = sender
|
||||
self.entries[i].location = location
|
||||
self.entries[i].item = data[0]
|
||||
self.entries[i].receiver = data[1]
|
||||
if len(data) > 2:
|
||||
self.entries[i].flags = data[2] # initialized to 0 during alloc
|
||||
# Ignoring extra data. warn?
|
||||
self.sender_index[sender].count += 1
|
||||
i += 1
|
||||
|
||||
# build pyobject caches
|
||||
self._proxies.append(None) # player 0
|
||||
assert self.sender_index[0].count == 0
|
||||
for i in range(1, max_sender + 1):
|
||||
if self.sender_index[i].count == 0 and self.sender_index[i].start >= count:
|
||||
self.sender_index[i].start = 0 # do not point outside valid entries
|
||||
assert self.sender_index[i].start < count
|
||||
key = i # allocate python integer
|
||||
proxy = PlayerLocationProxy(self, i)
|
||||
self._keys.append(key)
|
||||
self._items.append((key, proxy))
|
||||
self._proxies.append(proxy)
|
||||
self._raw_proxies[i] = <PyObject*>proxy
|
||||
|
||||
self.sender_index_size = max_sender + 1
|
||||
self.entry_count = count
|
||||
self._len = sender_count
|
||||
|
||||
# fake dict access
|
||||
def __len__(self) -> int:
|
||||
return self._len
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
return self._keys.__iter__()
|
||||
|
||||
def __getitem__(self, key: int) -> Any:
|
||||
# figure out if player actually exists in the multidata and return a proxy
|
||||
cdef size_t i = key # NOTE: this may raise TypeError
|
||||
if i < 1 or i >= self.sender_index_size:
|
||||
raise KeyError(key)
|
||||
return <object>self._raw_proxies[key]
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]:
|
||||
# calling into self.__getitem__ here is slow, but this is not used in MultiServer
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def items(self) -> Iterable[Tuple[int, PlayerLocationProxy]]:
|
||||
return self._items
|
||||
|
||||
# specialized accessors
|
||||
def find_item(self, slots: Set[int], seeked_item_id: int) -> Generator[Tuple[int, int, int, int, int], None, None]:
|
||||
cdef ap_id_t item = seeked_item_id
|
||||
cdef ap_player_t receiver
|
||||
cdef std_set[ap_player_t] receivers
|
||||
cdef size_t slot_count = len(slots)
|
||||
if slot_count == 1:
|
||||
# specialized implementation for single slot
|
||||
receiver = list(slots)[0]
|
||||
with nogil:
|
||||
for entry in self.entries[:self.entry_count]:
|
||||
if entry.item == item and entry.receiver == receiver:
|
||||
with gil:
|
||||
yield entry.sender, entry.location, entry.item, entry.receiver, entry.flags
|
||||
elif slot_count:
|
||||
# generic implementation with lookup in set
|
||||
for receiver in slots:
|
||||
receivers.insert(receiver)
|
||||
with nogil:
|
||||
for entry in self.entries[:self.entry_count]:
|
||||
if entry.item == item and receivers.count(entry.receiver):
|
||||
with gil:
|
||||
yield entry.sender, entry.location, entry.item, entry.receiver, entry.flags
|
||||
|
||||
def get_for_player(self, slot: int) -> Dict[int, Set[int]]:
|
||||
cdef ap_player_t receiver = slot
|
||||
all_locations: Dict[int, Set[int]] = {}
|
||||
with nogil:
|
||||
for entry in self.entries[:self.entry_count]:
|
||||
if entry.receiver == receiver:
|
||||
with gil:
|
||||
sender: int = entry.sender
|
||||
if sender not in all_locations:
|
||||
all_locations[sender] = set()
|
||||
all_locations[sender].add(entry.location)
|
||||
return all_locations
|
||||
|
||||
if TYPE_CHECKING:
|
||||
State = Dict[Tuple[int, int], Set[int]]
|
||||
else:
|
||||
State = Union[Tuple[int, int], Set[int], defaultdict]
|
||||
|
||||
def get_checked(self, state: State, team: int, slot: int) -> List[int]:
|
||||
# This used to validate checks actually exist. A remnant from the past.
|
||||
# If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it.
|
||||
cdef set checked = state[team, slot]
|
||||
|
||||
if not len(checked):
|
||||
# Skips loop if none have been checked.
|
||||
# This optimizes the case where everyone connects to a fresh game at the same time.
|
||||
return []
|
||||
|
||||
# Unless the set is close to empty, it's cheaper to use the python set directly, so we do that.
|
||||
cdef LocationEntry* entry
|
||||
cdef ap_player_t sender = slot
|
||||
cdef size_t start = self.sender_index[sender].start
|
||||
cdef size_t count = self.sender_index[sender].count
|
||||
return [entry.location for
|
||||
entry in self.entries[start:start+count] if
|
||||
entry.location in checked]
|
||||
|
||||
def get_missing(self, state: State, team: int, slot: int) -> List[int]:
|
||||
cdef LocationEntry* entry
|
||||
cdef ap_player_t sender = slot
|
||||
cdef size_t start = self.sender_index[sender].start
|
||||
cdef size_t count = self.sender_index[sender].count
|
||||
cdef set checked = state[team, slot]
|
||||
if not len(checked):
|
||||
# Skip `in` if none have been checked.
|
||||
# This optimizes the case where everyone connects to a fresh game at the same time.
|
||||
return [entry.location for
|
||||
entry in self.entries[start:start + count]]
|
||||
else:
|
||||
# Unless the set is close to empty, it's cheaper to use the python set directly, so we do that.
|
||||
return [entry.location for
|
||||
entry in self.entries[start:start + count] if
|
||||
entry.location not in checked]
|
||||
|
||||
def get_remaining(self, state: State, team: int, slot: int) -> List[int]:
|
||||
cdef LocationEntry* entry
|
||||
cdef ap_player_t sender = slot
|
||||
cdef size_t start = self.sender_index[sender].start
|
||||
cdef size_t count = self.sender_index[sender].count
|
||||
cdef set checked = state[team, slot]
|
||||
return sorted([entry.item for
|
||||
entry in self.entries[start:start+count] if
|
||||
entry.location not in checked])
|
||||
|
||||
|
||||
@cython.internal # unsafe. disable direct import
|
||||
cdef class PlayerLocationProxy:
|
||||
cdef LocationStore _store
|
||||
cdef size_t _player
|
||||
cdef object _len
|
||||
|
||||
def __init__(self, store: LocationStore, player: int) -> None:
|
||||
self._store = store
|
||||
self._player = player
|
||||
self._len = self._store.sender_index[self._player].count
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._store.sender_index[self._player].count
|
||||
|
||||
def __iter__(self) -> Generator[int, None, None]:
|
||||
cdef LocationEntry* entry
|
||||
cdef size_t i
|
||||
cdef size_t off = self._store.sender_index[self._player].start
|
||||
for i in range(self._store.sender_index[self._player].count):
|
||||
entry = self._store.entries + off + i
|
||||
yield entry.location
|
||||
|
||||
cdef LocationEntry* _get(self, ap_id_t loc):
|
||||
# This requires locations to be sorted.
|
||||
# This is always going to be slower than a pure python dict, because constructing the result tuple takes as long
|
||||
# as the search in a python dict, which stores a pointer to an existing tuple.
|
||||
cdef LocationEntry* entry = NULL
|
||||
# binary search
|
||||
cdef size_t l = self._store.sender_index[self._player].start
|
||||
cdef size_t r = l + self._store.sender_index[self._player].count
|
||||
cdef size_t m
|
||||
while l < r:
|
||||
m = (l + r) // 2
|
||||
entry = self._store.entries + m
|
||||
if entry.location < loc:
|
||||
l = m + 1
|
||||
else:
|
||||
r = m
|
||||
if entry: # count != 0
|
||||
entry = self._store.entries + l
|
||||
if entry.location == loc:
|
||||
return entry
|
||||
return NULL
|
||||
|
||||
def __getitem__(self, key: int) -> Tuple[int, int, int]:
|
||||
cdef LocationEntry* entry = self._get(key)
|
||||
if entry:
|
||||
return entry.item, entry.receiver, entry.flags
|
||||
raise KeyError(f"No location {key} for player {self._player}")
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]:
|
||||
cdef LocationEntry* entry = self._get(key)
|
||||
if entry:
|
||||
return entry.item, entry.receiver, entry.flags
|
||||
return default
|
||||
|
||||
def items(self) -> Generator[Tuple[int, Tuple[int, int, int]], None, None]:
|
||||
cdef LocationEntry* entry
|
||||
start = self._store.sender_index[self._player].start
|
||||
count = self._store.sender_index[self._player].count
|
||||
for entry in self._store.entries[start:start+count]:
|
||||
yield entry.location, (entry.item, entry.receiver, entry.flags)
|
|
@ -0,0 +1,8 @@
|
|||
# This file is required to get pyximport to work with C++.
|
||||
# Switching from std::set to a pure C implementation is still on the table to simplify everything.
|
||||
|
||||
def make_ext(modname, pyxfilename):
|
||||
from distutils.extension import Extension
|
||||
return Extension(name=modname,
|
||||
sources=[pyxfilename],
|
||||
language='c++')
|
|
@ -8,3 +8,5 @@ kivy>=2.2.0
|
|||
bsdiff4>=1.2.3
|
||||
platformdirs>=3.8.0
|
||||
certifi>=2023.5.7
|
||||
cython>=0.29.35
|
||||
cymem>=2.0.7
|
||||
|
|
17
setup.py
17
setup.py
|
@ -57,6 +57,7 @@ if __name__ == "__main__":
|
|||
|
||||
from worlds.LauncherComponents import components, icon_paths
|
||||
from Utils import version_tuple, is_windows, is_linux
|
||||
from Cython.Build import cythonize
|
||||
|
||||
|
||||
# On Python < 3.10 LogicMixin is not currently supported.
|
||||
|
@ -292,17 +293,27 @@ class BuildExeCommand(cx_Freeze.command.build_exe.BuildEXE):
|
|||
sni_thread = threading.Thread(target=download_SNI, name="SNI Downloader")
|
||||
sni_thread.start()
|
||||
|
||||
# pre build steps
|
||||
# pre-build steps
|
||||
print(f"Outputting to: {self.buildfolder}")
|
||||
os.makedirs(self.buildfolder, exist_ok=True)
|
||||
import ModuleUpdate
|
||||
ModuleUpdate.requirements_files.add(os.path.join("WebHostLib", "requirements.txt"))
|
||||
ModuleUpdate.update(yes=self.yes)
|
||||
|
||||
# auto-build cython modules
|
||||
build_ext = self.distribution.get_command_obj("build_ext")
|
||||
build_ext.inplace = True
|
||||
self.run_command("build_ext")
|
||||
|
||||
# regular cx build
|
||||
self.buildtime = datetime.datetime.utcnow()
|
||||
super().run()
|
||||
|
||||
# delete in-place built modules, otherwise this interferes with future pyximport
|
||||
for path in build_ext.get_output_mapping().values():
|
||||
print(f"deleting temp {path}")
|
||||
os.unlink(path)
|
||||
|
||||
# need to finish download before copying
|
||||
sni_thread.join()
|
||||
|
||||
|
@ -585,10 +596,10 @@ cx_Freeze.setup(
|
|||
version=f"{version_tuple.major}.{version_tuple.minor}.{version_tuple.build}",
|
||||
description="Archipelago",
|
||||
executables=exes,
|
||||
ext_modules=[], # required to disable auto-discovery with setuptools>=61
|
||||
ext_modules=cythonize("_speedups.pyx"),
|
||||
options={
|
||||
"build_exe": {
|
||||
"packages": ["worlds", "kivy"],
|
||||
"packages": ["worlds", "kivy", "_speedups", "cymem"],
|
||||
"includes": [],
|
||||
"excludes": ["numpy", "Cython", "PySide2", "PIL",
|
||||
"pandas"],
|
||||
|
|
|
@ -0,0 +1,217 @@
|
|||
# Tests for _speedups.LocationStore and NetUtils._LocationStore
|
||||
import typing
|
||||
import unittest
|
||||
from NetUtils import LocationStore, _LocationStore
|
||||
|
||||
|
||||
sample_data = {
|
||||
1: {
|
||||
11: (21, 2, 7),
|
||||
12: (22, 2, 0),
|
||||
13: (13, 1, 0),
|
||||
},
|
||||
2: {
|
||||
23: (11, 1, 0),
|
||||
22: (12, 1, 0),
|
||||
21: (23, 2, 0),
|
||||
},
|
||||
4: {
|
||||
9: (99, 3, 0),
|
||||
},
|
||||
3: {
|
||||
9: (99, 4, 0),
|
||||
},
|
||||
}
|
||||
|
||||
empty_state = {
|
||||
(0, slot): set() for slot in sample_data
|
||||
}
|
||||
|
||||
full_state = {
|
||||
(0, slot): set(locations) for (slot, locations) in sample_data.items()
|
||||
}
|
||||
|
||||
one_state = {
|
||||
(0, 1): {12}
|
||||
}
|
||||
|
||||
|
||||
class Base:
|
||||
class TestLocationStore(unittest.TestCase):
|
||||
store: typing.Union[LocationStore, _LocationStore]
|
||||
|
||||
def test_len(self):
|
||||
self.assertEqual(len(self.store), 4)
|
||||
self.assertEqual(len(self.store[1]), 3)
|
||||
|
||||
def test_key_error(self):
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.store[0]
|
||||
with self.assertRaises(KeyError):
|
||||
_ = self.store[5]
|
||||
locations = self.store[1] # no Exception
|
||||
with self.assertRaises(KeyError):
|
||||
_ = locations[7]
|
||||
_ = locations[11] # no Exception
|
||||
|
||||
def test_getitem(self):
|
||||
self.assertEqual(self.store[1][11], (21, 2, 7))
|
||||
self.assertEqual(self.store[1][13], (13, 1, 0))
|
||||
self.assertEqual(self.store[2][22], (12, 1, 0))
|
||||
self.assertEqual(self.store[4][9], (99, 3, 0))
|
||||
|
||||
def test_get(self):
|
||||
self.assertEqual(self.store.get(1, None), self.store[1])
|
||||
self.assertEqual(self.store.get(0, None), None)
|
||||
self.assertEqual(self.store[1].get(11, (None, None, None)), self.store[1][11])
|
||||
self.assertEqual(self.store[1].get(10, (None, None, None)), (None, None, None))
|
||||
|
||||
def test_iter(self):
|
||||
self.assertEqual(sorted(self.store), [1, 2, 3, 4])
|
||||
self.assertEqual(len(self.store), len(sample_data))
|
||||
self.assertEqual(list(self.store[1]), [11, 12, 13])
|
||||
self.assertEqual(len(self.store[1]), len(sample_data[1]))
|
||||
|
||||
def test_items(self):
|
||||
self.assertEqual(sorted(p for p, _ in self.store.items()), sorted(self.store))
|
||||
self.assertEqual(sorted(p for p, _ in self.store[1].items()), sorted(self.store[1]))
|
||||
self.assertEqual(sorted(self.store.items())[0][0], 1)
|
||||
self.assertEqual(sorted(self.store.items())[0][1], self.store[1])
|
||||
self.assertEqual(sorted(self.store[1].items())[0][0], 11)
|
||||
self.assertEqual(sorted(self.store[1].items())[0][1], self.store[1][11])
|
||||
|
||||
def test_find_item(self):
|
||||
self.assertEqual(sorted(self.store.find_item(set(), 99)), [])
|
||||
self.assertEqual(sorted(self.store.find_item({3}, 1)), [])
|
||||
self.assertEqual(sorted(self.store.find_item({5}, 99)), [])
|
||||
self.assertEqual(sorted(self.store.find_item({3}, 99)),
|
||||
[(4, 9, 99, 3, 0)])
|
||||
self.assertEqual(sorted(self.store.find_item({3, 4}, 99)),
|
||||
[(3, 9, 99, 4, 0), (4, 9, 99, 3, 0)])
|
||||
|
||||
def test_get_for_player(self):
|
||||
self.assertEqual(self.store.get_for_player(3), {4: {9}})
|
||||
self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}})
|
||||
|
||||
def get_checked(self):
|
||||
self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13])
|
||||
self.assertEqual(self.store.get_checked(one_state, 0, 1), [12])
|
||||
self.assertEqual(self.store.get_checked(empty_state, 0, 1), [])
|
||||
self.assertEqual(self.store.get_checked(full_state, 0, 3), [9])
|
||||
|
||||
def get_missing(self):
|
||||
self.assertEqual(self.store.get_missing(full_state, 0, 1), [])
|
||||
self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13])
|
||||
self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13])
|
||||
self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9])
|
||||
|
||||
def get_remaining(self):
|
||||
self.assertEqual(self.store.get_remaining(full_state, 0, 1), [])
|
||||
self.assertEqual(self.store.get_remaining(one_state, 0, 1), [13, 21])
|
||||
self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [13, 21, 22])
|
||||
self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [99])
|
||||
|
||||
|
||||
class TestPurePythonLocationStore(Base.TestLocationStore):
|
||||
def setUp(self) -> None:
|
||||
self.store = _LocationStore(sample_data)
|
||||
super().setUp()
|
||||
|
||||
|
||||
@unittest.skipIf(LocationStore is _LocationStore, "_speedups not available")
|
||||
class TestSpeedupsLocationStore(Base.TestLocationStore):
|
||||
def setUp(self) -> None:
|
||||
self.store = LocationStore(sample_data)
|
||||
super().setUp()
|
||||
|
||||
|
||||
@unittest.skipIf(LocationStore is _LocationStore, "_speedups not available")
|
||||
class TestSpeedupsLocationStoreConstructor(unittest.TestCase):
|
||||
def test_float_key(self):
|
||||
with self.assertRaises(Exception):
|
||||
LocationStore({
|
||||
1: {1: (1, 1, 1)},
|
||||
1.1: {1: (1, 1, 1)},
|
||||
3: {1: (1, 1, 1)}
|
||||
})
|
||||
|
||||
def test_string_key(self):
|
||||
with self.assertRaises(Exception):
|
||||
LocationStore({
|
||||
"1": {1: (1, 1, 1)},
|
||||
})
|
||||
|
||||
def test_hole(self):
|
||||
with self.assertRaises(Exception):
|
||||
LocationStore({
|
||||
1: {1: (1, 1, 1)},
|
||||
3: {1: (1, 1, 1)},
|
||||
})
|
||||
|
||||
def test_no_slot1(self):
|
||||
with self.assertRaises(Exception):
|
||||
LocationStore({
|
||||
2: {1: (1, 1, 1)},
|
||||
3: {1: (1, 1, 1)},
|
||||
})
|
||||
|
||||
def test_slot0(self):
|
||||
with self.assertRaises(Exception):
|
||||
LocationStore({
|
||||
0: {1: (1, 1, 1)},
|
||||
1: {1: (1, 1, 1)},
|
||||
})
|
||||
with self.assertRaises(Exception):
|
||||
LocationStore({
|
||||
0: {1: (1, 1, 1)},
|
||||
2: {1: (1, 1, 1)},
|
||||
})
|
||||
|
||||
def test_high_player_number(self):
|
||||
with self.assertRaises(Exception):
|
||||
LocationStore({
|
||||
1 << 32: {1: (1, 1, 1)},
|
||||
})
|
||||
|
||||
def test_no_players(self):
|
||||
try: # either is fine: raise during init, or behave like {}
|
||||
store = LocationStore({})
|
||||
self.assertEqual(len(store), 0)
|
||||
with self.assertRaises(KeyError):
|
||||
_ = store[1]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def test_no_locations(self):
|
||||
try: # either is fine: raise during init, or behave like {1: {}}
|
||||
store = LocationStore({
|
||||
1: {},
|
||||
})
|
||||
self.assertEqual(len(store), 1)
|
||||
self.assertEqual(len(store[1]), 0)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def test_no_locations_for_1(self):
|
||||
store = LocationStore({
|
||||
1: {},
|
||||
2: {1: (1, 2, 3)},
|
||||
})
|
||||
self.assertEqual(len(store), 2)
|
||||
self.assertEqual(len(store[1]), 0)
|
||||
self.assertEqual(len(store[2]), 1)
|
||||
|
||||
def test_no_locations_for_last(self):
|
||||
store = LocationStore({
|
||||
1: {1: (1, 2, 3)},
|
||||
2: {},
|
||||
})
|
||||
self.assertEqual(len(store), 2)
|
||||
self.assertEqual(len(store[1]), 1)
|
||||
self.assertEqual(len(store[2]), 0)
|
||||
|
||||
def test_not_a_tuple(self):
|
||||
with self.assertRaises(Exception):
|
||||
LocationStore({
|
||||
1: {1: None},
|
||||
})
|
Loading…
Reference in New Issue