Stardew Valley: Improve generation performance by around 11% by moving calculating from rule evaluation to collect (#4231)

This commit is contained in:
Jouramie 2025-01-18 20:36:01 -05:00 committed by GitHub
parent 1c9409cac9
commit 992f192529
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 42 additions and 56 deletions

View File

@ -1,6 +1,6 @@
import logging import logging
from random import Random from random import Random
from typing import Dict, Any, Iterable, Optional, Union, List, TextIO from typing import Dict, Any, Iterable, Optional, List, TextIO
from BaseClasses import Region, Entrance, Location, Item, Tutorial, ItemClassification, MultiWorld, CollectionState from BaseClasses import Region, Entrance, Location, Item, Tutorial, ItemClassification, MultiWorld, CollectionState
from Options import PerGameCommonOptions from Options import PerGameCommonOptions
@ -88,7 +88,6 @@ class StardewValleyWorld(World):
randomized_entrances: Dict[str, str] randomized_entrances: Dict[str, str]
total_progression_items: int total_progression_items: int
excluded_from_total_progression_items: List[str] = [Event.received_walnuts]
def __init__(self, multiworld: MultiWorld, player: int): def __init__(self, multiworld: MultiWorld, player: int):
super().__init__(multiworld, player) super().__init__(multiworld, player)
@ -176,7 +175,7 @@ class StardewValleyWorld(World):
if self.options.season_randomization == SeasonRandomization.option_disabled: if self.options.season_randomization == SeasonRandomization.option_disabled:
for season in season_pool: for season in season_pool:
self.multiworld.push_precollected(self.create_starting_item(season)) self.multiworld.push_precollected(self.create_item(season))
return return
if [item for item in self.multiworld.precollected_items[self.player] if [item for item in self.multiworld.precollected_items[self.player]
@ -186,12 +185,12 @@ class StardewValleyWorld(World):
if self.options.season_randomization == SeasonRandomization.option_randomized_not_winter: if self.options.season_randomization == SeasonRandomization.option_randomized_not_winter:
season_pool = [season for season in season_pool if season.name != "Winter"] season_pool = [season for season in season_pool if season.name != "Winter"]
starting_season = self.create_starting_item(self.random.choice(season_pool)) starting_season = self.create_item(self.random.choice(season_pool))
self.multiworld.push_precollected(starting_season) self.multiworld.push_precollected(starting_season)
def precollect_farm_type_items(self): def precollect_farm_type_items(self):
if self.options.farm_type == FarmType.option_meadowlands and self.options.building_progression & BuildingProgression.option_progressive: if self.options.farm_type == FarmType.option_meadowlands and self.options.building_progression & BuildingProgression.option_progressive:
self.multiworld.push_precollected(self.create_starting_item("Progressive Coop")) self.multiworld.push_precollected(self.create_item("Progressive Coop"))
def setup_logic_events(self): def setup_logic_events(self):
def register_event(name: str, region: str, rule: StardewRule): def register_event(name: str, region: str, rule: StardewRule):
@ -271,7 +270,7 @@ class StardewValleyWorld(World):
def get_all_location_names(self) -> List[str]: def get_all_location_names(self) -> List[str]:
return list(location.name for location in self.multiworld.get_locations(self.player)) return list(location.name for location in self.multiworld.get_locations(self.player))
def create_item(self, item: Union[str, ItemData], override_classification: ItemClassification = None) -> StardewItem: def create_item(self, item: str | ItemData, override_classification: ItemClassification = None) -> StardewItem:
if isinstance(item, str): if isinstance(item, str):
item = item_table[item] item = item_table[item]
@ -280,12 +279,6 @@ class StardewValleyWorld(World):
return StardewItem(item.name, override_classification, item.code, self.player) return StardewItem(item.name, override_classification, item.code, self.player)
def create_starting_item(self, item: Union[str, ItemData]) -> StardewItem:
if isinstance(item, str):
item = item_table[item]
return StardewItem(item.name, item.classification, item.code, self.player)
def create_event_location(self, location_data: LocationData, rule: StardewRule = None, item: Optional[str] = None): def create_event_location(self, location_data: LocationData, rule: StardewRule = None, item: Optional[str] = None):
if rule is None: if rule is None:
rule = True_() rule = True_()
@ -393,9 +386,19 @@ class StardewValleyWorld(World):
if not change: if not change:
return False return False
player_state = state.prog_items[self.player]
received_progression_count = player_state[Event.received_progression_item]
received_progression_count += 1
if self.total_progression_items:
# Total progression items is not set until all items are created, but collect will be called during the item creation when an item is precollected.
# We can't update the percentage if we don't know the total progression items, can't divide by 0.
player_state[Event.received_progression_percent] = received_progression_count * 100 // self.total_progression_items
player_state[Event.received_progression_item] = received_progression_count
walnut_amount = self.get_walnut_amount(item.name) walnut_amount = self.get_walnut_amount(item.name)
if walnut_amount: if walnut_amount:
state.prog_items[self.player][Event.received_walnuts] += walnut_amount player_state[Event.received_walnuts] += walnut_amount
return True return True
@ -404,9 +407,18 @@ class StardewValleyWorld(World):
if not change: if not change:
return False return False
player_state = state.prog_items[self.player]
received_progression_count = player_state[Event.received_progression_item]
received_progression_count -= 1
if self.total_progression_items:
# We can't update the percentage if we don't know the total progression items, can't divide by 0.
player_state[Event.received_progression_percent] = received_progression_count * 100 // self.total_progression_items
player_state[Event.received_progression_item] = received_progression_count
walnut_amount = self.get_walnut_amount(item.name) walnut_amount = self.get_walnut_amount(item.name)
if walnut_amount: if walnut_amount:
state.prog_items[self.player][Event.received_walnuts] -= walnut_amount player_state[Event.received_walnuts] -= walnut_amount
return True return True

View File

@ -4,6 +4,7 @@ from typing import Iterable, Union, List, Tuple, Hashable, TYPE_CHECKING
from BaseClasses import CollectionState from BaseClasses import CollectionState
from .base import BaseStardewRule, CombinableStardewRule from .base import BaseStardewRule, CombinableStardewRule
from .protocol import StardewRule from .protocol import StardewRule
from ..strings.ap_names.event_names import Event
if TYPE_CHECKING: if TYPE_CHECKING:
from .. import StardewValleyWorld from .. import StardewValleyWorld
@ -87,45 +88,13 @@ class Reach(BaseStardewRule):
return f"Reach {self.resolution_hint} {self.spot}" return f"Reach {self.resolution_hint} {self.spot}"
@dataclass(frozen=True) class HasProgressionPercent(Received):
class HasProgressionPercent(CombinableStardewRule): def __init__(self, player: int, percent: int):
player: int super().__init__(Event.received_progression_percent, player, percent, event=True)
percent: int
def __post_init__(self): def __post_init__(self):
assert self.percent > 0, "HasProgressionPercent rule must be above 0%" assert self.count > 0, "HasProgressionPercent rule must be above 0%"
assert self.percent <= 100, "HasProgressionPercent rule can't require more than 100% of items" assert self.count <= 100, "HasProgressionPercent rule can't require more than 100% of items"
@property
def combination_key(self) -> Hashable:
return HasProgressionPercent.__name__
@property
def value(self):
return self.percent
def __call__(self, state: CollectionState) -> bool:
stardew_world: "StardewValleyWorld" = state.multiworld.worlds[self.player]
total_count = stardew_world.total_progression_items
needed_count = (total_count * self.percent) // 100
player_state = state.prog_items[self.player]
if needed_count <= len(player_state) - len(stardew_world.excluded_from_total_progression_items):
return True
total_count = 0
for item, item_count in player_state.items():
if item in stardew_world.excluded_from_total_progression_items:
continue
total_count += item_count
if total_count >= needed_count:
return True
return False
def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
return self, self(state)
def __repr__(self): def __repr__(self):
return f"Received {self.percent}% progression items" return f"Received {self.count}% progression items"

View File

@ -10,3 +10,5 @@ class Event:
victory = event("Victory") victory = event("Victory")
received_walnuts = event("Received Walnuts") received_walnuts = event("Received Walnuts")
received_progression_item = event("Received Progression Item")
received_progression_percent = event("Received Progression Percent")

View File

@ -69,14 +69,17 @@ class TestShipsanityEverything(SVTestBase):
def test_all_shipsanity_locations_require_shipping_bin(self): def test_all_shipsanity_locations_require_shipping_bin(self):
bin_name = "Shipping Bin" bin_name = "Shipping Bin"
self.collect_all_except(bin_name) self.collect_all_except(bin_name)
shipsanity_locations = [location for location in self.get_real_locations() if shipsanity_locations = [location
LocationTags.SHIPSANITY in location_table[location.name].tags] for location in self.get_real_locations()
if LocationTags.SHIPSANITY in location_table[location.name].tags]
bin_item = self.create_item(bin_name) bin_item = self.create_item(bin_name)
for location in shipsanity_locations: for location in shipsanity_locations:
with self.subTest(location.name): with self.subTest(location.name):
self.remove(bin_item)
self.assertFalse(self.world.logic.region.can_reach_location(location.name)(self.multiworld.state)) self.assertFalse(self.world.logic.region.can_reach_location(location.name)(self.multiworld.state))
self.multiworld.state.collect(bin_item)
self.collect(bin_item)
shipsanity_rule = self.world.logic.region.can_reach_location(location.name) shipsanity_rule = self.world.logic.region.can_reach_location(location.name)
self.assert_rule_true(shipsanity_rule, self.multiworld.state) self.assert_rule_true(shipsanity_rule, self.multiworld.state)
self.remove(bin_item) self.remove(bin_item)