553 lines
13 KiB
Python
553 lines
13 KiB
Python
"""Setlist class definitions."""
|
|
import random as random_
|
|
|
|
from collections import (
|
|
Sequence,
|
|
Set,
|
|
MutableSequence,
|
|
MutableSet,
|
|
Hashable,
|
|
)
|
|
|
|
from . import _util
|
|
|
|
|
|
class _basesetlist(Sequence, Set):
|
|
"""A setlist is an ordered Collection of unique elements.
|
|
|
|
_basesetlist is the superclass of setlist and frozensetlist. It is immutable
|
|
and unhashable.
|
|
"""
|
|
|
|
def __init__(self, iterable=None, raise_on_duplicate=False):
|
|
"""Create a setlist.
|
|
|
|
Args:
|
|
iterable (Iterable): Values to initialize the setlist with.
|
|
"""
|
|
self._list = list()
|
|
self._dict = dict()
|
|
if iterable:
|
|
if raise_on_duplicate:
|
|
self._extend(iterable)
|
|
else:
|
|
self._update(iterable)
|
|
|
|
def __repr__(self):
|
|
if len(self) == 0:
|
|
return '{0}()'.format(self.__class__.__name__)
|
|
else:
|
|
repr_format = '{class_name}({values!r})'
|
|
return repr_format.format(
|
|
class_name=self.__class__.__name__,
|
|
values=tuple(self),
|
|
)
|
|
|
|
# Convenience methods
|
|
def _fix_neg_index(self, index):
|
|
if index < 0:
|
|
index += len(self)
|
|
if index < 0:
|
|
raise IndexError('index is out of range')
|
|
return index
|
|
|
|
def _fix_end_index(self, index):
|
|
if index is None:
|
|
return len(self)
|
|
else:
|
|
return self._fix_neg_index(index)
|
|
|
|
def _append(self, value):
|
|
# Checking value in self will check that value is Hashable
|
|
if value in self:
|
|
raise ValueError('Value "%s" already present' % str(value))
|
|
else:
|
|
self._dict[value] = len(self)
|
|
self._list.append(value)
|
|
|
|
def _extend(self, values):
|
|
new_values = set()
|
|
for value in values:
|
|
if value in new_values:
|
|
raise ValueError('New values contain duplicates')
|
|
elif value in self:
|
|
raise ValueError('New values contain elements already present in self')
|
|
else:
|
|
new_values.add(value)
|
|
for value in values:
|
|
self._dict[value] = len(self)
|
|
self._list.append(value)
|
|
|
|
def _add(self, item):
|
|
if item not in self:
|
|
self._dict[item] = len(self)
|
|
self._list.append(item)
|
|
|
|
def _update(self, values):
|
|
for value in values:
|
|
if value not in self:
|
|
self._dict[value] = len(self)
|
|
self._list.append(value)
|
|
|
|
@classmethod
|
|
def _from_iterable(cls, it, **kwargs):
|
|
return cls(it, **kwargs)
|
|
|
|
# Implement Container
|
|
def __contains__(self, value):
|
|
return value in self._dict
|
|
|
|
# Iterable we get by inheriting from Sequence
|
|
|
|
# Implement Sized
|
|
def __len__(self):
|
|
return len(self._list)
|
|
|
|
# Implement Sequence
|
|
def __getitem__(self, index):
|
|
if isinstance(index, slice):
|
|
return self._from_iterable(self._list[index])
|
|
return self._list[index]
|
|
|
|
def count(self, value):
|
|
"""Return the number of occurences of value in self.
|
|
|
|
This runs in O(1)
|
|
|
|
Args:
|
|
value: The value to count
|
|
Returns:
|
|
int: 1 if the value is in the setlist, otherwise 0
|
|
"""
|
|
if value in self:
|
|
return 1
|
|
else:
|
|
return 0
|
|
|
|
def index(self, value, start=0, end=None):
|
|
"""Return the index of value between start and end.
|
|
|
|
By default, the entire setlist is searched.
|
|
|
|
This runs in O(1)
|
|
|
|
Args:
|
|
value: The value to find the index of
|
|
start (int): The index to start searching at (defaults to 0)
|
|
end (int): The index to stop searching at (defaults to the end of the list)
|
|
Returns:
|
|
int: The index of the value
|
|
Raises:
|
|
ValueError: If the value is not in the list or outside of start - end
|
|
IndexError: If start or end are out of range
|
|
"""
|
|
try:
|
|
index = self._dict[value]
|
|
except KeyError:
|
|
raise ValueError
|
|
else:
|
|
start = self._fix_neg_index(start)
|
|
end = self._fix_end_index(end)
|
|
if start <= index and index < end:
|
|
return index
|
|
else:
|
|
raise ValueError
|
|
|
|
@classmethod
|
|
def _check_type(cls, other, operand_name):
|
|
if not isinstance(other, _basesetlist):
|
|
message = (
|
|
"unsupported operand type(s) for {operand_name}: "
|
|
"'{self_type}' and '{other_type}'").format(
|
|
operand_name=operand_name,
|
|
self_type=cls,
|
|
other_type=type(other),
|
|
)
|
|
raise TypeError(message)
|
|
|
|
def __add__(self, other):
|
|
self._check_type(other, '+')
|
|
out = self.copy()
|
|
out._extend(other)
|
|
return out
|
|
|
|
# Implement Set
|
|
|
|
def issubset(self, other):
|
|
return self <= other
|
|
|
|
def issuperset(self, other):
|
|
return self >= other
|
|
|
|
def union(self, other):
|
|
out = self.copy()
|
|
out.update(other)
|
|
return out
|
|
|
|
def intersection(self, other):
|
|
other = set(other)
|
|
return self._from_iterable(item for item in self if item in other)
|
|
|
|
def difference(self, other):
|
|
other = set(other)
|
|
return self._from_iterable(item for item in self if item not in other)
|
|
|
|
def symmetric_difference(self, other):
|
|
return self.union(other) - self.intersection(other)
|
|
|
|
def __sub__(self, other):
|
|
self._check_type(other, '-')
|
|
return self.difference(other)
|
|
|
|
def __and__(self, other):
|
|
self._check_type(other, '&')
|
|
return self.intersection(other)
|
|
|
|
def __or__(self, other):
|
|
self._check_type(other, '|')
|
|
return self.union(other)
|
|
|
|
def __xor__(self, other):
|
|
self._check_type(other, '^')
|
|
return self.symmetric_difference(other)
|
|
|
|
# Comparison
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, _basesetlist):
|
|
return False
|
|
if not len(self) == len(other):
|
|
return False
|
|
for self_elem, other_elem in zip(self, other):
|
|
if self_elem != other_elem:
|
|
return False
|
|
return True
|
|
|
|
def __ne__(self, other):
|
|
return not (self == other)
|
|
|
|
# New methods
|
|
|
|
def sub_index(self, sub, start=0, end=None):
|
|
"""Return the index of a subsequence.
|
|
|
|
This runs in O(len(sub))
|
|
|
|
Args:
|
|
sub (Sequence): An Iterable to search for
|
|
Returns:
|
|
int: The index of the first element of sub
|
|
Raises:
|
|
ValueError: If sub isn't a subsequence
|
|
TypeError: If sub isn't iterable
|
|
IndexError: If start or end are out of range
|
|
"""
|
|
start_index = self.index(sub[0], start, end)
|
|
end = self._fix_end_index(end)
|
|
if start_index + len(sub) > end:
|
|
raise ValueError
|
|
for i in range(1, len(sub)):
|
|
if sub[i] != self[start_index + i]:
|
|
raise ValueError
|
|
return start_index
|
|
|
|
def copy(self):
|
|
return self.__class__(self)
|
|
|
|
|
|
class setlist(_basesetlist, MutableSequence, MutableSet):
|
|
"""A mutable (unhashable) setlist."""
|
|
|
|
def __str__(self):
|
|
return '{[%s}]' % ', '.join(repr(v) for v in self)
|
|
|
|
# Helper methods
|
|
def _delete_all(self, elems_to_delete, raise_errors):
|
|
indices_to_delete = set()
|
|
for elem in elems_to_delete:
|
|
try:
|
|
elem_index = self._dict[elem]
|
|
except KeyError:
|
|
if raise_errors:
|
|
raise ValueError('Passed values contain elements not in self')
|
|
else:
|
|
if elem_index in indices_to_delete:
|
|
if raise_errors:
|
|
raise ValueError('Passed vales contain duplicates')
|
|
indices_to_delete.add(elem_index)
|
|
self._delete_values_by_index(indices_to_delete)
|
|
|
|
def _delete_values_by_index(self, indices_to_delete):
|
|
deleted_count = 0
|
|
for i, elem in enumerate(self._list):
|
|
if i in indices_to_delete:
|
|
deleted_count += 1
|
|
del self._dict[elem]
|
|
else:
|
|
new_index = i - deleted_count
|
|
self._list[new_index] = elem
|
|
self._dict[elem] = new_index
|
|
# Now remove deleted_count items from the end of the list
|
|
if deleted_count:
|
|
self._list = self._list[:-deleted_count]
|
|
|
|
# Set/Sequence agnostic
|
|
def pop(self, index=-1):
|
|
"""Remove and return the item at index."""
|
|
value = self._list.pop(index)
|
|
del self._dict[value]
|
|
return value
|
|
|
|
def clear(self):
|
|
"""Remove all elements from self."""
|
|
self._dict = dict()
|
|
self._list = list()
|
|
|
|
# Implement MutableSequence
|
|
def __setitem__(self, index, value):
|
|
if isinstance(index, slice):
|
|
old_values = self[index]
|
|
for v in value:
|
|
if v in self and v not in old_values:
|
|
raise ValueError
|
|
self._list[index] = value
|
|
self._dict = {}
|
|
for i, v in enumerate(self._list):
|
|
self._dict[v] = i
|
|
else:
|
|
index = self._fix_neg_index(index)
|
|
old_value = self._list[index]
|
|
if value in self:
|
|
if value == old_value:
|
|
return
|
|
else:
|
|
raise ValueError
|
|
del self._dict[old_value]
|
|
self._list[index] = value
|
|
self._dict[value] = index
|
|
|
|
def __delitem__(self, index):
|
|
if isinstance(index, slice):
|
|
indices_to_delete = set(self.index(e) for e in self._list[index])
|
|
self._delete_values_by_index(indices_to_delete)
|
|
else:
|
|
index = self._fix_neg_index(index)
|
|
value = self._list[index]
|
|
del self._dict[value]
|
|
for elem in self._list[index + 1:]:
|
|
self._dict[elem] -= 1
|
|
del self._list[index]
|
|
|
|
def insert(self, index, value):
|
|
"""Insert value at index.
|
|
|
|
Args:
|
|
index (int): Index to insert value at
|
|
value: Value to insert
|
|
Raises:
|
|
ValueError: If value already in self
|
|
IndexError: If start or end are out of range
|
|
"""
|
|
if value in self:
|
|
raise ValueError
|
|
index = self._fix_neg_index(index)
|
|
self._dict[value] = index
|
|
for elem in self._list[index:]:
|
|
self._dict[elem] += 1
|
|
self._list.insert(index, value)
|
|
|
|
def append(self, value):
|
|
"""Append value to the end.
|
|
|
|
Args:
|
|
value: Value to append
|
|
Raises:
|
|
ValueError: If value alread in self
|
|
TypeError: If value isn't hashable
|
|
"""
|
|
self._append(value)
|
|
|
|
def extend(self, values):
|
|
"""Append all values to the end.
|
|
|
|
If any of the values are present, ValueError will
|
|
be raised and none of the values will be appended.
|
|
|
|
Args:
|
|
values (Iterable): Values to append
|
|
Raises:
|
|
ValueError: If any values are already present or there are duplicates
|
|
in the passed values.
|
|
TypeError: If any of the values aren't hashable.
|
|
"""
|
|
self._extend(values)
|
|
|
|
def __iadd__(self, values):
|
|
"""Add all values to the end of self.
|
|
|
|
Args:
|
|
values (Iterable): Values to append
|
|
Raises:
|
|
ValueError: If any values are already present
|
|
"""
|
|
self._check_type(values, '+=')
|
|
self.extend(values)
|
|
return self
|
|
|
|
def remove(self, value):
|
|
"""Remove value from self.
|
|
|
|
Args:
|
|
value: Element to remove from self
|
|
Raises:
|
|
ValueError: if element is already present
|
|
"""
|
|
try:
|
|
index = self._dict[value]
|
|
except KeyError:
|
|
raise ValueError('Value "%s" is not present.')
|
|
else:
|
|
del self[index]
|
|
|
|
def remove_all(self, elems_to_delete):
|
|
"""Remove all elements from elems_to_delete, raises ValueErrors.
|
|
|
|
See Also:
|
|
discard_all
|
|
Args:
|
|
elems_to_delete (Iterable): Elements to remove.
|
|
Raises:
|
|
ValueError: If the count of any element is greater in
|
|
elems_to_delete than self.
|
|
TypeError: If any of the values aren't hashable.
|
|
"""
|
|
self._delete_all(elems_to_delete, raise_errors=True)
|
|
|
|
# Implement MutableSet
|
|
|
|
def add(self, item):
|
|
"""Add an item.
|
|
|
|
Note:
|
|
This does not raise a ValueError for an already present value like
|
|
append does. This is to match the behavior of set.add
|
|
Args:
|
|
item: Item to add
|
|
Raises:
|
|
TypeError: If item isn't hashable.
|
|
"""
|
|
self._add(item)
|
|
|
|
def update(self, values):
|
|
"""Add all values to the end.
|
|
|
|
If any of the values are present, silently ignore
|
|
them (as opposed to extend which raises an Error).
|
|
|
|
See also:
|
|
extend
|
|
Args:
|
|
values (Iterable): Values to add
|
|
Raises:
|
|
TypeError: If any of the values are unhashable.
|
|
"""
|
|
self._update(values)
|
|
|
|
def discard_all(self, elems_to_delete):
|
|
"""Discard all the elements from elems_to_delete.
|
|
|
|
This is much faster than removing them one by one.
|
|
This runs in O(len(self) + len(elems_to_delete))
|
|
|
|
Args:
|
|
elems_to_delete (Iterable): Elements to discard.
|
|
Raises:
|
|
TypeError: If any of the values aren't hashable.
|
|
"""
|
|
self._delete_all(elems_to_delete, raise_errors=False)
|
|
|
|
def discard(self, value):
|
|
"""Discard an item.
|
|
|
|
Note:
|
|
This does not raise a ValueError for a missing value like remove does.
|
|
This is to match the behavior of set.discard
|
|
"""
|
|
try:
|
|
self.remove(value)
|
|
except ValueError:
|
|
pass
|
|
|
|
def difference_update(self, other):
|
|
"""Update self to include only the differene with other."""
|
|
other = set(other)
|
|
indices_to_delete = set()
|
|
for i, elem in enumerate(self):
|
|
if elem in other:
|
|
indices_to_delete.add(i)
|
|
if indices_to_delete:
|
|
self._delete_values_by_index(indices_to_delete)
|
|
|
|
def intersection_update(self, other):
|
|
"""Update self to include only the intersection with other."""
|
|
other = set(other)
|
|
indices_to_delete = set()
|
|
for i, elem in enumerate(self):
|
|
if elem not in other:
|
|
indices_to_delete.add(i)
|
|
if indices_to_delete:
|
|
self._delete_values_by_index(indices_to_delete)
|
|
|
|
def symmetric_difference_update(self, other):
|
|
"""Update self to include only the symmetric difference with other."""
|
|
other = setlist(other)
|
|
indices_to_delete = set()
|
|
for i, item in enumerate(self):
|
|
if item in other:
|
|
indices_to_delete.add(i)
|
|
for item in other:
|
|
self.add(item)
|
|
self._delete_values_by_index(indices_to_delete)
|
|
|
|
def __isub__(self, other):
|
|
self._check_type(other, '-=')
|
|
self.difference_update(other)
|
|
return self
|
|
|
|
def __iand__(self, other):
|
|
self._check_type(other, '&=')
|
|
self.intersection_update(other)
|
|
return self
|
|
|
|
def __ior__(self, other):
|
|
self._check_type(other, '|=')
|
|
self.update(other)
|
|
return self
|
|
|
|
def __ixor__(self, other):
|
|
self._check_type(other, '^=')
|
|
self.symmetric_difference_update(other)
|
|
return self
|
|
|
|
# New methods
|
|
def shuffle(self, random=None):
|
|
"""Shuffle all of the elements in self randomly."""
|
|
random_.shuffle(self._list, random=random)
|
|
for i, elem in enumerate(self._list):
|
|
self._dict[elem] = i
|
|
|
|
def sort(self, *args, **kwargs):
|
|
"""Sort this setlist in place."""
|
|
self._list.sort(*args, **kwargs)
|
|
for index, value in enumerate(self._list):
|
|
self._dict[value] = index
|
|
|
|
|
|
class frozensetlist(_basesetlist, Hashable):
|
|
"""An immutable (hashable) setlist."""
|
|
|
|
def __hash__(self):
|
|
if not hasattr(self, '_hash_value'):
|
|
self._hash_value = _util.hash_iterable(self)
|
|
return self._hash_value
|