Core: clean up BaseClasses a bit (#1731)

This commit is contained in:
el-u 2023-05-25 01:24:12 +02:00 committed by GitHub
parent f4d9c294a3
commit c46d8afcfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 54 additions and 44 deletions

View File

@ -7,9 +7,9 @@ import random
import secrets import secrets
import typing # this can go away when Python 3.8 support is dropped import typing # this can go away when Python 3.8 support is dropped
from argparse import Namespace from argparse import Namespace
from collections import ChainMap, Counter, OrderedDict, deque from collections import ChainMap, Counter, deque
from enum import IntEnum, IntFlag from enum import IntEnum, IntFlag
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, TypedDict, Union from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Optional, Set, Tuple, TypedDict, Union
import NetUtils import NetUtils
import Options import Options
@ -28,15 +28,15 @@ class Group(TypedDict, total=False):
link_replacement: bool link_replacement: bool
class ThreadBarrierProxy(): class ThreadBarrierProxy:
"""Passes through getattr while passthrough is True""" """Passes through getattr while passthrough is True"""
def __init__(self, obj: Any): def __init__(self, obj: object) -> None:
self.passthrough = True self.passthrough = True
self.obj = obj self.obj = obj
def __getattr__(self, item): def __getattr__(self, name: str) -> Any:
if self.passthrough: if self.passthrough:
return getattr(self.obj, item) return getattr(self.obj, name)
else: else:
raise RuntimeError("You are in a threaded context and global random state was removed for your safety. " raise RuntimeError("You are in a threaded context and global random state was removed for your safety. "
"Please use multiworld.per_slot_randoms[player] or randomize ahead of output.") "Please use multiworld.per_slot_randoms[player] or randomize ahead of output.")
@ -1028,15 +1028,19 @@ class Item:
def flags(self) -> int: def flags(self) -> int:
return self.classification.as_flag() return self.classification.as_flag()
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, Item):
return NotImplemented
return self.name == other.name and self.player == other.player return self.name == other.name and self.player == other.player
def __lt__(self, other: Item) -> bool: def __lt__(self, other: object) -> bool:
if not isinstance(other, Item):
return NotImplemented
if other.player != self.player: if other.player != self.player:
return other.player < self.player return other.player < self.player
return self.name < other.name return self.name < other.name
def __hash__(self): def __hash__(self) -> int:
return hash((self.name, self.player)) return hash((self.name, self.player))
def __repr__(self) -> str: def __repr__(self) -> str:
@ -1048,33 +1052,44 @@ class Item:
return f"{self.name} (Player {self.player})" return f"{self.name} (Player {self.player})"
class Spoiler(): class EntranceInfo(TypedDict, total=False):
multiworld: MultiWorld player: int
unreachables: Set[Location] entrance: str
exit: str
direction: str
def __init__(self, world):
self.multiworld = world class Spoiler:
multiworld: MultiWorld
hashes: Dict[int, str]
entrances: Dict[Tuple[str, str, int], EntranceInfo]
playthrough: Dict[str, Union[List[str], Dict[str, str]]] # sphere "0" is list, others are dict
unreachables: Set[Location]
paths: Dict[str, List[Union[Tuple[str, str], Tuple[str, None]]]] # last step takes no further exits
def __init__(self, multiworld: MultiWorld) -> None:
self.multiworld = multiworld
self.hashes = {} self.hashes = {}
self.entrances = OrderedDict() self.entrances = {}
self.playthrough = {} self.playthrough = {}
self.unreachables = set() self.unreachables = set()
self.paths = {} self.paths = {}
def set_entrance(self, entrance: str, exit_: str, direction: str, player: int): def set_entrance(self, entrance: str, exit_: str, direction: str, player: int) -> None:
if self.multiworld.players == 1: if self.multiworld.players == 1:
self.entrances[(entrance, direction, player)] = OrderedDict( self.entrances[(entrance, direction, player)] = \
[('entrance', entrance), ('exit', exit_), ('direction', direction)]) {"entrance": entrance, "exit": exit_, "direction": direction}
else: else:
self.entrances[(entrance, direction, player)] = OrderedDict( self.entrances[(entrance, direction, player)] = \
[('player', player), ('entrance', entrance), ('exit', exit_), ('direction', direction)]) {"player": player, "entrance": entrance, "exit": exit_, "direction": direction}
def create_playthrough(self, create_paths: bool = True): def create_playthrough(self, create_paths: bool = True) -> None:
"""Destructive to the world while it is run, damage gets repaired afterwards.""" """Destructive to the world while it is run, damage gets repaired afterwards."""
from itertools import chain from itertools import chain
# get locations containing progress items # get locations containing progress items
multiworld = self.multiworld multiworld = self.multiworld
prog_locations = {location for location in multiworld.get_filled_locations() if location.item.advancement} prog_locations = {location for location in multiworld.get_filled_locations() if location.item.advancement}
state_cache = [None] state_cache: List[Optional[CollectionState]] = [None]
collection_spheres: List[Set[Location]] = [] collection_spheres: List[Set[Location]] = []
state = CollectionState(multiworld) state = CollectionState(multiworld)
sphere_candidates = set(prog_locations) sphere_candidates = set(prog_locations)
@ -1183,17 +1198,17 @@ class Spoiler():
for item in removed_precollected: for item in removed_precollected:
multiworld.push_precollected(item) multiworld.push_precollected(item)
def create_paths(self, state: CollectionState, collection_spheres: List[Set[Location]]): def create_paths(self, state: CollectionState, collection_spheres: List[Set[Location]]) -> None:
from itertools import zip_longest from itertools import zip_longest
multiworld = self.multiworld multiworld = self.multiworld
def flist_to_iter(node): def flist_to_iter(path_value: Optional[PathValue]) -> Iterator[str]:
while node: while path_value:
value, node = node region_or_entrance, path_value = path_value
yield value yield region_or_entrance
def get_path(state, region): def get_path(state: CollectionState, region: Region) -> List[Union[Tuple[str, str], Tuple[str, None]]]:
reversed_path_as_flist = state.path.get(region, (region, None)) reversed_path_as_flist: PathValue = state.path.get(region, (str(region), None))
string_path_flat = reversed(list(map(str, flist_to_iter(reversed_path_as_flist)))) string_path_flat = reversed(list(map(str, flist_to_iter(reversed_path_as_flist))))
# Now we combine the flat string list into (region, exit) pairs # Now we combine the flat string list into (region, exit) pairs
pathsiter = iter(string_path_flat) pathsiter = iter(string_path_flat)
@ -1219,14 +1234,11 @@ class Spoiler():
self.paths[str(multiworld.get_region('Inverted Big Bomb Shop', player))] = \ self.paths[str(multiworld.get_region('Inverted Big Bomb Shop', player))] = \
get_path(state, multiworld.get_region('Inverted Big Bomb Shop', player)) get_path(state, multiworld.get_region('Inverted Big Bomb Shop', player))
def to_file(self, filename: str): def to_file(self, filename: str) -> None:
def write_option(option_key: str, option_obj: type(Options.Option)): def write_option(option_key: str, option_obj: Options.AssembleOptions) -> None:
res = getattr(self.multiworld, option_key)[player] res = getattr(self.multiworld, option_key)[player]
display_name = getattr(option_obj, "display_name", option_key) display_name = getattr(option_obj, "display_name", option_key)
try: outfile.write(f"{display_name + ':':33}{res.current_option_name}\n")
outfile.write(f'{display_name + ":":33}{res.current_option_name}\n')
except:
raise Exception
with open(filename, 'w', encoding="utf-8-sig") as outfile: with open(filename, 'w', encoding="utf-8-sig") as outfile:
outfile.write( outfile.write(
@ -1259,15 +1271,15 @@ class Spoiler():
AutoWorld.call_all(self.multiworld, "write_spoiler", outfile) AutoWorld.call_all(self.multiworld, "write_spoiler", outfile)
locations = [(str(location), str(location.item) if location.item is not None else "Nothing") locations = [(str(location), str(location.item) if location.item is not None else "Nothing")
for location in self.multiworld.get_locations() if location.show_in_spoiler] for location in self.multiworld.get_locations() if location.show_in_spoiler]
outfile.write('\n\nLocations:\n\n') outfile.write('\n\nLocations:\n\n')
outfile.write('\n'.join( outfile.write('\n'.join(
['%s: %s' % (location, item) for location, item in locations])) ['%s: %s' % (location, item) for location, item in locations]))
outfile.write('\n\nPlaythrough:\n\n') outfile.write('\n\nPlaythrough:\n\n')
outfile.write('\n'.join(['%s: {\n%s\n}' % (sphere_nr, '\n'.join( outfile.write('\n'.join(['%s: {\n%s\n}' % (sphere_nr, '\n'.join(
[' %s: %s' % (location, item) for (location, item) in sphere.items()] if sphere_nr != '0' else [ [f" {location}: {item}" for (location, item) in sphere.items()] if isinstance(sphere, dict) else
f' {item}' for item in sphere])) for (sphere_nr, sphere) in self.playthrough.items()])) [f" {item}" for item in sphere])) for (sphere_nr, sphere) in self.playthrough.items()]))
if self.unreachables: if self.unreachables:
outfile.write('\n\nUnreachable Items:\n\n') outfile.write('\n\nUnreachable Items:\n\n')
outfile.write( outfile.write(
@ -1328,23 +1340,21 @@ class PlandoOptions(IntFlag):
@classmethod @classmethod
def _handle_part(cls, part: str, base: PlandoOptions) -> PlandoOptions: def _handle_part(cls, part: str, base: PlandoOptions) -> PlandoOptions:
try: try:
part = cls[part] return base | cls[part]
except Exception as e: except Exception as e:
raise KeyError(f"{part} is not a recognized name for a plando module. " raise KeyError(f"{part} is not a recognized name for a plando module. "
f"Known options: {', '.join(flag.name for flag in cls)}") from e f"Known options: {', '.join(str(flag.name) for flag in cls)}") from e
else:
return base | part
def __str__(self) -> str: def __str__(self) -> str:
if self.value: if self.value:
return ", ".join(flag.name for flag in PlandoOptions if self.value & flag.value) return ", ".join(str(flag.name) for flag in PlandoOptions if self.value & flag.value)
return "None" return "None"
seeddigits = 20 seeddigits = 20
def get_seed(seed=None) -> int: def get_seed(seed: Optional[int] = None) -> int:
if seed is None: if seed is None:
random.seed(None) random.seed(None)
return random.randint(0, pow(10, seeddigits) - 1) return random.randint(0, pow(10, seeddigits) - 1)