The Witness: Small code refactor (cast_not_none) (#3798)

* cast not none

* ruff

* Missed a spot
This commit is contained in:
NewSoupVi 2024-10-02 00:02:17 +02:00 committed by GitHub
parent f06f95d03d
commit 0ec9039ca6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 10 deletions

View File

@ -14,7 +14,7 @@ from .data import static_items as static_witness_items
from .data import static_locations as static_witness_locations from .data import static_locations as static_witness_locations
from .data import static_logic as static_witness_logic from .data import static_logic as static_witness_logic
from .data.item_definition_classes import DoorItemDefinition, ItemData from .data.item_definition_classes import DoorItemDefinition, ItemData
from .data.utils import get_audio_logs from .data.utils import cast_not_none, get_audio_logs
from .hints import CompactHintData, create_all_hints, make_compact_hint_data, make_laser_hints from .hints import CompactHintData, create_all_hints, make_compact_hint_data, make_laser_hints
from .locations import WitnessPlayerLocations from .locations import WitnessPlayerLocations
from .options import TheWitnessOptions, witness_option_groups from .options import TheWitnessOptions, witness_option_groups
@ -55,7 +55,7 @@ class WitnessWorld(World):
item_name_to_id = { item_name_to_id = {
# ITEM_DATA doesn't have any event items in it # ITEM_DATA doesn't have any event items in it
name: cast(int, data.ap_code) for name, data in static_witness_items.ITEM_DATA.items() name: cast_not_none(data.ap_code) for name, data in static_witness_items.ITEM_DATA.items()
} }
location_name_to_id = static_witness_locations.ALL_LOCATIONS_TO_ID location_name_to_id = static_witness_locations.ALL_LOCATIONS_TO_ID
item_name_groups = static_witness_items.ITEM_GROUPS item_name_groups = static_witness_items.ITEM_GROUPS
@ -336,7 +336,7 @@ class WitnessWorld(World):
for item_name, hint in laser_hints.items(): for item_name, hint in laser_hints.items():
item_def = cast(DoorItemDefinition, static_witness_logic.ALL_ITEMS[item_name]) item_def = cast(DoorItemDefinition, static_witness_logic.ALL_ITEMS[item_name])
self.laser_ids_to_hints[int(item_def.panel_id_hexes[0], 16)] = make_compact_hint_data(hint, self.player) self.laser_ids_to_hints[int(item_def.panel_id_hexes[0], 16)] = make_compact_hint_data(hint, self.player)
already_hinted_locations.add(cast(Location, hint.location)) already_hinted_locations.add(cast_not_none(hint.location))
# Audio Log Hints # Audio Log Hints

View File

@ -1,7 +1,7 @@
from math import floor from math import floor
from pkgutil import get_data from pkgutil import get_data
from random import Random from random import Random
from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Set, Tuple, TypeVar from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, TypeVar
T = TypeVar("T") T = TypeVar("T")
@ -13,6 +13,11 @@ T = TypeVar("T")
WitnessRule = FrozenSet[FrozenSet[str]] WitnessRule = FrozenSet[FrozenSet[str]]
def cast_not_none(value: Optional[T]) -> T:
assert value is not None
return value
def weighted_sample(world_random: Random, population: List[T], weights: List[float], k: int) -> List[T]: def weighted_sample(world_random: Random, population: List[T], weights: List[float], k: int) -> List[T]:
positions = range(len(population)) positions = range(len(population))
indices: List[int] = [] indices: List[int] = []

View File

@ -15,7 +15,7 @@ from .data.item_definition_classes import (
ProgressiveItemDefinition, ProgressiveItemDefinition,
WeightedItemDefinition, WeightedItemDefinition,
) )
from .data.utils import build_weighted_int_list from .data.utils import build_weighted_int_list, cast_not_none
from .locations import WitnessPlayerLocations from .locations import WitnessPlayerLocations
from .player_logic import WitnessPlayerLogic from .player_logic import WitnessPlayerLogic
@ -200,7 +200,7 @@ class WitnessPlayerItems:
""" """
return [ return [
# data.ap_code is guaranteed for a symbol definition # data.ap_code is guaranteed for a symbol definition
cast(int, data.ap_code) for name, data in static_witness_items.ITEM_DATA.items() cast_not_none(data.ap_code) for name, data in static_witness_items.ITEM_DATA.items()
if name not in self.item_data.keys() and data.definition.category is ItemCategory.SYMBOL if name not in self.item_data.keys() and data.definition.category is ItemCategory.SYMBOL
] ]
@ -211,8 +211,8 @@ class WitnessPlayerItems:
if isinstance(item.definition, ProgressiveItemDefinition): if isinstance(item.definition, ProgressiveItemDefinition):
# Note: we need to reference the static table here rather than the player-specific one because the child # Note: we need to reference the static table here rather than the player-specific one because the child
# items were removed from the pool when we pruned out all progression items not in the options. # items were removed from the pool when we pruned out all progression items not in the options.
output[cast(int, item.ap_code)] = [cast(int, static_witness_items.ITEM_DATA[child_item].ap_code) output[cast_not_none(item.ap_code)] = [cast_not_none(static_witness_items.ITEM_DATA[child_item].ap_code)
for child_item in item.definition.child_item_names] for child_item in item.definition.child_item_names]
return output return output

View File

@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, Iterable, List, Mapping, Union, cast from typing import Any, ClassVar, Dict, Iterable, List, Mapping, Union
from BaseClasses import CollectionState, Entrance, Item, Location, Region from BaseClasses import CollectionState, Entrance, Item, Location, Region
@ -7,6 +7,7 @@ from test.general import gen_steps, setup_multiworld
from test.multiworld.test_multiworlds import MultiworldTestBase from test.multiworld.test_multiworlds import MultiworldTestBase
from .. import WitnessWorld from .. import WitnessWorld
from ..data.utils import cast_not_none
class WitnessTestBase(WorldTestBase): class WitnessTestBase(WorldTestBase):
@ -32,7 +33,7 @@ class WitnessTestBase(WorldTestBase):
event_items = [item for item in self.multiworld.get_items() if item.name == item_name] event_items = [item for item in self.multiworld.get_items() if item.name == item_name]
self.assertTrue(event_items, f"Event item {item_name} does not exist.") self.assertTrue(event_items, f"Event item {item_name} does not exist.")
event_locations = [cast(Location, event_item.location) for event_item in event_items] event_locations = [cast_not_none(event_item.location) for event_item in event_items]
# Checking for an access dependency on an event item requires a bit of extra work, # Checking for an access dependency on an event item requires a bit of extra work,
# as state.remove forces a sweep, which will pick up the event item again right after we tried to remove it. # as state.remove forces a sweep, which will pick up the event item again right after we tried to remove it.