Use a proper multiset for progression items

This can cut generation times in half in some cases
This commit is contained in:
Kevin Cathcart 2019-07-13 18:17:16 -04:00
parent dcca15eda7
commit 6d5a0a004d
11 changed files with 1909 additions and 29 deletions

View File

@ -3,6 +3,7 @@ from enum import Enum, unique
import logging
import json
from collections import OrderedDict
from _vendor.collections_extended import bag, setlist
from Utils import int16_as_bytes
class World(object):
@ -135,34 +136,34 @@ class World(object):
if ret.has('Golden Sword', item.player):
pass
elif ret.has('Tempered Sword', item.player) and self.difficulty_requirements.progressive_sword_limit >= 4:
ret.prog_items.append(('Golden Sword', item.player))
ret.prog_items.add(('Golden Sword', item.player))
elif ret.has('Master Sword', item.player) and self.difficulty_requirements.progressive_sword_limit >= 3:
ret.prog_items.append(('Tempered Sword', item.player))
ret.prog_items.add(('Tempered Sword', item.player))
elif ret.has('Fighter Sword', item.player) and self.difficulty_requirements.progressive_sword_limit >= 2:
ret.prog_items.append(('Master Sword', item.player))
ret.prog_items.add(('Master Sword', item.player))
elif self.difficulty_requirements.progressive_sword_limit >= 1:
ret.prog_items.append(('Fighter Sword', item.player))
ret.prog_items.add(('Fighter Sword', item.player))
elif 'Glove' in item.name:
if ret.has('Titans Mitts', item.player):
pass
elif ret.has('Power Glove', item.player):
ret.prog_items.append(('Titans Mitts', item.player))
ret.prog_items.add(('Titans Mitts', item.player))
else:
ret.prog_items.append(('Power Glove', item.player))
ret.prog_items.add(('Power Glove', item.player))
elif 'Shield' in item.name:
if ret.has('Mirror Shield', item.player):
pass
elif ret.has('Red Shield', item.player) and self.difficulty_requirements.progressive_shield_limit >= 3:
ret.prog_items.append(('Mirror Shield', item.player))
ret.prog_items.add(('Mirror Shield', item.player))
elif ret.has('Blue Shield', item.player) and self.difficulty_requirements.progressive_shield_limit >= 2:
ret.prog_items.append(('Red Shield', item.player))
ret.prog_items.add(('Red Shield', item.player))
elif self.difficulty_requirements.progressive_shield_limit >= 1:
ret.prog_items.append(('Blue Shield', item.player))
ret.prog_items.add(('Blue Shield', item.player))
elif item.name.startswith('Bottle'):
if ret.bottle_count(item.player) < self.difficulty_requirements.progressive_bottle_limit:
ret.prog_items.append((item.name, item.player))
ret.prog_items.add((item.name, item.player))
elif item.advancement or item.key:
ret.prog_items.append((item.name, item.player))
ret.prog_items.add((item.name, item.player))
for item in self.itempool:
soft_collect(item)
@ -284,7 +285,7 @@ class World(object):
class CollectionState(object):
def __init__(self, parent):
self.prog_items = []
self.prog_items = bag()
self.world = parent
self.reachable_regions = {player: set() for player in range(1, parent.players + 1)}
self.events = []
@ -293,12 +294,13 @@ class CollectionState(object):
self.stale = {player: True for player in range(1, parent.players + 1)}
def update_reachable_regions(self, player):
player_regions = [region for region in self.world.regions if region.player == player]
self.stale[player] = False
rrp = self.reachable_regions[player]
new_regions = True
reachable_regions_count = len(rrp)
while new_regions:
possible = [region for region in self.world.regions if region not in rrp and region.player == player]
possible = [region for region in player_regions if region not in rrp]
for candidate in possible:
if candidate.can_reach_private(self):
rrp.add(candidate)
@ -307,7 +309,7 @@ class CollectionState(object):
def copy(self):
ret = CollectionState(self.world)
ret.prog_items = copy.copy(self.prog_items)
ret.prog_items = self.prog_items.copy()
ret.reachable_regions = {player: copy.copy(self.reachable_regions[player]) for player in range(1, self.world.players + 1)}
ret.events = copy.copy(self.events)
ret.path = copy.copy(self.path)
@ -343,17 +345,18 @@ class CollectionState(object):
self.collect(event.item, True, event)
new_locations = len(reachable_events) > checked_locations
checked_locations = len(reachable_events)
def has(self, item, player, count=1):
if count == 1:
return (item, player) in self.prog_items
return self.item_count(item, player) >= count
return self.prog_items.count((item, player)) >= count
def has_key(self, item, player, count=1):
if self.world.retro:
return self.can_buy_unlimited('Small Key (Universal)', player)
if count == 1:
return (item, player) in self.prog_items
return self.item_count(item, player) >= count
return self.prog_items.count((item, player)) >= count
def can_buy_unlimited(self, item, player):
for shop in self.world.shops:
@ -362,7 +365,7 @@ class CollectionState(object):
return False
def item_count(self, item, player):
return len([pritem for pritem in self.prog_items if pritem == (item, player)])
return self.prog_items.count((item, player))
def can_lift_rocks(self, player):
return self.has('Power Glove', player) or self.has('Titans Mitts', player)
@ -467,44 +470,44 @@ class CollectionState(object):
if self.has('Golden Sword', item.player):
pass
elif self.has('Tempered Sword', item.player) and self.world.difficulty_requirements.progressive_sword_limit >= 4:
self.prog_items.append(('Golden Sword', item.player))
self.prog_items.add(('Golden Sword', item.player))
changed = True
elif self.has('Master Sword', item.player) and self.world.difficulty_requirements.progressive_sword_limit >= 3:
self.prog_items.append(('Tempered Sword', item.player))
self.prog_items.add(('Tempered Sword', item.player))
changed = True
elif self.has('Fighter Sword', item.player) and self.world.difficulty_requirements.progressive_sword_limit >= 2:
self.prog_items.append(('Master Sword', item.player))
self.prog_items.add(('Master Sword', item.player))
changed = True
elif self.world.difficulty_requirements.progressive_sword_limit >= 1:
self.prog_items.append(('Fighter Sword', item.player))
self.prog_items.add(('Fighter Sword', item.player))
changed = True
elif 'Glove' in item.name:
if self.has('Titans Mitts', item.player):
pass
elif self.has('Power Glove', item.player):
self.prog_items.append(('Titans Mitts', item.player))
self.prog_items.add(('Titans Mitts', item.player))
changed = True
else:
self.prog_items.append(('Power Glove', item.player))
self.prog_items.add(('Power Glove', item.player))
changed = True
elif 'Shield' in item.name:
if self.has('Mirror Shield', item.player):
pass
elif self.has('Red Shield', item.player) and self.world.difficulty_requirements.progressive_shield_limit >= 3:
self.prog_items.append(('Mirror Shield', item.player))
self.prog_items.add(('Mirror Shield', item.player))
changed = True
elif self.has('Blue Shield', item.player) and self.world.difficulty_requirements.progressive_shield_limit >= 2:
self.prog_items.append(('Red Shield', item.player))
self.prog_items.add(('Red Shield', item.player))
changed = True
elif self.world.difficulty_requirements.progressive_shield_limit >= 1:
self.prog_items.append(('Blue Shield', item.player))
self.prog_items.add(('Blue Shield', item.player))
changed = True
elif item.name.startswith('Bottle'):
if self.bottle_count(item.player) < self.world.difficulty_requirements.progressive_bottle_limit:
self.prog_items.append((item.name, item.player))
self.prog_items.add((item.name, item.player))
changed = True
elif event or item.advancement:
self.prog_items.append((item.name, item.player))
self.prog_items.add((item.name, item.player))
changed = True
self.stale[item.player] = True

