Archipelago/worlds/stardew_valley/stardew_rule/base.py

482 lines
18 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from collections import deque, Counter
from dataclasses import dataclass, field
from functools import cached_property
from itertools import chain
from threading import Lock
from typing import Iterable, Dict, List, Union, Sized, Hashable, Callable, Tuple, Set, Optional, cast
from BaseClasses import CollectionState
from .literal import true_, false_, LiteralStardewRule
from .protocol import StardewRule
MISSING_ITEM = "THIS ITEM IS MISSING"
class BaseStardewRule(StardewRule, ABC):
def __or__(self, other) -> StardewRule:
if other is true_ or other is false_ or type(other) is Or:
return other | self
return Or(self, other)
def __and__(self, other) -> StardewRule:
if other is true_ or other is false_ or type(other) is And:
return other & self
return And(self, other)
class CombinableStardewRule(BaseStardewRule, ABC):
@property
@abstractmethod
def combination_key(self) -> Hashable:
raise NotImplementedError
@property
@abstractmethod
def value(self):
raise NotImplementedError
def is_same_rule(self, other: CombinableStardewRule):
return self.combination_key == other.combination_key
def add_into(self, rules: Dict[Hashable, CombinableStardewRule], reducer: Callable[[CombinableStardewRule, CombinableStardewRule], CombinableStardewRule]) \
-> Dict[Hashable, CombinableStardewRule]:
rules = dict(rules)
if self.combination_key in rules:
rules[self.combination_key] = reducer(self, rules[self.combination_key])
else:
rules[self.combination_key] = self
return rules
def __and__(self, other):
if isinstance(other, CombinableStardewRule) and self.is_same_rule(other):
return And.combine(self, other)
return super().__and__(other)
def __or__(self, other):
if isinstance(other, CombinableStardewRule) and self.is_same_rule(other):
return Or.combine(self, other)
return super().__or__(other)
class _SimplificationState:
original_simplifiable_rules: Tuple[StardewRule, ...]
rules_to_simplify: deque[StardewRule]
simplified_rules: Set[StardewRule]
lock: Lock
def __init__(self, simplifiable_rules: Tuple[StardewRule, ...], rules_to_simplify: Optional[deque[StardewRule]] = None,
simplified_rules: Optional[Set[StardewRule]] = None):
if simplified_rules is None:
simplified_rules = set()
self.original_simplifiable_rules = simplifiable_rules
self.rules_to_simplify = rules_to_simplify
self.simplified_rules = simplified_rules
self.locked = False
@property
def is_simplified(self):
return self.rules_to_simplify is not None and not self.rules_to_simplify
def short_circuit(self, complement: LiteralStardewRule):
self.rules_to_simplify = deque()
self.simplified_rules = {complement}
def try_popleft(self):
try:
self.rules_to_simplify.popleft()
except IndexError:
pass
def acquire_copy(self):
state = _SimplificationState(self.original_simplifiable_rules, self.rules_to_simplify.copy(), self.simplified_rules.copy())
state.acquire()
return state
def merge(self, other: _SimplificationState):
return _SimplificationState(self.original_simplifiable_rules + other.original_simplifiable_rules)
def add(self, rule: StardewRule):
return _SimplificationState(self.original_simplifiable_rules + (rule,))
def acquire(self):
"""
This just set a boolean to True and is absolutely not thread safe. It just works because AP is single-threaded.
"""
if self.locked is True:
return False
self.locked = True
return True
def release(self):
assert self.locked
self.locked = False
class AggregatingStardewRule(BaseStardewRule, ABC):
"""
Logic for both "And" and "Or" rules.
"""
identity: LiteralStardewRule
complement: LiteralStardewRule
symbol: str
combinable_rules: Dict[Hashable, CombinableStardewRule]
simplification_state: _SimplificationState
_last_short_circuiting_rule: Optional[StardewRule] = None
def __init__(self, *rules: StardewRule, _combinable_rules=None, _simplification_state=None):
if _combinable_rules is None:
assert rules, f"Can't create an aggregating condition without rules"
rules, _combinable_rules = self.split_rules(rules)
_simplification_state = _SimplificationState(rules)
self.combinable_rules = _combinable_rules
self.simplification_state = _simplification_state
@property
def original_rules(self):
return RepeatableChain(self.combinable_rules.values(), self.simplification_state.original_simplifiable_rules)
@property
def current_rules(self):
if self.simplification_state.rules_to_simplify is None:
return self.original_rules
return RepeatableChain(self.combinable_rules.values(), self.simplification_state.simplified_rules, self.simplification_state.rules_to_simplify)
@classmethod
def split_rules(cls, rules: Union[Iterable[StardewRule]]) -> Tuple[Tuple[StardewRule, ...], Dict[Hashable, CombinableStardewRule]]:
other_rules = []
reduced_rules = {}
for rule in rules:
if isinstance(rule, CombinableStardewRule):
key = rule.combination_key
if key not in reduced_rules:
reduced_rules[key] = rule
continue
reduced_rules[key] = cls.combine(reduced_rules[key], rule)
continue
if type(rule) is cls:
other_rules.extend(rule.simplification_state.original_simplifiable_rules) # noqa
reduced_rules = cls.merge(reduced_rules, rule.combinable_rules) # noqa
continue
other_rules.append(rule)
return tuple(other_rules), reduced_rules
@classmethod
def merge(cls, left: Dict[Hashable, CombinableStardewRule], right: Dict[Hashable, CombinableStardewRule]) -> Dict[Hashable, CombinableStardewRule]:
reduced_rules = dict(left)
for key, rule in right.items():
if key not in reduced_rules:
reduced_rules[key] = rule
continue
reduced_rules[key] = cls.combine(reduced_rules[key], rule)
return reduced_rules
@staticmethod
@abstractmethod
def combine(left: CombinableStardewRule, right: CombinableStardewRule) -> CombinableStardewRule:
raise NotImplementedError
def short_circuit_simplification(self):
self.simplification_state.short_circuit(self.complement)
self.combinable_rules = {}
return self.complement, self.complement.value
def short_circuit_evaluation(self, rule):
self._last_short_circuiting_rule = rule
return self, self.complement.value
def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
"""
The global idea here is the same as short-circuiting operators, applied to evaluation and rule simplification.
"""
# Directly checking last rule that short-circuited, in case state has not changed.
if self._last_short_circuiting_rule:
if self._last_short_circuiting_rule(state) is self.complement.value:
return self.short_circuit_evaluation(self._last_short_circuiting_rule)
self._last_short_circuiting_rule = None
# Combinable rules are considered already simplified, so we evaluate them right away to go faster.
for rule in self.combinable_rules.values():
if rule(state) is self.complement.value:
return self.short_circuit_evaluation(rule)
if self.simplification_state.is_simplified:
# The rule is fully simplified, so now we can only evaluate.
for rule in self.simplification_state.simplified_rules:
if rule(state) is self.complement.value:
return self.short_circuit_evaluation(rule)
return self, self.identity.value
return self.evaluate_while_simplifying_stateful(state)
def evaluate_while_simplifying_stateful(self, state):
local_state = self.simplification_state
try:
# Creating a new copy, so we don't modify the rules while we're already evaluating it. This can happen if a rule is used for an entrance and a
# location. When evaluating a given rule what requires access to a region, the region cache can get an update. If it does, we could enter this rule
# again. Since the simplification is stateful, the set of simplified rules can be modified while it's being iterated on, and cause a crash.
#
# After investigation, for millions of call to this method, copy were acquired 425 times.
# Merging simplification state in parent call was deemed useless.
if not local_state.acquire():
local_state = local_state.acquire_copy()
self.simplification_state = local_state
# Evaluating what has already been simplified. First it will be faster than simplifying "new" rules, but we also assume that if we reach this point
# and there are already are simplified rule, one of these rules has short-circuited, and might again, so we can leave early.
for rule in local_state.simplified_rules:
if rule(state) is self.complement.value:
return self.short_circuit_evaluation(rule)
# If the queue is None, it means we have not start simplifying. Otherwise, we will continue simplification where we left.
if local_state.rules_to_simplify is None:
rules_to_simplify = frozenset(local_state.original_simplifiable_rules)
if self.complement in rules_to_simplify:
return self.short_circuit_simplification()
local_state.rules_to_simplify = deque(rules_to_simplify)
# Start simplification where we left.
while local_state.rules_to_simplify:
result = self.evaluate_rule_while_simplifying_stateful(local_state, state)
local_state.try_popleft()
if result is not None:
return result
# The whole rule has been simplified and evaluated without short-circuit.
return self, self.identity.value
finally:
local_state.release()
def evaluate_rule_while_simplifying_stateful(self, local_state, state):
simplified, value = local_state.rules_to_simplify[0].evaluate_while_simplifying(state)
# Identity is removed from the resulting simplification since it does not affect the result.
if simplified is self.identity:
return
# If we find a complement here, we know the rule will always short-circuit, what ever the state.
if simplified is self.complement:
return self.short_circuit_simplification()
# Keep the simplified rule to be reevaluated later.
local_state.simplified_rules.add(simplified)
# Now we use the value to short-circuit if it is the complement.
if value is self.complement.value:
return self.short_circuit_evaluation(simplified)
def __str__(self):
return f"({self.symbol.join(str(rule) for rule in self.original_rules)})"
def __repr__(self):
return f"({self.symbol.join(repr(rule) for rule in self.original_rules)})"
def __eq__(self, other):
return (isinstance(other, type(self)) and self.combinable_rules == other.combinable_rules and
self.simplification_state.original_simplifiable_rules == other.simplification_state.original_simplifiable_rules)
def __hash__(self):
if len(self.combinable_rules) + len(self.simplification_state.original_simplifiable_rules) > 5:
return id(self)
return hash((*self.combinable_rules.values(), self.simplification_state.original_simplifiable_rules))
class Or(AggregatingStardewRule):
identity = false_
complement = true_
symbol = " | "
def __call__(self, state: CollectionState) -> bool:
return self.evaluate_while_simplifying(state)[1]
def __or__(self, other):
if other is true_ or other is false_:
return other | self
if isinstance(other, CombinableStardewRule):
return Or(_combinable_rules=other.add_into(self.combinable_rules, self.combine), _simplification_state=self.simplification_state)
if type(other) is Or:
other = cast(Or, other)
return Or(_combinable_rules=self.merge(self.combinable_rules, other.combinable_rules),
_simplification_state=self.simplification_state.merge(other.simplification_state))
return Or(_combinable_rules=self.combinable_rules, _simplification_state=self.simplification_state.add(other))
@staticmethod
def combine(left: CombinableStardewRule, right: CombinableStardewRule) -> CombinableStardewRule:
return min(left, right, key=lambda x: x.value)
class And(AggregatingStardewRule):
identity = true_
complement = false_
symbol = " & "
def __call__(self, state: CollectionState) -> bool:
return self.evaluate_while_simplifying(state)[1]
def __and__(self, other):
if other is true_ or other is false_:
return other & self
if isinstance(other, CombinableStardewRule):
return And(_combinable_rules=other.add_into(self.combinable_rules, self.combine), _simplification_state=self.simplification_state)
if type(other) is And:
other = cast(And, other)
return And(_combinable_rules=self.merge(self.combinable_rules, other.combinable_rules),
_simplification_state=self.simplification_state.merge(other.simplification_state))
return And(_combinable_rules=self.combinable_rules, _simplification_state=self.simplification_state.add(other))
@staticmethod
def combine(left: CombinableStardewRule, right: CombinableStardewRule) -> CombinableStardewRule:
return max(left, right, key=lambda x: x.value)
class Count(BaseStardewRule):
count: int
rules: List[StardewRule]
counter: Counter[StardewRule]
evaluate: Callable[[CollectionState], bool]
total: Optional[int]
rule_mapping: Optional[Dict[StardewRule, StardewRule]]
def __init__(self, rules: List[StardewRule], count: int):
self.count = count
self.counter = Counter(rules)
if len(self.counter) / len(rules) < .66:
# Checking if it's worth using the count operation with shortcircuit or not. Value should be fine-tuned when Count has more usage.
self.total = sum(self.counter.values())
self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True)
self.rule_mapping = {}
self.evaluate = self.evaluate_with_shortcircuit
else:
self.rules = rules
self.evaluate = self.evaluate_without_shortcircuit
def __call__(self, state: CollectionState) -> bool:
return self.evaluate(state)
def evaluate_without_shortcircuit(self, state: CollectionState) -> bool:
c = 0
for i in range(self.rules_count):
self.rules[i], value = self.rules[i].evaluate_while_simplifying(state)
if value:
c += 1
if c >= self.count:
return True
if c + self.rules_count - i < self.count:
break
return False
def evaluate_with_shortcircuit(self, state: CollectionState) -> bool:
c = 0
t = self.total
for rule in self.rules:
evaluation_value = self.call_evaluate_while_simplifying_cached(rule, state)
rule_value = self.counter[rule]
if evaluation_value:
c += rule_value
else:
t -= rule_value
if c >= self.count:
return True
elif t < self.count:
break
return False
def call_evaluate_while_simplifying_cached(self, rule: StardewRule, state: CollectionState) -> bool:
try:
# A mapping table with the original rule is used here because two rules could resolve to the same rule.
# This would require to change the counter to merge both rules, and quickly become complicated.
return self.rule_mapping[rule](state)
except KeyError:
self.rule_mapping[rule], value = rule.evaluate_while_simplifying(state)
return value
def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
return self, self(state)
@cached_property
def rules_count(self):
return len(self.rules)
def __repr__(self):
return f"Received {self.count} [{', '.join(f'{value}x {repr(rule)}' for rule, value in self.counter.items())}]"
@dataclass(frozen=True)
class Has(BaseStardewRule):
item: str
# For sure there is a better way than just passing all the rules everytime
other_rules: Dict[str, StardewRule] = field(repr=False, hash=False, compare=False)
group: str = "item"
def __call__(self, state: CollectionState) -> bool:
return self.evaluate_while_simplifying(state)[1]
def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
return self.other_rules[self.item].evaluate_while_simplifying(state)
def __str__(self):
if self.item not in self.other_rules:
return f"Has {self.item} ({self.group}) -> {MISSING_ITEM}"
return f"Has {self.item} ({self.group})"
def __repr__(self):
if self.item not in self.other_rules:
return f"Has {self.item} ({self.group}) -> {MISSING_ITEM}"
return f"Has {self.item} ({self.group}) -> {repr(self.other_rules[self.item])}"
class RepeatableChain(Iterable, Sized):
"""
Essentially a copy of what's in the core, with proper type hinting
"""
def __init__(self, *iterable: Union[Iterable, Sized]):
self.iterables = iterable
def __iter__(self):
return chain.from_iterable(self.iterables)
def __bool__(self):
return any(sub_iterable for sub_iterable in self.iterables)
def __len__(self):
return sum(len(iterable) for iterable in self.iterables)
def __contains__(self, item):
return any(item in it for it in self.iterables)