159 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			159 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
from __future__ import annotations
 | 
						|
 | 
						|
from collections import OrderedDict
 | 
						|
from threading import RLock
 | 
						|
from typing import TYPE_CHECKING, Any, Iterable, Union
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from .bot_ai import BotAI
 | 
						|
 | 
						|
 | 
						|
class ExpiringDict(OrderedDict):
 | 
						|
    """
 | 
						|
    An expiring dict that uses the bot.state.game_loop to only return items that are valid for a specific amount of time.
 | 
						|
 | 
						|
    Example usages::
 | 
						|
 | 
						|
        async def on_step(iteration: int):
 | 
						|
            # This dict will hold up to 10 items and only return values that have been added up to 20 frames ago
 | 
						|
            my_dict = ExpiringDict(self, max_age_frames=20)
 | 
						|
            if iteration == 0:
 | 
						|
                # Add item
 | 
						|
                my_dict["test"] = "something"
 | 
						|
            if iteration == 2:
 | 
						|
                # On default, one iteration is called every 8 frames
 | 
						|
                if "test" in my_dict:
 | 
						|
                    print("test is in dict")
 | 
						|
            if iteration == 20:
 | 
						|
                if "test" not in my_dict:
 | 
						|
                    print("test is not anymore in dict")
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, bot: BotAI, max_age_frames: int = 1):
 | 
						|
        assert max_age_frames >= -1
 | 
						|
        assert bot
 | 
						|
 | 
						|
        OrderedDict.__init__(self)
 | 
						|
        self.bot: BotAI = bot
 | 
						|
        self.max_age: Union[int, float] = max_age_frames
 | 
						|
        self.lock: RLock = RLock()
 | 
						|
 | 
						|
    @property
 | 
						|
    def frame(self) -> int:
 | 
						|
        return self.bot.state.game_loop
 | 
						|
 | 
						|
    def __contains__(self, key) -> bool:
 | 
						|
        """ Return True if dict has key, else False, e.g. 'key in dict' """
 | 
						|
        with self.lock:
 | 
						|
            if OrderedDict.__contains__(self, key):
 | 
						|
                # Each item is a list of [value, frame time]
 | 
						|
                item = OrderedDict.__getitem__(self, key)
 | 
						|
                if self.frame - item[1] < self.max_age:
 | 
						|
                    return True
 | 
						|
                del self[key]
 | 
						|
        return False
 | 
						|
 | 
						|
    def __getitem__(self, key, with_age=False) -> Any:
 | 
						|
        """ Return the item of the dict using d[key] """
 | 
						|
        with self.lock:
 | 
						|
            # Each item is a list of [value, frame time]
 | 
						|
            item = OrderedDict.__getitem__(self, key)
 | 
						|
            if self.frame - item[1] < self.max_age:
 | 
						|
                if with_age:
 | 
						|
                    return item[0], item[1]
 | 
						|
                return item[0]
 | 
						|
            OrderedDict.__delitem__(self, key)
 | 
						|
        raise KeyError(key)
 | 
						|
 | 
						|
    def __setitem__(self, key, value):
 | 
						|
        """ Set d[key] = value """
 | 
						|
        with self.lock:
 | 
						|
            OrderedDict.__setitem__(self, key, (value, self.frame))
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        """ Printable version of the dict instead of getting memory adress """
 | 
						|
        print_list = []
 | 
						|
        with self.lock:
 | 
						|
            for key, value in OrderedDict.items(self):
 | 
						|
                if self.frame - value[1] < self.max_age:
 | 
						|
                    print_list.append(f"{repr(key)}: {repr(value)}")
 | 
						|
        print_str = ", ".join(print_list)
 | 
						|
        return f"ExpiringDict({print_str})"
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return self.__repr__()
 | 
						|
 | 
						|
    def __iter__(self):
 | 
						|
        """ Override 'for key in dict:' """
 | 
						|
        with self.lock:
 | 
						|
            return self.keys()
 | 
						|
 | 
						|
    # TODO find a way to improve len
 | 
						|
    def __len__(self):
 | 
						|
        """Override len method as key value pairs aren't instantly being deleted, but only on __get__(item).
 | 
						|
        This function is slow because it has to check if each element is not expired yet."""
 | 
						|
        with self.lock:
 | 
						|
            count = 0
 | 
						|
            for _ in self.values():
 | 
						|
                count += 1
 | 
						|
            return count
 | 
						|
 | 
						|
    def pop(self, key, default=None, with_age=False):
 | 
						|
        """ Return the item and remove it """
 | 
						|
        with self.lock:
 | 
						|
            if OrderedDict.__contains__(self, key):
 | 
						|
                item = OrderedDict.__getitem__(self, key)
 | 
						|
                if self.frame - item[1] < self.max_age:
 | 
						|
                    del self[key]
 | 
						|
                    if with_age:
 | 
						|
                        return item[0], item[1]
 | 
						|
                    return item[0]
 | 
						|
                del self[key]
 | 
						|
            if default is None:
 | 
						|
                raise KeyError(key)
 | 
						|
            if with_age:
 | 
						|
                return default, self.frame
 | 
						|
            return default
 | 
						|
 | 
						|
    def get(self, key, default=None, with_age=False):
 | 
						|
        """ Return the value for key if key is in dict, else default """
 | 
						|
        with self.lock:
 | 
						|
            if OrderedDict.__contains__(self, key):
 | 
						|
                item = OrderedDict.__getitem__(self, key)
 | 
						|
                if self.frame - item[1] < self.max_age:
 | 
						|
                    if with_age:
 | 
						|
                        return item[0], item[1]
 | 
						|
                    return item[0]
 | 
						|
            if default is None:
 | 
						|
                raise KeyError(key)
 | 
						|
            if with_age:
 | 
						|
                return default, self.frame
 | 
						|
            return None
 | 
						|
        return None
 | 
						|
 | 
						|
    def update(self, other_dict: dict):
 | 
						|
        with self.lock:
 | 
						|
            for key, value in other_dict.items():
 | 
						|
                self[key] = value
 | 
						|
 | 
						|
    def items(self) -> Iterable:
 | 
						|
        """ Return iterator of zipped list [keys, values] """
 | 
						|
        with self.lock:
 | 
						|
            for key, value in OrderedDict.items(self):
 | 
						|
                if self.frame - value[1] < self.max_age:
 | 
						|
                    yield key, value[0]
 | 
						|
 | 
						|
    def keys(self) -> Iterable:
 | 
						|
        """ Return iterator of keys """
 | 
						|
        with self.lock:
 | 
						|
            for key, value in OrderedDict.items(self):
 | 
						|
                if self.frame - value[1] < self.max_age:
 | 
						|
                    yield key
 | 
						|
 | 
						|
    def values(self) -> Iterable:
 | 
						|
        """ Return iterator of values """
 | 
						|
        with self.lock:
 | 
						|
            for value in OrderedDict.values(self):
 | 
						|
                if self.frame - value[1] < self.max_age:
 | 
						|
                    yield value[0]
 |