View File

@ -227,7 +227,7 @@ def copy_world(world):
ret.itempool.append(Item(item.name, item.advancement, item.priority, item.type, player = item.player))
# copy progress items in state
ret.state.prog_items = list(world.state.prog_items)
ret.state.prog_items = world.state.prog_items.copy()
ret.state.stale = {player: True for player in range(1, world.players + 1)}
for player in range(1, world.players + 1):

View File

@ -0,0 +1,5 @@
Mike Lenzen <m.lenzen@gmail.com> https://github.com/mlenzen
Caleb Levy <caleb.levy@berkeley.edu> https://github.com/caleblevy
Marein Könings <mail@marein.org> https://github.com/MareinK
Jad Kik <jadkik94@gmail.com> https://github.com/jadkik
Kuba Marek <blue.cube@seznam.cz> https://github.com/bluecube

View File

@ -0,0 +1,191 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.
"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made
available under the License, as indicated by a copyright notice that is included
in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.
2. Grant of Copyright License.
Subject to the terms and conditions of this License, each Contributor hereby
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the Work and such
Derivative Works in Source or Object form.
3. Grant of Patent License.
Subject to the terms and conditions of this License, each Contributor hereby
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this section) patent license to make, have
made, use, offer to sell, sell, import, and otherwise transfer the Work, where
such license applies only to those patent claims licensable by such Contributor
that are necessarily infringed by their Contribution(s) alone or by combination
of their Contribution(s) with the Work to which such Contribution(s) was
submitted. If You institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work or a
Contribution incorporated within the Work constitutes direct or contributory
patent infringement, then any patent licenses granted to You under this License
for that Work shall terminate as of the date such litigation is filed.
4. Redistribution.
You may reproduce and distribute copies of the Work or Derivative Works thereof
in any medium, with or without modifications, and in Source or Object form,
provided that You meet the following conditions:
You must give any other recipients of the Work or Derivative Works a copy of
this License; and
You must cause any modified files to carry prominent notices stating that You
changed the files; and
You must retain, in the Source form of any Derivative Works that You distribute,
all copyright, patent, trademark, and attribution notices from the Source form
of the Work, excluding those notices that do not pertain to any part of the
Derivative Works; and
If the Work includes a "NOTICE" text file as part of its distribution, then any
Derivative Works that You distribute must include a readable copy of the
attribution notices contained within such NOTICE file, excluding those notices
that do not pertain to any part of the Derivative Works, in at least one of the
following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.
You may add Your own copyright statement to Your modifications and may provide
additional or different license terms and conditions for use, reproduction, or
distribution of Your modifications, or for any such Derivative Works as a whole,
provided Your use, reproduction, and distribution of the Work otherwise complies
with the conditions stated in this License.
5. Submission of Contributions.
Unless You explicitly state otherwise, any Contribution intentionally submitted
for inclusion in the Work by You to the Licensor shall be under the terms and
conditions of this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify the terms of
any separate license agreement you may have executed with Licensor regarding
such Contributions.
6. Trademarks.
This License does not grant permission to use the trade names, trademarks,
service marks, or product names of the Licensor, except as required for
reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.
7. Disclaimer of Warranty.
Unless required by applicable law or agreed to in writing, Licensor provides the
Work (and each Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
including, without limitation, any warranties or conditions of TITLE,
NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
solely responsible for determining the appropriateness of using or
redistributing the Work and assume any risks associated with Your exercise of
permissions under this License.
8. Limitation of Liability.
In no event and under no legal theory, whether in tort (including negligence),
contract, or otherwise, unless required by applicable law (such as deliberate
and grossly negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special, incidental,
or consequential damages of any character arising as a result of this License or
out of the use or inability to use the Work (including but not limited to
damages for loss of goodwill, work stoppage, computer failure or malfunction, or
any and all other commercial damages or losses), even if such Contributor has
been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability.
While redistributing the Work or Derivative Works thereof, You may choose to
offer, and charge a fee for, acceptance of support, warranty, indemnity, or
other liability obligations and/or rights consistent with this License. However,
in accepting such obligations, You may act only on Your own behalf and on Your
sole responsibility, not on behalf of any other Contributor, and only if You
agree to indemnify, defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason of your
accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on
the same "printed page" as the copyright notice for easier identification within
third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -0,0 +1,55 @@
"""collections_extended contains a few extra basic data structures."""
from ._compat import Collection
from .bags import bag, frozenbag
from .setlists import setlist, frozensetlist
from .bijection import bijection
from .range_map import RangeMap, MappedRange
__version__ = '1.0.2'
__all__ = (
'collection',
'setlist',
'frozensetlist',
'bag',
'frozenbag',
'bijection',
'RangeMap',
'MappedRange',
'Collection',
)
def collection(iterable=None, mutable=True, ordered=False, unique=False):
"""Return a Collection with the specified properties.
Args:
iterable (Iterable): collection to instantiate new collection from.
mutable (bool): Whether or not the new collection is mutable.
ordered (bool): Whether or not the new collection is ordered.
unique (bool): Whether or not the new collection contains only unique values.
"""
if iterable is None:
iterable = tuple()
if unique:
if ordered:
if mutable:
return setlist(iterable)
else:
return frozensetlist(iterable)
else:
if mutable:
return set(iterable)
else:
return frozenset(iterable)
else:
if ordered:
if mutable:
return list(iterable)
else:
return tuple(iterable)
else:
if mutable:
return bag(iterable)
else:
return frozenbag(iterable)

View File

@ -0,0 +1,53 @@
"""Python 2/3 compatibility helpers."""
import sys
is_py2 = sys.version_info[0] == 2
if is_py2:
def keys_set(d):
"""Return a set of passed dictionary's keys."""
return set(d.keys())
else:
keys_set = dict.keys
if sys.version_info < (3, 6):
from collections import Sized, Iterable, Container
def _check_methods(C, *methods):
mro = C.__mro__
for method in methods:
for B in mro:
if method in B.__dict__:
if B.__dict__[method] is None:
return NotImplemented
break
else:
return NotImplemented
return True
class Collection(Sized, Iterable, Container):
"""Backport from Python3.6."""
__slots__ = tuple()
@classmethod
def __subclasshook__(cls, C):
if cls is Collection:
return _check_methods(C, "__len__", "__iter__", "__contains__")
return NotImplemented
else:
from collections.abc import Collection
def handle_rich_comp_not_implemented():
"""Correctly handle unimplemented rich comparisons.
In Python 3, return NotImplemented.
In Python 2, raise a TypeError.
"""
if is_py2:
raise TypeError()
else:
return NotImplemented

View File

@ -0,0 +1,16 @@
"""util functions for collections_extended."""
def hash_iterable(it):
"""Perform a O(1) memory hash of an iterable of arbitrary length.
hash(tuple(it)) creates a temporary tuple containing all values from it
which could be a problem if it is large.
See discussion at:
https://groups.google.com/forum/#!msg/python-ideas/XcuC01a8SYs/e-doB9TbDwAJ
"""
hash_value = hash(type(it))
for value in it:
hash_value = hash((hash_value, value))
return hash_value

View File

@ -0,0 +1,527 @@
"""Bag class definitions."""
import heapq
from operator import itemgetter
from collections import Set, MutableSet, Hashable
from . import _compat
class _basebag(Set):
"""Base class for bag classes.
Base class for bag and frozenbag. Is not mutable and not hashable, so there's
no reason to use this instead of either bag or frozenbag.
"""
# Basic object methods
def __init__(self, iterable=None):
"""Create a new basebag.
If iterable isn't given, is None or is empty then the bag starts empty.
Otherwise each element from iterable will be added to the bag
however many times it appears.
This runs in O(len(iterable))
"""
self._dict = dict()
self._size = 0
if iterable:
if isinstance(iterable, _basebag):
for elem, count in iterable._dict.items():
self._dict[elem] = count
self._size += count
else:
for value in iterable:
self._dict[value] = self._dict.get(value, 0) + 1
self._size += 1
def __repr__(self):
if self._size == 0:
return '{0}()'.format(self.__class__.__name__)
else:
repr_format = '{class_name}({values!r})'
return repr_format.format(
class_name=self.__class__.__name__,
values=tuple(self),
)
def __str__(self):
if self._size == 0:
return '{class_name}()'.format(class_name=self.__class__.__name__)
else:
format_single = '{elem!r}'
format_mult = '{elem!r}^{mult}'
strings = []
for elem, mult in self._dict.items():
if mult > 1:
strings.append(format_mult.format(elem=elem, mult=mult))
else:
strings.append(format_single.format(elem=elem))
return '{%s}' % ', '.join(strings)
# New public methods (not overriding/implementing anything)
def num_unique_elements(self):
"""Return the number of unique elements.
This runs in O(1) time
"""
return len(self._dict)
def unique_elements(self):
"""Return a view of unique elements in this bag.
In Python 3:
This runs in O(1) time and returns a view of the unique elements
In Python 2:
This runs in O(n) and returns set of the current elements.
"""
return _compat.keys_set(self._dict)
def count(self, value):
"""Return the number of value present in this bag.
If value is not in the bag no Error is raised, instead 0 is returned.
This runs in O(1) time
Args:
value: The element of self to get the count of
Returns:
int: The count of value in self
"""
return self._dict.get(value, 0)
def nlargest(self, n=None):
"""List the n most common elements and their counts.
List is from the most
common to the least. If n is None, the list all element counts.
Run time should be O(m log m) where m is len(self)
Args:
n (int): The number of elements to return
"""
if n is None:
return sorted(self._dict.items(), key=itemgetter(1), reverse=True)
else:
return heapq.nlargest(n, self._dict.items(), key=itemgetter(1))
@classmethod
def _from_iterable(cls, it):
return cls(it)
@classmethod
def from_mapping(cls, mapping):
"""Create a bag from a dict of elem->count.
Each key in the dict is added if the value is > 0.
"""
out = cls()
for elem, count in mapping.items():
if count > 0:
out._dict[elem] = count
out._size += count
return out
def copy(self):
"""Create a shallow copy of self.
This runs in O(len(self.num_unique_elements()))
"""
return self.from_mapping(self._dict)
# implementing Sized methods
def __len__(self):
"""Return the cardinality of the bag.
This runs in O(1)
"""
return self._size
# implementing Container methods
def __contains__(self, value):
"""Return the multiplicity of the element.
This runs in O(1)
"""
return self._dict.get(value, 0)
# implementing Iterable methods
def __iter__(self):
"""Iterate through all elements.
Multiple copies will be returned if they exist.
"""
for value, count in self._dict.items():
for i in range(count):
yield(value)
# Comparison methods
def _is_subset(self, other):
"""Check that every element in self has a count <= in other.
Args:
other (Set)
"""
if isinstance(other, _basebag):
for elem, count in self._dict.items():
if not count <= other._dict.get(elem, 0):
return False
else:
for elem in self:
if self._dict.get(elem, 0) > 1 or elem not in other:
return False
return True
def _is_superset(self, other):
"""Check that every element in self has a count >= in other.
Args:
other (Set)
"""
if isinstance(other, _basebag):
for elem, count in other._dict.items():
if not self._dict.get(elem, 0) >= count:
return False
else:
for elem in other:
if elem not in self:
return False
return True
def __le__(self, other):
if not isinstance(other, Set):
return _compat.handle_rich_comp_not_implemented()
return len(self) <= len(other) and self._is_subset(other)
def __lt__(self, other):
if not isinstance(other, Set):
return _compat.handle_rich_comp_not_implemented()
return len(self) < len(other) and self._is_subset(other)
def __gt__(self, other):
if not isinstance(other, Set):
return _compat.handle_rich_comp_not_implemented()
return len(self) > len(other) and self._is_superset(other)
def __ge__(self, other):
if not isinstance(other, Set):
return _compat.handle_rich_comp_not_implemented()
return len(self) >= len(other) and self._is_superset(other)
def __eq__(self, other):
if not isinstance(other, Set):
return False
if isinstance(other, _basebag):
return self._dict == other._dict
if not len(self) == len(other):
return False
for elem in other:
if self._dict.get(elem, 0) != 1:
return False
return True
def __ne__(self, other):
return not (self == other)
# Operations - &, |, +, -, ^, * and isdisjoint
def __and__(self, other):
"""Intersection is the minimum of corresponding counts.
This runs in O(l + n) where:
n is self.num_unique_elements()
if other is a bag:
l = 1
else:
l = len(other)
"""
if not isinstance(other, _basebag):
other = self._from_iterable(other)
values = dict()
for elem in self._dict:
values[elem] = min(other._dict.get(elem, 0), self._dict.get(elem, 0))
return self.from_mapping(values)
def isdisjoint(self, other):
"""Return if this bag is disjoint with the passed collection.
This runs in O(len(other))
TODO move isdisjoint somewhere more appropriate
"""
for value in other:
if value in self:
return False
return True
def __or__(self, other):
"""Union is the maximum of all elements.
This runs in O(m + n) where:
n is self.num_unique_elements()
if other is a bag:
m = other.num_unique_elements()
else:
m = len(other)
"""
if not isinstance(other, _basebag):
other = self._from_iterable(other)
values = dict()
for elem in self.unique_elements() | other.unique_elements():
values[elem] = max(self._dict.get(elem, 0), other._dict.get(elem, 0))
return self.from_mapping(values)
def __add__(self, other):
"""Return a new bag also containing all the elements of other.
self + other = self & other + self | other
This runs in O(m + n) where:
n is self.num_unique_elements()
m is len(other)
Args:
other (Iterable): elements to add to self
"""
out = self.copy()
for value in other:
out._dict[value] = out._dict.get(value, 0) + 1
out._size += 1
return out
def __sub__(self, other):
"""Difference between the sets.
For normal sets this is all x s.t. x in self and x not in other.
For bags this is count(x) = max(0, self.count(x)-other.count(x))
This runs in O(m + n) where:
n is self.num_unique_elements()
m is len(other)
Args:
other (Iterable): elements to remove
"""
out = self.copy()
for value in other:
old_count = out._dict.get(value, 0)
if old_count == 1:
del out._dict[value]
out._size -= 1
elif old_count > 1:
out._dict[value] = old_count - 1
out._size -= 1
return out
def __mul__(self, other):
"""Cartesian product of the two sets.
other can be any iterable.
Both self and other must contain elements that can be added together.
This should run in O(m*n+l) where:
m is the number of unique elements in self
n is the number of unique elements in other
if other is a bag:
l is 0
else:
l is the len(other)
The +l will only really matter when other is an iterable with MANY
repeated elements.
For example: {'a'^2} * 'bbbbbbbbbbbbbbbbbbbbbbbbbb'
The algorithm will be dominated by counting the 'b's
"""
if not isinstance(other, _basebag):
other = self._from_iterable(other)
values = dict()
for elem, count in self._dict.items():
for other_elem, other_count in other._dict.items():
new_elem = elem + other_elem
new_count = count * other_count
values[new_elem] = new_count
return self.from_mapping(values)
def __xor__(self, other):
"""Symmetric difference between the sets.
other can be any iterable.
This runs in O(m + n) where:
m = len(self)
n = len(other)
"""
return (self - other) | (other - self)
class bag(_basebag, MutableSet):
"""bag is a mutable unhashable bag."""
def pop(self):
"""Remove and return an element of self."""
# TODO can this be done more efficiently (no need to create an iterator)?
it = iter(self)
try:
value = next(it)
except StopIteration:
raise KeyError
self.discard(value)
return value
def add(self, elem):
"""Add elem to self."""
self._dict[elem] = self._dict.get(elem, 0) + 1
self._size += 1
def discard(self, elem):
"""Remove elem from this bag, silent if it isn't present."""
try:
self.remove(elem)
except ValueError:
pass
def remove(self, elem):
"""Remove elem from this bag, raising a ValueError if it isn't present.
Args:
elem: object to remove from self
Raises:
ValueError: if the elem isn't present
"""
old_count = self._dict.get(elem, 0)
if old_count == 0:
raise ValueError
elif old_count == 1:
del self._dict[elem]
else:
self._dict[elem] -= 1
self._size -= 1
def discard_all(self, other):
"""Discard all of the elems from other."""
if not isinstance(other, _basebag):
other = self._from_iterable(other)
for elem, other_count in other._dict.items():
old_count = self._dict.get(elem, 0)
new_count = old_count - other_count
if new_count >= 0:
if new_count == 0:
if elem in self:
del self._dict[elem]
else:
self._dict[elem] = new_count
self._size += new_count - old_count
def remove_all(self, other):
"""Remove all of the elems from other.
Raises a ValueError if the multiplicity of any elem in other is greater
than in self.
"""
if not self._is_superset(other):
raise ValueError
self.discard_all(other)
def clear(self):
"""Remove all elements from this bag."""
self._dict = dict()
self._size = 0
# In-place operations
def __ior__(self, other):
"""Set multiplicity of each element to the maximum of the two collections.
if isinstance(other, _basebag):
This runs in O(other.num_unique_elements())
else:
This runs in O(len(other))
"""
if not isinstance(other, _basebag):
other = self._from_iterable(other)
for elem, other_count in other._dict.items():
old_count = self._dict.get(elem, 0)
new_count = max(other_count, old_count)
self._dict[elem] = new_count
self._size += new_count - old_count
return self
def __iand__(self, other):
"""Set multiplicity of each element to the minimum of the two collections.
if isinstance(other, _basebag):
This runs in O(other.num_unique_elements())
else:
This runs in O(len(other))
"""
if not isinstance(other, _basebag):
other = self._from_iterable(other)
for elem, old_count in set(self._dict.items()):
other_count = other._dict.get(elem, 0)
new_count = min(other_count, old_count)
if new_count == 0:
del self._dict[elem]
else:
self._dict[elem] = new_count
self._size += new_count - old_count
return self
def __ixor__(self, other):
"""Set self to the symmetric difference between the sets.
if isinstance(other, _basebag):
This runs in O(other.num_unique_elements())
else:
This runs in O(len(other))
"""
if not isinstance(other, _basebag):
other = self._from_iterable(other)
other_minus_self = other - self
self -= other
self |= other_minus_self
return self
def __isub__(self, other):
"""Discard the elements of other from self.
if isinstance(it, _basebag):
This runs in O(it.num_unique_elements())
else:
This runs in O(len(it))
"""
self.discard_all(other)
return self
def __iadd__(self, other):
"""Add all of the elements of other to self.
if isinstance(it, _basebag):
This runs in O(it.num_unique_elements())
else:
This runs in O(len(it))
"""
if not isinstance(other, _basebag):
other = self._from_iterable(other)
for elem, other_count in other._dict.items():
self._dict[elem] = self._dict.get(elem, 0) + other_count
self._size += other_count
return self
class frozenbag(_basebag, Hashable):
"""frozenbag is an immutable, hashable bab."""
def __hash__(self):
"""Compute the hash value of a frozenbag.
This was copied directly from _collections_abc.Set._hash in Python3 which
is identical to _abcoll.Set._hash
We can't call it directly because Python2 raises a TypeError.
"""
if not hasattr(self, '_hash_value'):
self._hash_value = self._hash()
return self._hash_value

View File

@ -0,0 +1,94 @@
"""Class definition for bijection."""
from collections import MutableMapping, Mapping
class bijection(MutableMapping):
"""A one-to-one onto mapping, a dict with unique values."""
def __init__(self, iterable=None, **kwarg):
"""Create a bijection from an iterable.
Matches dict.__init__.
"""
self._data = {}
self.__inverse = self.__new__(bijection)
self.__inverse._data = {}
self.__inverse.__inverse = self
if iterable is not None:
if isinstance(iterable, Mapping):
for key, value in iterable.items():
self[key] = value
else:
for pair in iterable:
self[pair[0]] = pair[1]
for key, value in kwarg.items():
self[key] = value
def __repr__(self):
if len(self._data) == 0:
return '{0}()'.format(self.__class__.__name__)
else:
repr_format = '{class_name}({values!r})'
return repr_format.format(
class_name=self.__class__.__name__,
values=self._data,
)
@property
def inverse(self):
"""Return the inverse of this bijection."""
return self.__inverse
# Required for MutableMapping
def __len__(self):
return len(self._data)
# Required for MutableMapping
def __getitem__(self, key):
return self._data[key]
# Required for MutableMapping
def __setitem__(self, key, value):
if key in self:
del self.inverse._data[self[key]]
if value in self.inverse:
del self._data[self.inverse[value]]
self._data[key] = value
self.inverse._data[value] = key
# Required for MutableMapping
def __delitem__(self, key):
value = self._data.pop(key)
del self.inverse._data[value]
# Required for MutableMapping
def __iter__(self):
return iter(self._data)
def __contains__(self, key):
return key in self._data
def clear(self):
"""Remove everything from this bijection."""
self._data.clear()
self.inverse._data.clear()
def copy(self):
"""Return a copy of this bijection."""
return bijection(self)
def items(self):
"""See Mapping.items."""
return self._data.items()
def keys(self):
"""See Mapping.keys."""
return self._data.keys()
def values(self):
"""See Mapping.values."""
return self.inverse.keys()
def __eq__(self, other):
return isinstance(other, bijection) and self._data == other._data

View File

@ -0,0 +1,384 @@
"""RangeMap class definition."""
from bisect import bisect_left, bisect_right
from collections import namedtuple, Mapping, MappingView, Set
# Used to mark unmapped ranges
_empty = object()
MappedRange = namedtuple('MappedRange', ('start', 'stop', 'value'))
class KeysView(MappingView, Set):
"""A view of the keys that mark the starts of subranges.
Since iterating over all the keys is impossible, the KeysView only
contains the keys that start each subrange.
"""
__slots__ = ()
@classmethod
def _from_iterable(self, it):
return set(it)
def __contains__(self, key):
loc = self._mapping._bisect_left(key)
return self._mapping._keys[loc] == key and \
self._mapping._values[loc] is not _empty
def __iter__(self):
for item in self._mapping.ranges():
yield item.start
class ItemsView(MappingView, Set):
"""A view of the items that mark the starts of subranges.
Since iterating over all the keys is impossible, the ItemsView only
contains the items that start each subrange.
"""
__slots__ = ()
@classmethod
def _from_iterable(self, it):
return set(it)
def __contains__(self, item):
key, value = item
loc = self._mapping._bisect_left(key)
return self._mapping._keys[loc] == key and \
self._mapping._values[loc] == value
def __iter__(self):
for mapped_range in self._mapping.ranges():
yield (mapped_range.start, mapped_range.value)
class ValuesView(MappingView):
"""A view on the values of a Mapping."""
__slots__ = ()
def __contains__(self, value):
return value in self._mapping._values
def __iter__(self):
for value in self._mapping._values:
if value is not _empty:
yield value
def _check_start_stop(start, stop):
"""Check that start and stop are valid - orderable and in the right order.
Raises:
ValueError: if stop <= start
TypeError: if unorderable
"""
if start is not None and stop is not None and stop <= start:
raise ValueError('stop must be > start')
def _check_key_slice(key):
if not isinstance(key, slice):
raise TypeError('Can only set and delete slices')
if key.step is not None:
raise ValueError('Cannot set or delete slices with steps')
class RangeMap(Mapping):
"""Map ranges of orderable elements to values."""
def __init__(self, iterable=None, **kwargs):
"""Create a RangeMap.
A mapping or other iterable can be passed to initialize the RangeMap.
If mapping is passed, it is interpreted as a mapping from range start
indices to values.
If an iterable is passed, each element will define a range in the
RangeMap and should be formatted (start, stop, value).
default_value is a an optional keyword argument that will initialize the
entire RangeMap to that value. Any missing ranges will be mapped to that
value. However, if ranges are subsequently deleted they will be removed
and *not* mapped to the default_value.
Args:
iterable: A Mapping or an Iterable to initialize from.
default_value: If passed, the return value for all keys less than the
least key in mapping or missing ranges in iterable. If no mapping
or iterable, the return value for all keys.
"""
default_value = kwargs.pop('default_value', _empty)
if kwargs:
raise TypeError('Unknown keyword arguments: %s' % ', '.join(kwargs.keys()))
self._keys = [None]
self._values = [default_value]
if iterable:
if isinstance(iterable, Mapping):
self._init_from_mapping(iterable)
else:
self._init_from_iterable(iterable)
@classmethod
def from_mapping(cls, mapping):
"""Create a RangeMap from a mapping of interval starts to values."""
obj = cls()
obj._init_from_mapping(mapping)
return obj
def _init_from_mapping(self, mapping):
for key, value in sorted(mapping.items()):
self.set(value, key)
@classmethod
def from_iterable(cls, iterable):
"""Create a RangeMap from an iterable of tuples defining each range.
Each element of the iterable is a tuple (start, stop, value).
"""
obj = cls()
obj._init_from_iterable(iterable)
return obj
def _init_from_iterable(self, iterable):
for start, stop, value in iterable:
self.set(value, start=start, stop=stop)
def __str__(self):
range_format = '({range.start}, {range.stop}): {range.value}'
values = ', '.join([range_format.format(range=r) for r in self.ranges()])
return 'RangeMap(%s)' % values
def __repr__(self):
range_format = '({range.start!r}, {range.stop!r}, {range.value!r})'
values = ', '.join([range_format.format(range=r) for r in self.ranges()])
return 'RangeMap([%s])' % values
def _bisect_left(self, key):
"""Return the index of the key or the last key < key."""
if key is None:
return 0
else:
return bisect_left(self._keys, key, lo=1)
def _bisect_right(self, key):
"""Return the index of the first key > key."""
if key is None:
return 1
else:
return bisect_right(self._keys, key, lo=1)
def ranges(self, start=None, stop=None):
"""Generate MappedRanges for all mapped ranges.
Yields:
MappedRange
"""
_check_start_stop(start, stop)
start_loc = self._bisect_right(start)
if stop is None:
stop_loc = len(self._keys)
else:
stop_loc = self._bisect_left(stop)
start_val = self._values[start_loc - 1]
candidate_keys = [start] + self._keys[start_loc:stop_loc] + [stop]
candidate_values = [start_val] + self._values[start_loc:stop_loc]
for i, value in enumerate(candidate_values):
if value is not _empty:
start_key = candidate_keys[i]
stop_key = candidate_keys[i + 1]
yield MappedRange(start_key, stop_key, value)
def __contains__(self, value):
try:
self.__getitem(value) is not _empty
except KeyError:
return False
else:
return True
def __iter__(self):
for key, value in zip(self._keys, self._values):
if value is not _empty:
yield key
def __bool__(self):
if len(self._keys) > 1:
return True
else:
return self._values[0] != _empty
__nonzero__ = __bool__
def __getitem(self, key):
"""Get the value for a key (not a slice)."""
loc = self._bisect_right(key) - 1
value = self._values[loc]
if value is _empty:
raise KeyError(key)
else:
return value
def get(self, key, restval=None):
"""Get the value of the range containing key, otherwise return restval."""
try:
return self.__getitem(key)
except KeyError:
return restval
def get_range(self, start=None, stop=None):
"""Return a RangeMap for the range start to stop.
Returns:
A RangeMap
"""
return self.from_iterable(self.ranges(start, stop))
def set(self, value, start=None, stop=None):
"""Set the range from start to stop to value."""
_check_start_stop(start, stop)
# start_index, stop_index will denote the sections we are replacing
start_index = self._bisect_left(start)
if start is not None: # start_index == 0
prev_value = self._values[start_index - 1]
if prev_value == value:
# We're setting a range where the left range has the same
# value, so create one big range
start_index -= 1
start = self._keys[start_index]
if stop is None:
new_keys = [start]
new_values = [value]
stop_index = len(self._keys)
else:
stop_index = self._bisect_right(stop)
stop_value = self._values[stop_index - 1]
stop_key = self._keys[stop_index - 1]
if stop_key == stop and stop_value == value:
new_keys = [start]
new_values = [value]
else:
new_keys = [start, stop]
new_values = [value, stop_value]
self._keys[start_index:stop_index] = new_keys
self._values[start_index:stop_index] = new_values
def delete(self, start=None, stop=None):
"""Delete the range from start to stop from self.
Raises:
KeyError: If part of the passed range isn't mapped.
"""
_check_start_stop(start, stop)
start_loc = self._bisect_right(start) - 1
if stop is None:
stop_loc = len(self._keys)
else:
stop_loc = self._bisect_left(stop)
for value in self._values[start_loc:stop_loc]:
if value is _empty:
raise KeyError((start, stop))
# this is inefficient, we've already found the sub ranges
self.set(_empty, start=start, stop=stop)
def empty(self, start=None, stop=None):
"""Empty the range from start to stop.
Like delete, but no Error is raised if the entire range isn't mapped.
"""
self.set(_empty, start=start, stop=stop)
def clear(self):
"""Remove all elements."""
self._keys = [None]
self._values = [_empty]
@property
def start(self):
"""Get the start key of the first range.
None if RangeMap is empty or unbounded to the left.
"""
if self._values[0] is _empty:
try:
return self._keys[1]
except IndexError:
# This is empty or everything is mapped to a single value
return None
else:
# This is unbounded to the left
return self._keys[0]
@property
def end(self):
"""Get the stop key of the last range.
None if RangeMap is empty or unbounded to the right.
"""
if self._values[-1] is _empty:
return self._keys[-1]
else:
# This is unbounded to the right
return None
def __eq__(self, other):
if isinstance(other, RangeMap):
return (
self._keys == other._keys and
self._values == other._values
)
else:
return False
def __getitem__(self, key):
try:
_check_key_slice(key)
except TypeError:
return self.__getitem(key)
else:
return self.get_range(key.start, key.stop)
def __setitem__(self, key, value):
_check_key_slice(key)
self.set(value, key.start, key.stop)
def __delitem__(self, key):
_check_key_slice(key)
self.delete(key.start, key.stop)
def __len__(self):
count = 0
for v in self._values:
if v is not _empty:
count += 1
return count
def keys(self):
"""Return a view of the keys."""
return KeysView(self)
def values(self):
"""Return a view of the values."""
return ValuesView(self)
def items(self):
"""Return a view of the item pairs."""
return ItemsView(self)
# Python2 - override slice methods
def __setslice__(self, i, j, value):
"""Implement __setslice__ to override behavior in Python 2.
This is required because empty slices pass integers in python2 as opposed
to None in python 3.
"""
raise SyntaxError('Assigning slices doesn\t work in Python 2, use set')
def __delslice__(self, i, j):
raise SyntaxError('Deleting slices doesn\t work in Python 2, use delete')
def __getslice__(self, i, j):
raise SyntaxError('Getting slices doesn\t work in Python 2, use get_range.')

View File

@ -0,0 +1,552 @@
"""Setlist class definitions."""
import random as random_
from collections import (
Sequence,
Set,
MutableSequence,
MutableSet,
Hashable,
)
from . import _util
class _basesetlist(Sequence, Set):
"""A setlist is an ordered Collection of unique elements.
_basesetlist is the superclass of setlist and frozensetlist. It is immutable
and unhashable.
"""
def __init__(self, iterable=None, raise_on_duplicate=False):
"""Create a setlist.
Args:
iterable (Iterable): Values to initialize the setlist with.
"""
self._list = list()
self._dict = dict()
if iterable:
if raise_on_duplicate:
self._extend(iterable)
else:
self._update(iterable)
def __repr__(self):
if len(self) == 0:
return '{0}()'.format(self.__class__.__name__)
else:
repr_format = '{class_name}({values!r})'
return repr_format.format(
class_name=self.__class__.__name__,
values=tuple(self),
)
# Convenience methods
def _fix_neg_index(self, index):
if index < 0:
index += len(self)
if index < 0:
raise IndexError('index is out of range')
return index
def _fix_end_index(self, index):
if index is None:
return len(self)
else:
return self._fix_neg_index(index)
def _append(self, value):
# Checking value in self will check that value is Hashable
if value in self:
raise ValueError('Value "%s" already present' % str(value))
else:
self._dict[value] = len(self)
self._list.append(value)
def _extend(self, values):
new_values = set()
for value in values:
if value in new_values:
raise ValueError('New values contain duplicates')
elif value in self:
raise ValueError('New values contain elements already present in self')
else:
new_values.add(value)
for value in values:
self._dict[value] = len(self)
self._list.append(value)
def _add(self, item):
if item not in self:
self._dict[item] = len(self)
self._list.append(item)
def _update(self, values):
for value in values:
if value not in self:
self._dict[value] = len(self)
self._list.append(value)
@classmethod
def _from_iterable(cls, it, **kwargs):
return cls(it, **kwargs)
# Implement Container
def __contains__(self, value):
return value in self._dict
# Iterable we get by inheriting from Sequence
# Implement Sized
def __len__(self):
return len(self._list)
# Implement Sequence
def __getitem__(self, index):
if isinstance(index, slice):
return self._from_iterable(self._list[index])
return self._list[index]
def count(self, value):
"""Return the number of occurences of value in self.
This runs in O(1)
Args:
value: The value to count
Returns:
int: 1 if the value is in the setlist, otherwise 0
"""
if value in self:
return 1
else:
return 0
def index(self, value, start=0, end=None):
"""Return the index of value between start and end.
By default, the entire setlist is searched.
This runs in O(1)
Args:
value: The value to find the index of
start (int): The index to start searching at (defaults to 0)
end (int): The index to stop searching at (defaults to the end of the list)
Returns:
int: The index of the value
Raises:
ValueError: If the value is not in the list or outside of start - end
IndexError: If start or end are out of range
"""
try:
index = self._dict[value]
except KeyError:
raise ValueError
else:
start = self._fix_neg_index(start)
end = self._fix_end_index(end)
if start <= index and index < end:
return index
else:
raise ValueError
@classmethod
def _check_type(cls, other, operand_name):
if not isinstance(other, _basesetlist):
message = (
"unsupported operand type(s) for {operand_name}: "
"'{self_type}' and '{other_type}'").format(
operand_name=operand_name,
self_type=cls,
other_type=type(other),
)
raise TypeError(message)
def __add__(self, other):
self._check_type(other, '+')
out = self.copy()
out._extend(other)
return out
# Implement Set
def issubset(self, other):
return self <= other
def issuperset(self, other):
return self >= other
def union(self, other):
out = self.copy()
out.update(other)
return out
def intersection(self, other):
other = set(other)
return self._from_iterable(item for item in self if item in other)
def difference(self, other):
other = set(other)
return self._from_iterable(item for item in self if item not in other)
def symmetric_difference(self, other):
return self.union(other) - self.intersection(other)
def __sub__(self, other):
self._check_type(other, '-')
return self.difference(other)
def __and__(self, other):
self._check_type(other, '&')
return self.intersection(other)
def __or__(self, other):
self._check_type(other, '|')
return self.union(other)
def __xor__(self, other):
self._check_type(other, '^')
return self.symmetric_difference(other)
# Comparison
def __eq__(self, other):
if not isinstance(other, _basesetlist):
return False
if not len(self) == len(other):
return False
for self_elem, other_elem in zip(self, other):
if self_elem != other_elem:
return False
return True
def __ne__(self, other):
return not (self == other)
# New methods
def sub_index(self, sub, start=0, end=None):
"""Return the index of a subsequence.
This runs in O(len(sub))
Args:
sub (Sequence): An Iterable to search for
Returns:
int: The index of the first element of sub
Raises:
ValueError: If sub isn't a subsequence
TypeError: If sub isn't iterable
IndexError: If start or end are out of range
"""
start_index = self.index(sub[0], start, end)
end = self._fix_end_index(end)
if start_index + len(sub) > end:
raise ValueError
for i in range(1, len(sub)):
if sub[i] != self[start_index + i]:
raise ValueError
return start_index
def copy(self):
return self.__class__(self)
class setlist(_basesetlist, MutableSequence, MutableSet):
"""A mutable (unhashable) setlist."""
def __str__(self):
return '{[%s}]' % ', '.join(repr(v) for v in self)
# Helper methods
def _delete_all(self, elems_to_delete, raise_errors):
indices_to_delete = set()
for elem in elems_to_delete:
try:
elem_index = self._dict[elem]
except KeyError:
if raise_errors:
raise ValueError('Passed values contain elements not in self')
else:
if elem_index in indices_to_delete:
if raise_errors:
raise ValueError('Passed vales contain duplicates')
indices_to_delete.add(elem_index)
self._delete_values_by_index(indices_to_delete)
def _delete_values_by_index(self, indices_to_delete):
deleted_count = 0
for i, elem in enumerate(self._list):
if i in indices_to_delete:
deleted_count += 1
del self._dict[elem]
else:
new_index = i - deleted_count
self._list[new_index] = elem
self._dict[elem] = new_index
# Now remove deleted_count items from the end of the list
if deleted_count:
self._list = self._list[:-deleted_count]
# Set/Sequence agnostic
def pop(self, index=-1):
"""Remove and return the item at index."""
value = self._list.pop(index)
del self._dict[value]
return value
def clear(self):
"""Remove all elements from self."""
self._dict = dict()
self._list = list()
# Implement MutableSequence
def __setitem__(self, index, value):
if isinstance(index, slice):
old_values = self[index]
for v in value:
if v in self and v not in old_values:
raise ValueError
self._list[index] = value
self._dict = {}
for i, v in enumerate(self._list):
self._dict[v] = i
else:
index = self._fix_neg_index(index)
old_value = self._list[index]
if value in self:
if value == old_value:
return
else:
raise ValueError
del self._dict[old_value]
self._list[index] = value
self._dict[value] = index
def __delitem__(self, index):
if isinstance(index, slice):
indices_to_delete = set(self.index(e) for e in self._list[index])
self._delete_values_by_index(indices_to_delete)
else:
index = self._fix_neg_index(index)
value = self._list[index]
del self._dict[value]
for elem in self._list[index + 1:]:
self._dict[elem] -= 1
del self._list[index]
def insert(self, index, value):
"""Insert value at index.
Args:
index (int): Index to insert value at
value: Value to insert
Raises:
ValueError: If value already in self
IndexError: If start or end are out of range
"""
if value in self:
raise ValueError
index = self._fix_neg_index(index)
self._dict[value] = index
for elem in self._list[index:]:
self._dict[elem] += 1
self._list.insert(index, value)
def append(self, value):
"""Append value to the end.
Args:
value: Value to append
Raises:
ValueError: If value alread in self
TypeError: If value isn't hashable
"""
self._append(value)
def extend(self, values):
"""Append all values to the end.
If any of the values are present, ValueError will
be raised and none of the values will be appended.
Args:
values (Iterable): Values to append
Raises:
ValueError: If any values are already present or there are duplicates
in the passed values.
TypeError: If any of the values aren't hashable.
"""
self._extend(values)
def __iadd__(self, values):
"""Add all values to the end of self.
Args:
values (Iterable): Values to append
Raises:
ValueError: If any values are already present
"""
self._check_type(values, '+=')
self.extend(values)
return self
def remove(self, value):
"""Remove value from self.
Args:
value: Element to remove from self
Raises:
ValueError: if element is already present
"""
try:
index = self._dict[value]
except KeyError:
raise ValueError('Value "%s" is not present.')
else:
del self[index]
def remove_all(self, elems_to_delete):
"""Remove all elements from elems_to_delete, raises ValueErrors.
See Also:
discard_all
Args:
elems_to_delete (Iterable): Elements to remove.
Raises:
ValueError: If the count of any element is greater in
elems_to_delete than self.
TypeError: If any of the values aren't hashable.
"""
self._delete_all(elems_to_delete, raise_errors=True)
# Implement MutableSet
def add(self, item):
"""Add an item.
Note:
This does not raise a ValueError for an already present value like
append does. This is to match the behavior of set.add
Args:
item: Item to add
Raises:
TypeError: If item isn't hashable.
"""
self._add(item)
def update(self, values):
"""Add all values to the end.
If any of the values are present, silently ignore
them (as opposed to extend which raises an Error).
See also:
extend
Args:
values (Iterable): Values to add
Raises:
TypeError: If any of the values are unhashable.
"""
self._update(values)
def discard_all(self, elems_to_delete):
"""Discard all the elements from elems_to_delete.
This is much faster than removing them one by one.
This runs in O(len(self) + len(elems_to_delete))
Args:
elems_to_delete (Iterable): Elements to discard.
Raises:
TypeError: If any of the values aren't hashable.
"""
self._delete_all(elems_to_delete, raise_errors=False)
def discard(self, value):
"""Discard an item.
Note:
This does not raise a ValueError for a missing value like remove does.
This is to match the behavior of set.discard
"""
try:
self.remove(value)
except ValueError:
pass
def difference_update(self, other):
"""Update self to include only the differene with other."""
other = set(other)
indices_to_delete = set()
for i, elem in enumerate(self):
if elem in other:
indices_to_delete.add(i)
if indices_to_delete:
self._delete_values_by_index(indices_to_delete)
def intersection_update(self, other):
"""Update self to include only the intersection with other."""
other = set(other)
indices_to_delete = set()
for i, elem in enumerate(self):
if elem not in other:
indices_to_delete.add(i)
if indices_to_delete:
self._delete_values_by_index(indices_to_delete)
def symmetric_difference_update(self, other):
"""Update self to include only the symmetric difference with other."""
other = setlist(other)
indices_to_delete = set()
for i, item in enumerate(self):
if item in other:
indices_to_delete.add(i)
for item in other:
self.add(item)
self._delete_values_by_index(indices_to_delete)
def __isub__(self, other):
self._check_type(other, '-=')
self.difference_update(other)
return self
def __iand__(self, other):
self._check_type(other, '&=')
self.intersection_update(other)
return self
def __ior__(self, other):
self._check_type(other, '|=')
self.update(other)
return self
def __ixor__(self, other):
self._check_type(other, '^=')
self.symmetric_difference_update(other)
return self
# New methods
def shuffle(self, random=None):
"""Shuffle all of the elements in self randomly."""
random_.shuffle(self._list, random=random)
for i, elem in enumerate(self._list):
self._dict[elem] = i
def sort(self, *args, **kwargs):
"""Sort this setlist in place."""
self._list.sort(*args, **kwargs)
for index, value in enumerate(self._list):
self._dict[value] = index
class frozensetlist(_basesetlist, Hashable):
"""An immutable (hashable) setlist."""
def __hash__(self):
if not hasattr(self, '_hash_value'):
self._hash_value = _util.hash_iterable(self)
return self._hash_value