553 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			553 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
"""Setlist class definitions."""
 | 
						|
import random as random_
 | 
						|
 | 
						|
from collections import (
 | 
						|
	Sequence,
 | 
						|
	Set,
 | 
						|
	MutableSequence,
 | 
						|
	MutableSet,
 | 
						|
	Hashable,
 | 
						|
	)
 | 
						|
 | 
						|
from . import _util
 | 
						|
 | 
						|
 | 
						|
class _basesetlist(Sequence, Set):
 | 
						|
	"""A setlist is an ordered Collection of unique elements.
 | 
						|
 | 
						|
	_basesetlist is the superclass of setlist and frozensetlist.  It is immutable
 | 
						|
	and unhashable.
 | 
						|
	"""
 | 
						|
 | 
						|
	def __init__(self, iterable=None, raise_on_duplicate=False):
 | 
						|
		"""Create a setlist.
 | 
						|
 | 
						|
		Args:
 | 
						|
			iterable (Iterable): Values to initialize the setlist with.
 | 
						|
		"""
 | 
						|
		self._list = list()
 | 
						|
		self._dict = dict()
 | 
						|
		if iterable:
 | 
						|
			if raise_on_duplicate:
 | 
						|
				self._extend(iterable)
 | 
						|
			else:
 | 
						|
				self._update(iterable)
 | 
						|
 | 
						|
	def __repr__(self):
 | 
						|
		if len(self) == 0:
 | 
						|
			return '{0}()'.format(self.__class__.__name__)
 | 
						|
		else:
 | 
						|
			repr_format = '{class_name}({values!r})'
 | 
						|
			return repr_format.format(
 | 
						|
				class_name=self.__class__.__name__,
 | 
						|
				values=tuple(self),
 | 
						|
				)
 | 
						|
 | 
						|
	# Convenience methods
 | 
						|
	def _fix_neg_index(self, index):
 | 
						|
		if index < 0:
 | 
						|
			index += len(self)
 | 
						|
		if index < 0:
 | 
						|
			raise IndexError('index is out of range')
 | 
						|
		return index
 | 
						|
 | 
						|
	def _fix_end_index(self, index):
 | 
						|
		if index is None:
 | 
						|
			return len(self)
 | 
						|
		else:
 | 
						|
			return self._fix_neg_index(index)
 | 
						|
 | 
						|
	def _append(self, value):
 | 
						|
		# Checking value in self will check that value is Hashable
 | 
						|
		if value in self:
 | 
						|
			raise ValueError('Value "%s" already present' % str(value))
 | 
						|
		else:
 | 
						|
			self._dict[value] = len(self)
 | 
						|
			self._list.append(value)
 | 
						|
 | 
						|
	def _extend(self, values):
 | 
						|
		new_values = set()
 | 
						|
		for value in values:
 | 
						|
			if value in new_values:
 | 
						|
				raise ValueError('New values contain duplicates')
 | 
						|
			elif value in self:
 | 
						|
				raise ValueError('New values contain elements already present in self')
 | 
						|
			else:
 | 
						|
				new_values.add(value)
 | 
						|
		for value in values:
 | 
						|
			self._dict[value] = len(self)
 | 
						|
			self._list.append(value)
 | 
						|
 | 
						|
	def _add(self, item):
 | 
						|
		if item not in self:
 | 
						|
			self._dict[item] = len(self)
 | 
						|
			self._list.append(item)
 | 
						|
 | 
						|
	def _update(self, values):
 | 
						|
		for value in values:
 | 
						|
			if value not in self:
 | 
						|
				self._dict[value] = len(self)
 | 
						|
				self._list.append(value)
 | 
						|
 | 
						|
	@classmethod
 | 
						|
	def _from_iterable(cls, it, **kwargs):
 | 
						|
		return cls(it, **kwargs)
 | 
						|
 | 
						|
	# Implement Container
 | 
						|
	def __contains__(self, value):
 | 
						|
		return value in self._dict
 | 
						|
 | 
						|
	# Iterable we get by inheriting from Sequence
 | 
						|
 | 
						|
	# Implement Sized
 | 
						|
	def __len__(self):
 | 
						|
		return len(self._list)
 | 
						|
 | 
						|
	# Implement Sequence
 | 
						|
	def __getitem__(self, index):
 | 
						|
		if isinstance(index, slice):
 | 
						|
			return self._from_iterable(self._list[index])
 | 
						|
		return self._list[index]
 | 
						|
 | 
						|
	def count(self, value):
 | 
						|
		"""Return the number of occurences of value in self.
 | 
						|
 | 
						|
		This runs in O(1)
 | 
						|
 | 
						|
		Args:
 | 
						|
			value: The value to count
 | 
						|
		Returns:
 | 
						|
			int: 1 if the value is in the setlist, otherwise 0
 | 
						|
		"""
 | 
						|
		if value in self:
 | 
						|
			return 1
 | 
						|
		else:
 | 
						|
			return 0
 | 
						|
 | 
						|
	def index(self, value, start=0, end=None):
 | 
						|
		"""Return the index of value between start and end.
 | 
						|
 | 
						|
		By default, the entire setlist is searched.
 | 
						|
 | 
						|
		This runs in O(1)
 | 
						|
 | 
						|
		Args:
 | 
						|
			value: The value to find the index of
 | 
						|
			start (int): The index to start searching at (defaults to 0)
 | 
						|
			end (int): The index to stop searching at (defaults to the end of the list)
 | 
						|
		Returns:
 | 
						|
			int: The index of the value
 | 
						|
		Raises:
 | 
						|
			ValueError: If the value is not in the list or outside of start - end
 | 
						|
			IndexError: If start or end are out of range
 | 
						|
		"""
 | 
						|
		try:
 | 
						|
			index = self._dict[value]
 | 
						|
		except KeyError:
 | 
						|
			raise ValueError
 | 
						|
		else:
 | 
						|
			start = self._fix_neg_index(start)
 | 
						|
			end = self._fix_end_index(end)
 | 
						|
			if start <= index and index < end:
 | 
						|
				return index
 | 
						|
			else:
 | 
						|
				raise ValueError
 | 
						|
 | 
						|
	@classmethod
 | 
						|
	def _check_type(cls, other, operand_name):
 | 
						|
		if not isinstance(other, _basesetlist):
 | 
						|
			message = (
 | 
						|
				"unsupported operand type(s) for {operand_name}: "
 | 
						|
				"'{self_type}' and '{other_type}'").format(
 | 
						|
					operand_name=operand_name,
 | 
						|
					self_type=cls,
 | 
						|
					other_type=type(other),
 | 
						|
					)
 | 
						|
			raise TypeError(message)
 | 
						|
 | 
						|
	def __add__(self, other):
 | 
						|
		self._check_type(other, '+')
 | 
						|
		out = self.copy()
 | 
						|
		out._extend(other)
 | 
						|
		return out
 | 
						|
 | 
						|
	# Implement Set
 | 
						|
 | 
						|
	def issubset(self, other):
 | 
						|
		return self <= other
 | 
						|
 | 
						|
	def issuperset(self, other):
 | 
						|
		return self >= other
 | 
						|
 | 
						|
	def union(self, other):
 | 
						|
		out = self.copy()
 | 
						|
		out.update(other)
 | 
						|
		return out
 | 
						|
 | 
						|
	def intersection(self, other):
 | 
						|
		other = set(other)
 | 
						|
		return self._from_iterable(item for item in self if item in other)
 | 
						|
 | 
						|
	def difference(self, other):
 | 
						|
		other = set(other)
 | 
						|
		return self._from_iterable(item for item in self if item not in other)
 | 
						|
 | 
						|
	def symmetric_difference(self, other):
 | 
						|
		return self.union(other) - self.intersection(other)
 | 
						|
 | 
						|
	def __sub__(self, other):
 | 
						|
		self._check_type(other, '-')
 | 
						|
		return self.difference(other)
 | 
						|
 | 
						|
	def __and__(self, other):
 | 
						|
		self._check_type(other, '&')
 | 
						|
		return self.intersection(other)
 | 
						|
 | 
						|
	def __or__(self, other):
 | 
						|
		self._check_type(other, '|')
 | 
						|
		return self.union(other)
 | 
						|
 | 
						|
	def __xor__(self, other):
 | 
						|
		self._check_type(other, '^')
 | 
						|
		return self.symmetric_difference(other)
 | 
						|
 | 
						|
	# Comparison
 | 
						|
 | 
						|
	def __eq__(self, other):
 | 
						|
		if not isinstance(other, _basesetlist):
 | 
						|
			return False
 | 
						|
		if not len(self) == len(other):
 | 
						|
			return False
 | 
						|
		for self_elem, other_elem in zip(self, other):
 | 
						|
			if self_elem != other_elem:
 | 
						|
				return False
 | 
						|
		return True
 | 
						|
 | 
						|
	def __ne__(self, other):
 | 
						|
		return not (self == other)
 | 
						|
 | 
						|
	# New methods
 | 
						|
 | 
						|
	def sub_index(self, sub, start=0, end=None):
 | 
						|
		"""Return the index of a subsequence.
 | 
						|
 | 
						|
		This runs in O(len(sub))
 | 
						|
 | 
						|
		Args:
 | 
						|
			sub (Sequence): An Iterable to search for
 | 
						|
		Returns:
 | 
						|
			int: The index of the first element of sub
 | 
						|
		Raises:
 | 
						|
			ValueError: If sub isn't a subsequence
 | 
						|
			TypeError: If sub isn't iterable
 | 
						|
			IndexError: If start or end are out of range
 | 
						|
		"""
 | 
						|
		start_index = self.index(sub[0], start, end)
 | 
						|
		end = self._fix_end_index(end)
 | 
						|
		if start_index + len(sub) > end:
 | 
						|
			raise ValueError
 | 
						|
		for i in range(1, len(sub)):
 | 
						|
			if sub[i] != self[start_index + i]:
 | 
						|
				raise ValueError
 | 
						|
		return start_index
 | 
						|
 | 
						|
	def copy(self):
 | 
						|
		return self.__class__(self)
 | 
						|
 | 
						|
 | 
						|
