480 lines
18 KiB
Python
480 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from collections import deque, Counter
|
|
from dataclasses import dataclass, field
|
|
from functools import cached_property
|
|
from itertools import chain
|
|
from threading import Lock
|
|
from typing import Iterable, Dict, List, Union, Sized, Hashable, Callable, Tuple, Set, Optional
|
|
|
|
from BaseClasses import CollectionState
|
|
from .literal import true_, false_, LiteralStardewRule
|
|
from .protocol import StardewRule
|
|
|
|
MISSING_ITEM = "THIS ITEM IS MISSING"
|
|
|
|
|
|
class BaseStardewRule(StardewRule, ABC):
|
|
|
|
def __or__(self, other) -> StardewRule:
|
|
if other is true_ or other is false_ or type(other) is Or:
|
|
return other | self
|
|
|
|
return Or(self, other)
|
|
|
|
def __and__(self, other) -> StardewRule:
|
|
if other is true_ or other is false_ or type(other) is And:
|
|
return other & self
|
|
|
|
return And(self, other)
|
|
|
|
|
|
class CombinableStardewRule(BaseStardewRule, ABC):
|
|
|
|
@property
|
|
@abstractmethod
|
|
def combination_key(self) -> Hashable:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
@abstractmethod
|
|
def value(self):
|
|
raise NotImplementedError
|
|
|
|
def is_same_rule(self, other: CombinableStardewRule):
|
|
return self.combination_key == other.combination_key
|
|
|
|
def add_into(self, rules: Dict[Hashable, CombinableStardewRule], reducer: Callable[[CombinableStardewRule, CombinableStardewRule], CombinableStardewRule]) \
|
|
-> Dict[Hashable, CombinableStardewRule]:
|
|
rules = dict(rules)
|
|
|
|
if self.combination_key in rules:
|
|
rules[self.combination_key] = reducer(self, rules[self.combination_key])
|
|
else:
|
|
rules[self.combination_key] = self
|
|
|
|
return rules
|
|
|
|
def __and__(self, other):
|
|
if isinstance(other, CombinableStardewRule) and self.is_same_rule(other):
|
|
return And.combine(self, other)
|
|
return super().__and__(other)
|
|
|
|
def __or__(self, other):
|
|
if isinstance(other, CombinableStardewRule) and self.is_same_rule(other):
|
|
return Or.combine(self, other)
|
|
return super().__or__(other)
|
|
|
|
|
|
class _SimplificationState:
|
|
original_simplifiable_rules: Tuple[StardewRule, ...]
|
|
|
|
rules_to_simplify: deque[StardewRule]
|
|
simplified_rules: Set[StardewRule]
|
|
lock: Lock
|
|
|
|
def __init__(self, simplifiable_rules: Tuple[StardewRule, ...], rules_to_simplify: Optional[deque[StardewRule]] = None,
|
|
simplified_rules: Optional[Set[StardewRule]] = None):
|
|
if simplified_rules is None:
|
|
simplified_rules = set()
|
|
|
|
self.original_simplifiable_rules = simplifiable_rules
|
|
self.rules_to_simplify = rules_to_simplify
|
|
self.simplified_rules = simplified_rules
|
|
self.locked = False
|
|
|
|
@property
|
|
def is_simplified(self):
|
|
return self.rules_to_simplify is not None and not self.rules_to_simplify
|
|
|
|
def short_circuit(self, complement: LiteralStardewRule):
|
|
self.rules_to_simplify = deque()
|
|
self.simplified_rules = {complement}
|
|
|
|
def try_popleft(self):
|
|
try:
|
|
self.rules_to_simplify.popleft()
|
|
except IndexError:
|
|
pass
|
|
|
|
def acquire_copy(self):
|
|
state = _SimplificationState(self.original_simplifiable_rules, self.rules_to_simplify.copy(), self.simplified_rules.copy())
|
|
state.acquire()
|
|
return state
|
|
|
|
def merge(self, other: _SimplificationState):
|
|
return _SimplificationState(self.original_simplifiable_rules + other.original_simplifiable_rules)
|
|
|
|
def add(self, rule: StardewRule):
|
|
return _SimplificationState(self.original_simplifiable_rules + (rule,))
|
|
|
|
def acquire(self):
|
|
"""
|
|
This just set a boolean to True and is absolutely not thread safe. It just works because AP is single-threaded.
|
|
"""
|
|
if self.locked is True:
|
|
return False
|
|
|
|
self.locked = True
|
|
return True
|
|
|
|
def release(self):
|
|
assert self.locked
|
|
self.locked = False
|
|
|
|
|
|
class AggregatingStardewRule(BaseStardewRule, ABC):
|
|
"""
|
|
Logic for both "And" and "Or" rules.
|
|
"""
|
|
identity: LiteralStardewRule
|
|
complement: LiteralStardewRule
|
|
symbol: str
|
|
|
|
combinable_rules: Dict[Hashable, CombinableStardewRule]
|
|
simplification_state: _SimplificationState
|
|
_last_short_circuiting_rule: Optional[StardewRule] = None
|
|
|
|
def __init__(self, *rules: StardewRule, _combinable_rules=None, _simplification_state=None):
|
|
if _combinable_rules is None:
|
|
assert rules, f"Can't create an aggregating condition without rules"
|
|
rules, _combinable_rules = self.split_rules(rules)
|
|
_simplification_state = _SimplificationState(rules)
|
|
|
|
self.combinable_rules = _combinable_rules
|
|
self.simplification_state = _simplification_state
|
|
|
|
@property
|
|
def original_rules(self):
|
|
return RepeatableChain(self.combinable_rules.values(), self.simplification_state.original_simplifiable_rules)
|
|
|
|
@property
|
|
def current_rules(self):
|
|
if self.simplification_state.rules_to_simplify is None:
|
|
return self.original_rules
|
|
|
|
return RepeatableChain(self.combinable_rules.values(), self.simplification_state.simplified_rules, self.simplification_state.rules_to_simplify)
|
|
|
|
@classmethod
|
|
def split_rules(cls, rules: Union[Iterable[StardewRule]]) -> Tuple[Tuple[StardewRule, ...], Dict[Hashable, CombinableStardewRule]]:
|
|
other_rules = []
|
|
reduced_rules = {}
|
|
for rule in rules:
|
|
if isinstance(rule, CombinableStardewRule):
|
|
key = rule.combination_key
|
|
if key not in reduced_rules:
|
|
reduced_rules[key] = rule
|
|
continue
|
|
|
|
reduced_rules[key] = cls.combine(reduced_rules[key], rule)
|
|
continue
|
|
|
|
if type(rule) is cls:
|
|
other_rules.extend(rule.simplification_state.original_simplifiable_rules) # noqa
|
|
reduced_rules = cls.merge(reduced_rules, rule.combinable_rules) # noqa
|
|
continue
|
|
|
|
other_rules.append(rule)
|
|
|
|
return tuple(other_rules), reduced_rules
|
|
|
|
@classmethod
|
|
def merge(cls, left: Dict[Hashable, CombinableStardewRule], right: Dict[Hashable, CombinableStardewRule]) -> Dict[Hashable, CombinableStardewRule]:
|
|
reduced_rules = dict(left)
|
|
for key, rule in right.items():
|
|
if key not in reduced_rules:
|
|
reduced_rules[key] = rule
|
|
continue
|
|
|
|
reduced_rules[key] = cls.combine(reduced_rules[key], rule)
|
|
|
|
return reduced_rules
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def combine(left: CombinableStardewRule, right: CombinableStardewRule) -> CombinableStardewRule:
|
|
raise NotImplementedError
|
|
|
|
def short_circuit_simplification(self):
|
|
self.simplification_state.short_circuit(self.complement)
|
|
self.combinable_rules = {}
|
|
return self.complement, self.complement.value
|
|
|
|
def short_circuit_evaluation(self, rule):
|
|
self._last_short_circuiting_rule = rule
|
|
return self, self.complement.value
|
|
|
|
def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
|
|
"""
|
|
The global idea here is the same as short-circuiting operators, applied to evaluation and rule simplification.
|
|
"""
|
|
|
|
# Directly checking last rule that short-circuited, in case state has not changed.
|
|
if self._last_short_circuiting_rule:
|
|
if self._last_short_circuiting_rule(state) is self.complement.value:
|
|
return self.short_circuit_evaluation(self._last_short_circuiting_rule)
|
|
self._last_short_circuiting_rule = None
|
|
|
|
# Combinable rules are considered already simplified, so we evaluate them right away to go faster.
|
|
for rule in self.combinable_rules.values():
|
|
if rule(state) is self.complement.value:
|
|
return self.short_circuit_evaluation(rule)
|
|
|
|
if self.simplification_state.is_simplified:
|
|
# The rule is fully simplified, so now we can only evaluate.
|
|
for rule in self.simplification_state.simplified_rules:
|
|
if rule(state) is self.complement.value:
|
|
return self.short_circuit_evaluation(rule)
|
|
return self, self.identity.value
|
|
|
|
return self.evaluate_while_simplifying_stateful(state)
|
|
|
|
def evaluate_while_simplifying_stateful(self, state):
|
|
local_state = self.simplification_state
|
|
try:
|
|
# Creating a new copy, so we don't modify the rules while we're already evaluating it. This can happen if a rule is used for an entrance and a
|
|
# location. When evaluating a given rule what requires access to a region, the region cache can get an update. If it does, we could enter this rule
|
|
# again. Since the simplification is stateful, the set of simplified rules can be modified while it's being iterated on, and cause a crash.
|
|
#
|
|
# After investigation, for millions of call to this method, copy were acquired 425 times.
|
|
# Merging simplification state in parent call was deemed useless.
|
|
if not local_state.acquire():
|
|
local_state = local_state.acquire_copy()
|
|
self.simplification_state = local_state
|
|
|
|
# Evaluating what has already been simplified. First it will be faster than simplifying "new" rules, but we also assume that if we reach this point
|
|
# and there are already are simplified rule, one of these rules has short-circuited, and might again, so we can leave early.
|
|
for rule in local_state.simplified_rules:
|
|
if rule(state) is self.complement.value:
|
|
return self.short_circuit_evaluation(rule)
|
|
|
|
# If the queue is None, it means we have not start simplifying. Otherwise, we will continue simplification where we left.
|
|
if local_state.rules_to_simplify is None:
|
|
rules_to_simplify = frozenset(local_state.original_simplifiable_rules)
|
|
if self.complement in rules_to_simplify:
|
|
return self.short_circuit_simplification()
|
|
local_state.rules_to_simplify = deque(rules_to_simplify)
|
|
|
|
# Start simplification where we left.
|
|
while local_state.rules_to_simplify:
|
|
result = self.evaluate_rule_while_simplifying_stateful(local_state, state)
|
|
local_state.try_popleft()
|
|
if result is not None:
|
|
return result
|
|
|
|
# The whole rule has been simplified and evaluated without short-circuit.
|
|
return self, self.identity.value
|
|
finally:
|
|
local_state.release()
|
|
|
|
def evaluate_rule_while_simplifying_stateful(self, local_state, state):
|
|
simplified, value = local_state.rules_to_simplify[0].evaluate_while_simplifying(state)
|
|
|
|
# Identity is removed from the resulting simplification since it does not affect the result.
|
|
if simplified is self.identity:
|
|
return
|
|
|
|
# If we find a complement here, we know the rule will always short-circuit, what ever the state.
|
|
if simplified is self.complement:
|
|
return self.short_circuit_simplification()
|
|
# Keep the simplified rule to be reevaluated later.
|
|
local_state.simplified_rules.add(simplified)
|
|
|
|
# Now we use the value to short-circuit if it is the complement.
|
|
if value is self.complement.value:
|
|
return self.short_circuit_evaluation(simplified)
|
|
|
|
def __str__(self):
|
|
return f"({self.symbol.join(str(rule) for rule in self.original_rules)})"
|
|
|
|
def __repr__(self):
|
|
return f"({self.symbol.join(repr(rule) for rule in self.original_rules)})"
|
|
|
|
def __eq__(self, other):
|
|
return (isinstance(other, type(self)) and self.combinable_rules == other.combinable_rules and
|
|
self.simplification_state.original_simplifiable_rules == other.simplification_state.original_simplifiable_rules)
|
|
|
|
def __hash__(self):
|
|
if len(self.combinable_rules) + len(self.simplification_state.original_simplifiable_rules) > 5:
|
|
return id(self)
|
|
|
|
return hash((*self.combinable_rules.values(), self.simplification_state.original_simplifiable_rules))
|
|
|
|
|
|
class Or(AggregatingStardewRule):
|
|
identity = false_
|
|
complement = true_
|
|
symbol = " | "
|
|
|
|
def __call__(self, state: CollectionState) -> bool:
|
|
return self.evaluate_while_simplifying(state)[1]
|
|
|
|
def __or__(self, other):
|
|
if other is true_ or other is false_:
|
|
return other | self
|
|
|
|
if isinstance(other, CombinableStardewRule):
|
|
return Or(_combinable_rules=other.add_into(self.combinable_rules, self.combine), _simplification_state=self.simplification_state)
|
|
|
|
if type(other) is Or:
|
|
return Or(_combinable_rules=self.merge(self.combinable_rules, other.combinable_rules),
|
|
_simplification_state=self.simplification_state.merge(other.simplification_state))
|
|
|
|
return Or(_combinable_rules=self.combinable_rules, _simplification_state=self.simplification_state.add(other))
|
|
|
|
@staticmethod
|
|
def combine(left: CombinableStardewRule, right: CombinableStardewRule) -> CombinableStardewRule:
|
|
return min(left, right, key=lambda x: x.value)
|
|
|
|
|
|
class And(AggregatingStardewRule):
|
|
identity = true_
|
|
complement = false_
|
|
symbol = " & "
|
|
|
|
def __call__(self, state: CollectionState) -> bool:
|
|
return self.evaluate_while_simplifying(state)[1]
|
|
|
|
def __and__(self, other):
|
|
if other is true_ or other is false_:
|
|
return other & self
|
|
|
|
if isinstance(other, CombinableStardewRule):
|
|
return And(_combinable_rules=other.add_into(self.combinable_rules, self.combine), _simplification_state=self.simplification_state)
|
|
|
|
if type(other) is And:
|
|
return And(_combinable_rules=self.merge(self.combinable_rules, other.combinable_rules),
|
|
_simplification_state=self.simplification_state.merge(other.simplification_state))
|
|
|
|
return And(_combinable_rules=self.combinable_rules, _simplification_state=self.simplification_state.add(other))
|
|
|
|
@staticmethod
|
|
def combine(left: CombinableStardewRule, right: CombinableStardewRule) -> CombinableStardewRule:
|
|
return max(left, right, key=lambda x: x.value)
|
|
|
|
|
|
class Count(BaseStardewRule):
|
|
count: int
|
|
rules: List[StardewRule]
|
|
counter: Counter[StardewRule]
|
|
evaluate: Callable[[CollectionState], bool]
|
|
|
|
total: Optional[int]
|
|
rule_mapping: Optional[Dict[StardewRule, StardewRule]]
|
|
|
|
def __init__(self, rules: List[StardewRule], count: int):
|
|
self.count = count
|
|
self.counter = Counter(rules)
|
|
|
|
if len(self.counter) / len(rules) < .66:
|
|
# Checking if it's worth using the count operation with shortcircuit or not. Value should be fine-tuned when Count has more usage.
|
|
self.total = sum(self.counter.values())
|
|
self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True)
|
|
self.rule_mapping = {}
|
|
self.evaluate = self.evaluate_with_shortcircuit
|
|
else:
|
|
self.rules = rules
|
|
self.evaluate = self.evaluate_without_shortcircuit
|
|
|
|
def __call__(self, state: CollectionState) -> bool:
|
|
return self.evaluate(state)
|
|
|
|
def evaluate_without_shortcircuit(self, state: CollectionState) -> bool:
|
|
c = 0
|
|
for i in range(self.rules_count):
|
|
self.rules[i], value = self.rules[i].evaluate_while_simplifying(state)
|
|
if value:
|
|
c += 1
|
|
|
|
if c >= self.count:
|
|
return True
|
|
if c + self.rules_count - i < self.count:
|
|
break
|
|
|
|
return False
|
|
|
|
def evaluate_with_shortcircuit(self, state: CollectionState) -> bool:
|
|
c = 0
|
|
t = self.total
|
|
|
|
for rule in self.rules:
|
|
evaluation_value = self.call_evaluate_while_simplifying_cached(rule, state)
|
|
rule_value = self.counter[rule]
|
|
|
|
if evaluation_value:
|
|
c += rule_value
|
|
else:
|
|
t -= rule_value
|
|
|
|
if c >= self.count:
|
|
return True
|
|
elif t < self.count:
|
|
break
|
|
|
|
return False
|
|
|
|
def call_evaluate_while_simplifying_cached(self, rule: StardewRule, state: CollectionState) -> bool:
|
|
try:
|
|
# A mapping table with the original rule is used here because two rules could resolve to the same rule.
|
|
# This would require to change the counter to merge both rules, and quickly become complicated.
|
|
return self.rule_mapping[rule](state)
|
|
except KeyError:
|
|
self.rule_mapping[rule], value = rule.evaluate_while_simplifying(state)
|
|
return value
|
|
|
|
def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
|
|
return self, self(state)
|
|
|
|
@cached_property
|
|
def rules_count(self):
|
|
return len(self.rules)
|
|
|
|
def __repr__(self):
|
|
return f"Received {self.count} [{', '.join(f'{value}x {repr(rule)}' for rule, value in self.counter.items())}]"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Has(BaseStardewRule):
|
|
item: str
|
|
# For sure there is a better way than just passing all the rules everytime
|
|
other_rules: Dict[str, StardewRule] = field(repr=False, hash=False, compare=False)
|
|
group: str = "item"
|
|
|
|
def __call__(self, state: CollectionState) -> bool:
|
|
return self.evaluate_while_simplifying(state)[1]
|
|
|
|
def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
|
|
return self.other_rules[self.item].evaluate_while_simplifying(state)
|
|
|
|
def __str__(self):
|
|
if self.item not in self.other_rules:
|
|
return f"Has {self.item} ({self.group}) -> {MISSING_ITEM}"
|
|
return f"Has {self.item} ({self.group})"
|
|
|
|
def __repr__(self):
|
|
if self.item not in self.other_rules:
|
|
return f"Has {self.item} ({self.group}) -> {MISSING_ITEM}"
|
|
return f"Has {self.item} ({self.group}) -> {repr(self.other_rules[self.item])}"
|
|
|
|
|
|
class RepeatableChain(Iterable, Sized):
|
|
"""
|
|
Essentially a copy of what's in the core, with proper type hinting
|
|
"""
|
|
|
|
def __init__(self, *iterable: Union[Iterable, Sized]):
|
|
self.iterables = iterable
|
|
|
|
def __iter__(self):
|
|
return chain.from_iterable(self.iterables)
|
|
|
|
def __bool__(self):
|
|
return any(sub_iterable for sub_iterable in self.iterables)
|
|
|
|
def __len__(self):
|
|
return sum(len(iterable) for iterable in self.iterables)
|
|
|
|
def __contains__(self, item):
|
|
return any(item in it for it in self.iterables)
|