159 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			159 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| from collections import OrderedDict
 | |
| from threading import RLock
 | |
| from typing import TYPE_CHECKING, Any, Iterable, Union
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from .bot_ai import BotAI
 | |
| 
 | |
| 
 | |
| class ExpiringDict(OrderedDict):
 | |
|     """
 | |
|     An expiring dict that uses the bot.state.game_loop to only return items that are valid for a specific amount of time.
 | |
| 
 | |
|     Example usages::
 | |
| 
 | |
|         async def on_step(iteration: int):
 | |
|             # This dict will hold up to 10 items and only return values that have been added up to 20 frames ago
 | |
|             my_dict = ExpiringDict(self, max_age_frames=20)
 | |
|             if iteration == 0:
 | |
|                 # Add item
 | |
|                 my_dict["test"] = "something"
 | |
|             if iteration == 2:
 | |
|                 # On default, one iteration is called every 8 frames
 | |
|                 if "test" in my_dict:
 | |
|                     print("test is in dict")
 | |
|             if iteration == 20:
 | |
|                 if "test" not in my_dict:
 | |
|                     print("test is not anymore in dict")
 | |
|     """
 | |
| 
 | |
|     def __init__(self, bot: BotAI, max_age_frames: int = 1):
 | |
|         assert max_age_frames >= -1
 | |
|         assert bot
 | |
| 
 | |
|         OrderedDict.__init__(self)
 | |
|         self.bot: BotAI = bot
 | |
|         self.max_age: Union[int, float] = max_age_frames
 | |
|         self.lock: RLock = RLock()
 | |
| 
 | |
|     @property
 | |
|     def frame(self) -> int:
 | |
|         return self.bot.state.game_loop
 | |
| 
 | |
|     def __contains__(self, key) -> bool:
 | |
|         """ Return True if dict has key, else False, e.g. 'key in dict' """
 | |
|         with self.lock:
 | |
|             if OrderedDict.__contains__(self, key):
 | |
|                 # Each item is a list of [value, frame time]
 | |
|                 item = OrderedDict.__getitem__(self, key)
 | |
|                 if self.frame - item[1] < self.max_age:
 | |
|                     return True
 | |
|                 del self[key]
 | |
|         return False
 | |
| 
 | |
|     def __getitem__(self, key, with_age=False) -> Any:
 | |
|         """ Return the item of the dict using d[key] """
 | |
|         with self.lock:
 | |
|             # Each item is a list of [value, frame time]
 | |
|             item = OrderedDict.__getitem__(self, key)
 | |
|             if self.frame - item[1] < self.max_age:
 | |
|                 if with_age:
 | |
|                     return item[0], item[1]
 | |
|                 return item[0]
 | |
|             OrderedDict.__delitem__(self, key)
 | |
|         raise KeyError(key)
 | |
| 
 | |
|     def __setitem__(self, key, value):
 | |
|         """ Set d[key] = value """
 | |
|         with self.lock:
 | |
|             OrderedDict.__setitem__(self, key, (value, self.frame))
 | |
| 
 | |
|     def __repr__(self):
 | |
|         """ Printable version of the dict instead of getting memory adress """
 | |
|         print_list = []
 | |
|         with self.lock:
 | |
|             for key, value in OrderedDict.items(self):
 | |
|                 if self.frame - value[1] < self.max_age:
 | |
|                     print_list.append(f"{repr(key)}: {repr(value)}")
 | |
|         print_str = ", ".join(print_list)
 | |
|         return f"ExpiringDict({print_str})"
 | |
| 
 | |
|     def __str__(self):
 | |
|         return self.__repr__()
 | |
| 
 | |
|     def __iter__(self):
 | |
|         """ Override 'for key in dict:' """
 | |
|         with self.lock:
 | |
|             return self.keys()
 | |
| 
 | |
|     # TODO find a way to improve len
 | |
|     def __len__(self):
 | |
|         """Override len method as key value pairs aren't instantly being deleted, but only on __get__(item).
 | |
|         This function is slow because it has to check if each element is not expired yet."""
 | |
|         with self.lock:
 | |
|             count = 0
 | |
|             for _ in self.values():
 | |
|                 count += 1
 | |
|             return count
 | |
| 
 | |
|     def pop(self, key, default=None, with_age=False):
 | |
|         """ Return the item and remove it """
 | |
|         with self.lock:
 | |
|             if OrderedDict.__contains__(self, key):
 | |
|                 item = OrderedDict.__getitem__(self, key)
 | |
|                 if self.frame - item[1] < self.max_age:
 | |
|                     del self[key]
 | |
|                     if with_age:
 | |
|                         return item[0], item[1]
 | |
|                     return item[0]
 | |
|                 del self[key]
 | |
|             if default is None:
 | |
|                 raise KeyError(key)
 | |
|             if with_age:
 | |
|                 return default, self.frame
 | |
|             return default
 | |
| 
 | |
|     def get(self, key, default=None, with_age=False):
 | |
|         """ Return the value for key if key is in dict, else default """
 | |
|         with self.lock:
 | |
|             if OrderedDict.__contains__(self, key):
 | |
|                 item = OrderedDict.__getitem__(self, key)
 | |
|                 if self.frame - item[1] < self.max_age:
 | |
|                     if with_age:
 | |
|                         return item[0], item[1]
 | |
|                     return item[0]
 | |
|             if default is None:
 | |
|                 raise KeyError(key)
 | |
|             if with_age:
 | |
|                 return default, self.frame
 | |
|             return None
 | |
|         return None
 | |
| 
 | |
|     def update(self, other_dict: dict):
 | |
|         with self.lock:
 | |
|             for key, value in other_dict.items():
 | |
|                 self[key] = value
 | |
| 
 | |
|     def items(self) -> Iterable:
 | |
|         """ Return iterator of zipped list [keys, values] """
 | |
|         with self.lock:
 | |
|             for key, value in OrderedDict.items(self):
 | |
|                 if self.frame - value[1] < self.max_age:
 | |
|                     yield key, value[0]
 | |
| 
 | |
|     def keys(self) -> Iterable:
 | |
|         """ Return iterator of keys """
 | |
|         with self.lock:
 | |
|             for key, value in OrderedDict.items(self):
 | |
|                 if self.frame - value[1] < self.max_age:
 | |
|                     yield key
 | |
| 
 | |
|     def values(self) -> Iterable:
 | |
|         """ Return iterator of values """
 | |
|         with self.lock:
 | |
|             for value in OrderedDict.values(self):
 | |
|                 if self.frame - value[1] < self.max_age:
 | |
|                     yield value[0]
 |