class setlist(_basesetlist, MutableSequence, MutableSet):
 | 
						|
	"""A mutable (unhashable) setlist."""
 | 
						|
 | 
						|
	def __str__(self):
 | 
						|
		return '{[%s}]' % ', '.join(repr(v) for v in self)
 | 
						|
 | 
						|
	# Helper methods
 | 
						|
	def _delete_all(self, elems_to_delete, raise_errors):
 | 
						|
		indices_to_delete = set()
 | 
						|
		for elem in elems_to_delete:
 | 
						|
			try:
 | 
						|
				elem_index = self._dict[elem]
 | 
						|
			except KeyError:
 | 
						|
				if raise_errors:
 | 
						|
					raise ValueError('Passed values contain elements not in self')
 | 
						|
			else:
 | 
						|
				if elem_index in indices_to_delete:
 | 
						|
					if raise_errors:
 | 
						|
						raise ValueError('Passed vales contain duplicates')
 | 
						|
				indices_to_delete.add(elem_index)
 | 
						|
		self._delete_values_by_index(indices_to_delete)
 | 
						|
 | 
						|
	def _delete_values_by_index(self, indices_to_delete):
 | 
						|
		deleted_count = 0
 | 
						|
		for i, elem in enumerate(self._list):
 | 
						|
			if i in indices_to_delete:
 | 
						|
				deleted_count += 1
 | 
						|
				del self._dict[elem]
 | 
						|
			else:
 | 
						|
				new_index = i - deleted_count
 | 
						|
				self._list[new_index] = elem
 | 
						|
				self._dict[elem] = new_index
 | 
						|
		# Now remove deleted_count items from the end of the list
 | 
						|
		if deleted_count:
 | 
						|
			self._list = self._list[:-deleted_count]
 | 
						|
 | 
						|
	# Set/Sequence agnostic
 | 
						|
	def pop(self, index=-1):
 | 
						|
		"""Remove and return the item at index."""
 | 
						|
		value = self._list.pop(index)
 | 
						|
		del self._dict[value]
 | 
						|
		return value
 | 
						|
 | 
						|
	def clear(self):
 | 
						|
		"""Remove all elements from self."""
 | 
						|
		self._dict = dict()
 | 
						|
		self._list = list()
 | 
						|
 | 
						|
	# Implement MutableSequence
 | 
						|
	def __setitem__(self, index, value):
 | 
						|
		if isinstance(index, slice):
 | 
						|
			old_values = self[index]
 | 
						|
			for v in value:
 | 
						|
				if v in self and v not in old_values:
 | 
						|
					raise ValueError
 | 
						|
			self._list[index] = value
 | 
						|
			self._dict = {}
 | 
						|
			for i, v in enumerate(self._list):
 | 
						|
				self._dict[v] = i
 | 
						|
		else:
 | 
						|
			index = self._fix_neg_index(index)
 | 
						|
			old_value = self._list[index]
 | 
						|
			if value in self:
 | 
						|
				if value == old_value:
 | 
						|
					return
 | 
						|
				else:
 | 
						|
					raise ValueError
 | 
						|
			del self._dict[old_value]
 | 
						|
			self._list[index] = value
 | 
						|
			self._dict[value] = index
 | 
						|
 | 
						|
	def __delitem__(self, index):
 | 
						|
		if isinstance(index, slice):
 | 
						|
			indices_to_delete = set(self.index(e) for e in self._list[index])
 | 
						|
			self._delete_values_by_index(indices_to_delete)
 | 
						|
		else:
 | 
						|
			index = self._fix_neg_index(index)
 | 
						|
			value = self._list[index]
 | 
						|
			del self._dict[value]
 | 
						|
			for elem in self._list[index + 1:]:
 | 
						|
				self._dict[elem] -= 1
 | 
						|
			del self._list[index]
 | 
						|
 | 
						|
	def insert(self, index, value):
 | 
						|
		"""Insert value at index.
 | 
						|
 | 
						|
		Args:
 | 
						|
			index (int): Index to insert value at
 | 
						|
			value: Value to insert
 | 
						|
		Raises:
 | 
						|
			ValueError: If value already in self
 | 
						|
			IndexError: If start or end are out of range
 | 
						|
		"""
 | 
						|
		if value in self:
 | 
						|
			raise ValueError
 | 
						|
		index = self._fix_neg_index(index)
 | 
						|
		self._dict[value] = index
 | 
						|
		for elem in self._list[index:]:
 | 
						|
			self._dict[elem] += 1
 | 
						|
		self._list.insert(index, value)
 | 
						|
 | 
						|
	def append(self, value):
 | 
						|
		"""Append value to the end.
 | 
						|
 | 
						|
		Args:
 | 
						|
			value: Value to append
 | 
						|
		Raises:
 | 
						|
			ValueError: If value alread in self
 | 
						|
			TypeError: If value isn't hashable
 | 
						|
		"""
 | 
						|
		self._append(value)
 | 
						|
 | 
						|
	def extend(self, values):
 | 
						|
		"""Append all values to the end.
 | 
						|
 | 
						|
		If any of the values are present, ValueError will
 | 
						|
		be raised and none of the values will be appended.
 | 
						|
 | 
						|
		Args:
 | 
						|
			values (Iterable): Values to append
 | 
						|
		Raises:
 | 
						|
			ValueError: If any values are already present or there are duplicates
 | 
						|
				in the passed values.
 | 
						|
			TypeError: If any of the values aren't hashable.
 | 
						|
		"""
 | 
						|
		self._extend(values)
 | 
						|
 | 
						|
	def __iadd__(self, values):
 | 
						|
		"""Add all values to the end of self.
 | 
						|
 | 
						|
		Args:
 | 
						|
			values (Iterable): Values to append
 | 
						|
		Raises:
 | 
						|
			ValueError: If any values are already present
 | 
						|
		"""
 | 
						|
		self._check_type(values, '+=')
 | 
						|
		self.extend(values)
 | 
						|
		return self
 | 
						|
 | 
						|
	def remove(self, value):
 | 
						|
		"""Remove value from self.
 | 
						|
 | 
						|
		Args:
 | 
						|
			value: Element to remove from self
 | 
						|
		Raises:
 | 
						|
			ValueError: if element is already present
 | 
						|
		"""
 | 
						|
		try:
 | 
						|
			index = self._dict[value]
 | 
						|
		except KeyError:
 | 
						|
			raise ValueError('Value "%s" is not present.')
 | 
						|
		else:
 | 
						|
			del self[index]
 | 
						|
 | 
						|
	def remove_all(self, elems_to_delete):
 | 
						|
		"""Remove all elements from elems_to_delete, raises ValueErrors.
 | 
						|
 | 
						|
		See Also:
 | 
						|
			discard_all
 | 
						|
		Args:
 | 
						|
			elems_to_delete (Iterable): Elements to remove.
 | 
						|
		Raises:
 | 
						|
			ValueError: If the count of any element is greater in
 | 
						|
				elems_to_delete than self.
 | 
						|
			TypeError: If any of the values aren't hashable.
 | 
						|
		"""
 | 
						|
		self._delete_all(elems_to_delete, raise_errors=True)
 | 
						|
 | 
						|
	# Implement MutableSet
 | 
						|
 | 
						|
	def add(self, item):
 | 
						|
		"""Add an item.
 | 
						|
 | 
						|
		Note:
 | 
						|
			This does not raise a ValueError for an already present value like
 | 
						|
			append does. This is to match the behavior of set.add
 | 
						|
		Args:
 | 
						|
			item: Item to add
 | 
						|
		Raises:
 | 
						|
			TypeError: If item isn't hashable.
 | 
						|
		"""
 | 
						|
		self._add(item)
 | 
						|
 | 
						|
	def update(self, values):
 | 
						|
		"""Add all values to the end.
 | 
						|
 | 
						|
		If any of the values are present, silently ignore
 | 
						|
		them (as opposed to extend which raises an Error).
 | 
						|
 | 
						|
		See also:
 | 
						|
			extend
 | 
						|
		Args:
 | 
						|
			values (Iterable): Values to add
 | 
						|
		Raises:
 | 
						|
			TypeError: If any of the values are unhashable.
 | 
						|
		"""
 | 
						|
		self._update(values)
 | 
						|
 | 
						|
	def discard_all(self, elems_to_delete):
 | 
						|
		"""Discard all the elements from elems_to_delete.
 | 
						|
 | 
						|
		This is much faster than removing them one by one.
 | 
						|
		This runs in O(len(self) + len(elems_to_delete))
 | 
						|
 | 
						|
		Args:
 | 
						|
			elems_to_delete (Iterable): Elements to discard.
 | 
						|
		Raises:
 | 
						|
			TypeError: If any of the values aren't hashable.
 | 
						|
		"""
 | 
						|
		self._delete_all(elems_to_delete, raise_errors=False)
 | 
						|
 | 
						|
	def discard(self, value):
 | 
						|
		"""Discard an item.
 | 
						|
 | 
						|
		Note:
 | 
						|
			This does not raise a ValueError for a missing value like remove does.
 | 
						|
			This is to match the behavior of set.discard
 | 
						|
		"""
 | 
						|
		try:
 | 
						|
			self.remove(value)
 | 
						|
		except ValueError:
 | 
						|
			pass
 | 
						|
 | 
						|
	def difference_update(self, other):
 | 
						|
		"""Update self to include only the differene with other."""
 | 
						|
		other = set(other)
 | 
						|
		indices_to_delete = set()
 | 
						|
		for i, elem in enumerate(self):
 | 
						|
			if elem in other:
 | 
						|
				indices_to_delete.add(i)
 | 
						|
		if indices_to_delete:
 | 
						|
			self._delete_values_by_index(indices_to_delete)
 | 
						|
 | 
						|
	def intersection_update(self, other):
 | 
						|
		"""Update self to include only the intersection with other."""
 | 
						|
		other = set(other)
 | 
						|
		indices_to_delete = set()
 | 
						|
		for i, elem in enumerate(self):
 | 
						|
			if elem not in other:
 | 
						|
				indices_to_delete.add(i)
 | 
						|
		if indices_to_delete:
 | 
						|
			self._delete_values_by_index(indices_to_delete)
 | 
						|
 | 
						|
	def symmetric_difference_update(self, other):
 | 
						|
		"""Update self to include only the symmetric difference with other."""
 | 
						|
		other = setlist(other)
 | 
						|
		indices_to_delete = set()
 | 
						|
		for i, item in enumerate(self):
 | 
						|
			if item in other:
 | 
						|
				indices_to_delete.add(i)
 | 
						|
		for item in other:
 | 
						|
			self.add(item)
 | 
						|
		self._delete_values_by_index(indices_to_delete)
 | 
						|
 | 
						|
	def __isub__(self, other):
 | 
						|
		self._check_type(other, '-=')
 | 
						|
		self.difference_update(other)
 | 
						|
		return self
 | 
						|
 | 
						|
	def __iand__(self, other):
 | 
						|
		self._check_type(other, '&=')
 | 
						|
		self.intersection_update(other)
 | 
						|
		return self
 | 
						|
 | 
						|
	def __ior__(self, other):
 | 
						|
		self._check_type(other, '|=')
 | 
						|
		self.update(other)
 | 
						|
		return self
 | 
						|
 | 
						|
	def __ixor__(self, other):
 | 
						|
		self._check_type(other, '^=')
 | 
						|
		self.symmetric_difference_update(other)
 | 
						|
		return self
 | 
						|
 | 
						|
	# New methods
 | 
						|
	def shuffle(self, random=None):
 | 
						|
		"""Shuffle all of the elements in self randomly."""
 | 
						|
		random_.shuffle(self._list, random=random)
 | 
						|
		for i, elem in enumerate(self._list):
 | 
						|
			self._dict[elem] = i
 | 
						|
 | 
						|
	def sort(self, *args, **kwargs):
 | 
						|
		"""Sort this setlist in place."""
 | 
						|
		self._list.sort(*args, **kwargs)
 | 
						|
		for index, value in enumerate(self._list):
 | 
						|
			self._dict[value] = index
 | 
						|
 | 
						|
 | 
						|
class frozensetlist(_basesetlist, Hashable):
 | 
						|
	"""An immutable (hashable) setlist."""
 | 
						|
 | 
						|
	def __hash__(self):
 | 
						|
		if not hasattr(self, '_hash_value'):
 | 
						|
			self._hash_value = _util.hash_iterable(self)
 | 
						|
		return self._hash_value
 |