From 6d5a0a004debc31ddb19690b5620cae8c4bb79dc Mon Sep 17 00:00:00 2001 From: Kevin Cathcart Date: Sat, 13 Jul 2019 18:17:16 -0400 Subject: [PATCH] Use a proper multiset for progression items This can cut generation times in half in some cases --- BaseClasses.py | 59 +-- Main.py | 2 +- _vendor/collections_extended/CONTRIBUTERS | 5 + _vendor/collections_extended/LICENSE | 191 ++++++++ _vendor/collections_extended/__init__.py | 55 +++ _vendor/collections_extended/_compat.py | 53 +++ _vendor/collections_extended/_util.py | 16 + _vendor/collections_extended/bags.py | 527 +++++++++++++++++++++ _vendor/collections_extended/bijection.py | 94 ++++ _vendor/collections_extended/range_map.py | 384 +++++++++++++++ _vendor/collections_extended/setlists.py | 552 ++++++++++++++++++++++ 11 files changed, 1909 insertions(+), 29 deletions(-) create mode 100644 _vendor/collections_extended/CONTRIBUTERS create mode 100644 _vendor/collections_extended/LICENSE create mode 100644 _vendor/collections_extended/__init__.py create mode 100644 _vendor/collections_extended/_compat.py create mode 100644 _vendor/collections_extended/_util.py create mode 100644 _vendor/collections_extended/bags.py create mode 100644 _vendor/collections_extended/bijection.py create mode 100644 _vendor/collections_extended/range_map.py create mode 100644 _vendor/collections_extended/setlists.py diff --git a/BaseClasses.py b/BaseClasses.py index f41ed2b9..29b33a88 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -3,6 +3,7 @@ from enum import Enum, unique import logging import json from collections import OrderedDict +from _vendor.collections_extended import bag, setlist from Utils import int16_as_bytes class World(object): @@ -135,34 +136,34 @@ class World(object): if ret.has('Golden Sword', item.player): pass elif ret.has('Tempered Sword', item.player) and self.difficulty_requirements.progressive_sword_limit >= 4: - ret.prog_items.append(('Golden Sword', item.player)) + ret.prog_items.add(('Golden Sword', item.player)) elif ret.has('Master Sword', item.player) and self.difficulty_requirements.progressive_sword_limit >= 3: - ret.prog_items.append(('Tempered Sword', item.player)) + ret.prog_items.add(('Tempered Sword', item.player)) elif ret.has('Fighter Sword', item.player) and self.difficulty_requirements.progressive_sword_limit >= 2: - ret.prog_items.append(('Master Sword', item.player)) + ret.prog_items.add(('Master Sword', item.player)) elif self.difficulty_requirements.progressive_sword_limit >= 1: - ret.prog_items.append(('Fighter Sword', item.player)) + ret.prog_items.add(('Fighter Sword', item.player)) elif 'Glove' in item.name: if ret.has('Titans Mitts', item.player): pass elif ret.has('Power Glove', item.player): - ret.prog_items.append(('Titans Mitts', item.player)) + ret.prog_items.add(('Titans Mitts', item.player)) else: - ret.prog_items.append(('Power Glove', item.player)) + ret.prog_items.add(('Power Glove', item.player)) elif 'Shield' in item.name: if ret.has('Mirror Shield', item.player): pass elif ret.has('Red Shield', item.player) and self.difficulty_requirements.progressive_shield_limit >= 3: - ret.prog_items.append(('Mirror Shield', item.player)) + ret.prog_items.add(('Mirror Shield', item.player)) elif ret.has('Blue Shield', item.player) and self.difficulty_requirements.progressive_shield_limit >= 2: - ret.prog_items.append(('Red Shield', item.player)) + ret.prog_items.add(('Red Shield', item.player)) elif self.difficulty_requirements.progressive_shield_limit >= 1: - ret.prog_items.append(('Blue Shield', item.player)) + ret.prog_items.add(('Blue Shield', item.player)) elif item.name.startswith('Bottle'): if ret.bottle_count(item.player) < self.difficulty_requirements.progressive_bottle_limit: - ret.prog_items.append((item.name, item.player)) + ret.prog_items.add((item.name, item.player)) elif item.advancement or item.key: - ret.prog_items.append((item.name, item.player)) + ret.prog_items.add((item.name, item.player)) for item in self.itempool: soft_collect(item) @@ -284,7 +285,7 @@ class World(object): class CollectionState(object): def __init__(self, parent): - self.prog_items = [] + self.prog_items = bag() self.world = parent self.reachable_regions = {player: set() for player in range(1, parent.players + 1)} self.events = [] @@ -293,12 +294,13 @@ class CollectionState(object): self.stale = {player: True for player in range(1, parent.players + 1)} def update_reachable_regions(self, player): + player_regions = [region for region in self.world.regions if region.player == player] self.stale[player] = False rrp = self.reachable_regions[player] new_regions = True reachable_regions_count = len(rrp) while new_regions: - possible = [region for region in self.world.regions if region not in rrp and region.player == player] + possible = [region for region in player_regions if region not in rrp] for candidate in possible: if candidate.can_reach_private(self): rrp.add(candidate) @@ -307,7 +309,7 @@ class CollectionState(object): def copy(self): ret = CollectionState(self.world) - ret.prog_items = copy.copy(self.prog_items) + ret.prog_items = self.prog_items.copy() ret.reachable_regions = {player: copy.copy(self.reachable_regions[player]) for player in range(1, self.world.players + 1)} ret.events = copy.copy(self.events) ret.path = copy.copy(self.path) @@ -343,17 +345,18 @@ class CollectionState(object): self.collect(event.item, True, event) new_locations = len(reachable_events) > checked_locations checked_locations = len(reachable_events) + def has(self, item, player, count=1): if count == 1: return (item, player) in self.prog_items - return self.item_count(item, player) >= count + return self.prog_items.count((item, player)) >= count def has_key(self, item, player, count=1): if self.world.retro: return self.can_buy_unlimited('Small Key (Universal)', player) if count == 1: return (item, player) in self.prog_items - return self.item_count(item, player) >= count + return self.prog_items.count((item, player)) >= count def can_buy_unlimited(self, item, player): for shop in self.world.shops: @@ -362,7 +365,7 @@ class CollectionState(object): return False def item_count(self, item, player): - return len([pritem for pritem in self.prog_items if pritem == (item, player)]) + return self.prog_items.count((item, player)) def can_lift_rocks(self, player): return self.has('Power Glove', player) or self.has('Titans Mitts', player) @@ -467,44 +470,44 @@ class CollectionState(object): if self.has('Golden Sword', item.player): pass elif self.has('Tempered Sword', item.player) and self.world.difficulty_requirements.progressive_sword_limit >= 4: - self.prog_items.append(('Golden Sword', item.player)) + self.prog_items.add(('Golden Sword', item.player)) changed = True elif self.has('Master Sword', item.player) and self.world.difficulty_requirements.progressive_sword_limit >= 3: - self.prog_items.append(('Tempered Sword', item.player)) + self.prog_items.add(('Tempered Sword', item.player)) changed = True elif self.has('Fighter Sword', item.player) and self.world.difficulty_requirements.progressive_sword_limit >= 2: - self.prog_items.append(('Master Sword', item.player)) + self.prog_items.add(('Master Sword', item.player)) changed = True elif self.world.difficulty_requirements.progressive_sword_limit >= 1: - self.prog_items.append(('Fighter Sword', item.player)) + self.prog_items.add(('Fighter Sword', item.player)) changed = True elif 'Glove' in item.name: if self.has('Titans Mitts', item.player): pass elif self.has('Power Glove', item.player): - self.prog_items.append(('Titans Mitts', item.player)) + self.prog_items.add(('Titans Mitts', item.player)) changed = True else: - self.prog_items.append(('Power Glove', item.player)) + self.prog_items.add(('Power Glove', item.player)) changed = True elif 'Shield' in item.name: if self.has('Mirror Shield', item.player): pass elif self.has('Red Shield', item.player) and self.world.difficulty_requirements.progressive_shield_limit >= 3: - self.prog_items.append(('Mirror Shield', item.player)) + self.prog_items.add(('Mirror Shield', item.player)) changed = True elif self.has('Blue Shield', item.player) and self.world.difficulty_requirements.progressive_shield_limit >= 2: - self.prog_items.append(('Red Shield', item.player)) + self.prog_items.add(('Red Shield', item.player)) changed = True elif self.world.difficulty_requirements.progressive_shield_limit >= 1: - self.prog_items.append(('Blue Shield', item.player)) + self.prog_items.add(('Blue Shield', item.player)) changed = True elif item.name.startswith('Bottle'): if self.bottle_count(item.player) < self.world.difficulty_requirements.progressive_bottle_limit: - self.prog_items.append((item.name, item.player)) + self.prog_items.add((item.name, item.player)) changed = True elif event or item.advancement: - self.prog_items.append((item.name, item.player)) + self.prog_items.add((item.name, item.player)) changed = True self.stale[item.player] = True diff --git a/Main.py b/Main.py index c2152e55..bef644a9 100644 --- a/Main.py +++ b/Main.py @@ -227,7 +227,7 @@ def copy_world(world): ret.itempool.append(Item(item.name, item.advancement, item.priority, item.type, player = item.player)) # copy progress items in state - ret.state.prog_items = list(world.state.prog_items) + ret.state.prog_items = world.state.prog_items.copy() ret.state.stale = {player: True for player in range(1, world.players + 1)} for player in range(1, world.players + 1): diff --git a/_vendor/collections_extended/CONTRIBUTERS b/_vendor/collections_extended/CONTRIBUTERS new file mode 100644 index 00000000..333a02bf --- /dev/null +++ b/_vendor/collections_extended/CONTRIBUTERS @@ -0,0 +1,5 @@ +Mike Lenzen https://github.com/mlenzen +Caleb Levy https://github.com/caleblevy +Marein Könings https://github.com/MareinK +Jad Kik https://github.com/jadkik +Kuba Marek https://github.com/bluecube \ No newline at end of file diff --git a/_vendor/collections_extended/LICENSE b/_vendor/collections_extended/LICENSE new file mode 100644 index 00000000..8405e89a --- /dev/null +++ b/_vendor/collections_extended/LICENSE @@ -0,0 +1,191 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and +distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright +owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities +that control, are controlled by, or are under common control with that entity. +For the purposes of this definition, "control" means (i) the power, direct or +indirect, to cause the direction or management of such entity, whether by +contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the +outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising +permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including +but not limited to software source code, documentation source, and configuration +files. + +"Object" form shall mean any form resulting from mechanical transformation or +translation of a Source form, including but not limited to compiled object code, +generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made +available under the License, as indicated by a copyright notice that is included +in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that +is based on (or derived from) the Work and for which the editorial revisions, +annotations, elaborations, or other modifications represent, as a whole, an +original work of authorship. For the purposes of this License, Derivative Works +shall not include works that remain separable from, or merely link (or bind by +name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version +of the Work and any modifications or additions to that Work or Derivative Works +thereof, that is intentionally submitted to Licensor for inclusion in the Work +by the copyright owner or by an individual or Legal Entity authorized to submit +on behalf of the copyright owner. For the purposes of this definition, +"submitted" means any form of electronic, verbal, or written communication sent +to the Licensor or its representatives, including but not limited to +communication on electronic mailing lists, source code control systems, and +issue tracking systems that are managed by, or on behalf of, the Licensor for +the purpose of discussing and improving the Work, but excluding communication +that is conspicuously marked or otherwise designated in writing by the copyright +owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf +of whom a Contribution has been received by Licensor and subsequently +incorporated within the Work. + +2. Grant of Copyright License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable copyright license to reproduce, prepare Derivative Works of, +publicly display, publicly perform, sublicense, and distribute the Work and such +Derivative Works in Source or Object form. + +3. Grant of Patent License. + +Subject to the terms and conditions of this License, each Contributor hereby +grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable (except as stated in this section) patent license to make, have +made, use, offer to sell, sell, import, and otherwise transfer the Work, where +such license applies only to those patent claims licensable by such Contributor +that are necessarily infringed by their Contribution(s) alone or by combination +of their Contribution(s) with the Work to which such Contribution(s) was +submitted. If You institute patent litigation against any entity (including a +cross-claim or counterclaim in a lawsuit) alleging that the Work or a +Contribution incorporated within the Work constitutes direct or contributory +patent infringement, then any patent licenses granted to You under this License +for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. + +You may reproduce and distribute copies of the Work or Derivative Works thereof +in any medium, with or without modifications, and in Source or Object form, +provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of +this License; and +You must cause any modified files to carry prominent notices stating that You +changed the files; and +You must retain, in the Source form of any Derivative Works that You distribute, +all copyright, patent, trademark, and attribution notices from the Source form +of the Work, excluding those notices that do not pertain to any part of the +Derivative Works; and +If the Work includes a "NOTICE" text file as part of its distribution, then any +Derivative Works that You distribute must include a readable copy of the +attribution notices contained within such NOTICE file, excluding those notices +that do not pertain to any part of the Derivative Works, in at least one of the +following places: within a NOTICE text file distributed as part of the +Derivative Works; within the Source form or documentation, if provided along +with the Derivative Works; or, within a display generated by the Derivative +Works, if and wherever such third-party notices normally appear. The contents of +the NOTICE file are for informational purposes only and do not modify the +License. You may add Your own attribution notices within Derivative Works that +You distribute, alongside or as an addendum to the NOTICE text from the Work, +provided that such additional attribution notices cannot be construed as +modifying the License. +You may add Your own copyright statement to Your modifications and may provide +additional or different license terms and conditions for use, reproduction, or +distribution of Your modifications, or for any such Derivative Works as a whole, +provided Your use, reproduction, and distribution of the Work otherwise complies +with the conditions stated in this License. + +5. Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted +for inclusion in the Work by You to the Licensor shall be under the terms and +conditions of this License, without any additional terms or conditions. +Notwithstanding the above, nothing herein shall supersede or modify the terms of +any separate license agreement you may have executed with Licensor regarding +such Contributions. + +6. Trademarks. + +This License does not grant permission to use the trade names, trademarks, +service marks, or product names of the Licensor, except as required for +reasonable and customary use in describing the origin of the Work and +reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. + +Unless required by applicable law or agreed to in writing, Licensor provides the +Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, +including, without limitation, any warranties or conditions of TITLE, +NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are +solely responsible for determining the appropriateness of using or +redistributing the Work and assume any risks associated with Your exercise of +permissions under this License. + +8. Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), +contract, or otherwise, unless required by applicable law (such as deliberate +and grossly negligent acts) or agreed to in writing, shall any Contributor be +liable to You for damages, including any direct, indirect, special, incidental, +or consequential damages of any character arising as a result of this License or +out of the use or inability to use the Work (including but not limited to +damages for loss of goodwill, work stoppage, computer failure or malfunction, or +any and all other commercial damages or losses), even if such Contributor has +been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. + +While redistributing the Work or Derivative Works thereof, You may choose to +offer, and charge a fee for, acceptance of support, warranty, indemnity, or +other liability obligations and/or rights consistent with this License. However, +in accepting such obligations, You may act only on Your own behalf and on Your +sole responsibility, not on behalf of any other Contributor, and only if You +agree to indemnify, defend, and hold each Contributor harmless for any liability +incurred by, or claims asserted against, such Contributor by reason of your +accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work + +To apply the Apache License to your work, attach the following boilerplate +notice, with the fields enclosed by brackets "[]" replaced with your own +identifying information. (Don't include the brackets!) The text should be +enclosed in the appropriate comment syntax for the file format. We also +recommend that a file or class name and description of purpose be included on +the same "printed page" as the copyright notice for easier identification within +third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/_vendor/collections_extended/__init__.py b/_vendor/collections_extended/__init__.py new file mode 100644 index 00000000..039fee03 --- /dev/null +++ b/_vendor/collections_extended/__init__.py @@ -0,0 +1,55 @@ +"""collections_extended contains a few extra basic data structures.""" +from ._compat import Collection +from .bags import bag, frozenbag +from .setlists import setlist, frozensetlist +from .bijection import bijection +from .range_map import RangeMap, MappedRange + +__version__ = '1.0.2' + +__all__ = ( + 'collection', + 'setlist', + 'frozensetlist', + 'bag', + 'frozenbag', + 'bijection', + 'RangeMap', + 'MappedRange', + 'Collection', + ) + + +def collection(iterable=None, mutable=True, ordered=False, unique=False): + """Return a Collection with the specified properties. + + Args: + iterable (Iterable): collection to instantiate new collection from. + mutable (bool): Whether or not the new collection is mutable. + ordered (bool): Whether or not the new collection is ordered. + unique (bool): Whether or not the new collection contains only unique values. + """ + if iterable is None: + iterable = tuple() + if unique: + if ordered: + if mutable: + return setlist(iterable) + else: + return frozensetlist(iterable) + else: + if mutable: + return set(iterable) + else: + return frozenset(iterable) + else: + if ordered: + if mutable: + return list(iterable) + else: + return tuple(iterable) + else: + if mutable: + return bag(iterable) + else: + return frozenbag(iterable) diff --git a/_vendor/collections_extended/_compat.py b/_vendor/collections_extended/_compat.py new file mode 100644 index 00000000..bbf0fbd9 --- /dev/null +++ b/_vendor/collections_extended/_compat.py @@ -0,0 +1,53 @@ +"""Python 2/3 compatibility helpers.""" +import sys + +is_py2 = sys.version_info[0] == 2 + +if is_py2: + def keys_set(d): + """Return a set of passed dictionary's keys.""" + return set(d.keys()) +else: + keys_set = dict.keys + + +if sys.version_info < (3, 6): + from collections import Sized, Iterable, Container + + def _check_methods(C, *methods): + mro = C.__mro__ + for method in methods: + for B in mro: + if method in B.__dict__: + if B.__dict__[method] is None: + return NotImplemented + break + else: + return NotImplemented + return True + + class Collection(Sized, Iterable, Container): + """Backport from Python3.6.""" + + __slots__ = tuple() + + @classmethod + def __subclasshook__(cls, C): + if cls is Collection: + return _check_methods(C, "__len__", "__iter__", "__contains__") + return NotImplemented + +else: + from collections.abc import Collection + + +def handle_rich_comp_not_implemented(): + """Correctly handle unimplemented rich comparisons. + + In Python 3, return NotImplemented. + In Python 2, raise a TypeError. + """ + if is_py2: + raise TypeError() + else: + return NotImplemented diff --git a/_vendor/collections_extended/_util.py b/_vendor/collections_extended/_util.py new file mode 100644 index 00000000..58a83f45 --- /dev/null +++ b/_vendor/collections_extended/_util.py @@ -0,0 +1,16 @@ +"""util functions for collections_extended.""" + + +def hash_iterable(it): + """Perform a O(1) memory hash of an iterable of arbitrary length. + + hash(tuple(it)) creates a temporary tuple containing all values from it + which could be a problem if it is large. + + See discussion at: + https://groups.google.com/forum/#!msg/python-ideas/XcuC01a8SYs/e-doB9TbDwAJ + """ + hash_value = hash(type(it)) + for value in it: + hash_value = hash((hash_value, value)) + return hash_value diff --git a/_vendor/collections_extended/bags.py b/_vendor/collections_extended/bags.py new file mode 100644 index 00000000..cce132fe --- /dev/null +++ b/_vendor/collections_extended/bags.py @@ -0,0 +1,527 @@ +"""Bag class definitions.""" +import heapq +from operator import itemgetter +from collections import Set, MutableSet, Hashable + +from . import _compat + + +class _basebag(Set): + """Base class for bag classes. + + Base class for bag and frozenbag. Is not mutable and not hashable, so there's + no reason to use this instead of either bag or frozenbag. + """ + + # Basic object methods + + def __init__(self, iterable=None): + """Create a new basebag. + + If iterable isn't given, is None or is empty then the bag starts empty. + Otherwise each element from iterable will be added to the bag + however many times it appears. + + This runs in O(len(iterable)) + """ + self._dict = dict() + self._size = 0 + if iterable: + if isinstance(iterable, _basebag): + for elem, count in iterable._dict.items(): + self._dict[elem] = count + self._size += count + else: + for value in iterable: + self._dict[value] = self._dict.get(value, 0) + 1 + self._size += 1 + + def __repr__(self): + if self._size == 0: + return '{0}()'.format(self.__class__.__name__) + else: + repr_format = '{class_name}({values!r})' + return repr_format.format( + class_name=self.__class__.__name__, + values=tuple(self), + ) + + def __str__(self): + if self._size == 0: + return '{class_name}()'.format(class_name=self.__class__.__name__) + else: + format_single = '{elem!r}' + format_mult = '{elem!r}^{mult}' + strings = [] + for elem, mult in self._dict.items(): + if mult > 1: + strings.append(format_mult.format(elem=elem, mult=mult)) + else: + strings.append(format_single.format(elem=elem)) + return '{%s}' % ', '.join(strings) + + # New public methods (not overriding/implementing anything) + + def num_unique_elements(self): + """Return the number of unique elements. + + This runs in O(1) time + """ + return len(self._dict) + + def unique_elements(self): + """Return a view of unique elements in this bag. + + In Python 3: + This runs in O(1) time and returns a view of the unique elements + In Python 2: + This runs in O(n) and returns set of the current elements. + """ + return _compat.keys_set(self._dict) + + def count(self, value): + """Return the number of value present in this bag. + + If value is not in the bag no Error is raised, instead 0 is returned. + + This runs in O(1) time + + Args: + value: The element of self to get the count of + Returns: + int: The count of value in self + """ + return self._dict.get(value, 0) + + def nlargest(self, n=None): + """List the n most common elements and their counts. + + List is from the most + common to the least. If n is None, the list all element counts. + + Run time should be O(m log m) where m is len(self) + Args: + n (int): The number of elements to return + """ + if n is None: + return sorted(self._dict.items(), key=itemgetter(1), reverse=True) + else: + return heapq.nlargest(n, self._dict.items(), key=itemgetter(1)) + + @classmethod + def _from_iterable(cls, it): + return cls(it) + + @classmethod + def from_mapping(cls, mapping): + """Create a bag from a dict of elem->count. + + Each key in the dict is added if the value is > 0. + """ + out = cls() + for elem, count in mapping.items(): + if count > 0: + out._dict[elem] = count + out._size += count + return out + + def copy(self): + """Create a shallow copy of self. + + This runs in O(len(self.num_unique_elements())) + """ + return self.from_mapping(self._dict) + + # implementing Sized methods + + def __len__(self): + """Return the cardinality of the bag. + + This runs in O(1) + """ + return self._size + + # implementing Container methods + + def __contains__(self, value): + """Return the multiplicity of the element. + + This runs in O(1) + """ + return self._dict.get(value, 0) + + # implementing Iterable methods + + def __iter__(self): + """Iterate through all elements. + + Multiple copies will be returned if they exist. + """ + for value, count in self._dict.items(): + for i in range(count): + yield(value) + + # Comparison methods + + def _is_subset(self, other): + """Check that every element in self has a count <= in other. + + Args: + other (Set) + """ + if isinstance(other, _basebag): + for elem, count in self._dict.items(): + if not count <= other._dict.get(elem, 0): + return False + else: + for elem in self: + if self._dict.get(elem, 0) > 1 or elem not in other: + return False + return True + + def _is_superset(self, other): + """Check that every element in self has a count >= in other. + + Args: + other (Set) + """ + if isinstance(other, _basebag): + for elem, count in other._dict.items(): + if not self._dict.get(elem, 0) >= count: + return False + else: + for elem in other: + if elem not in self: + return False + return True + + def __le__(self, other): + if not isinstance(other, Set): + return _compat.handle_rich_comp_not_implemented() + return len(self) <= len(other) and self._is_subset(other) + + def __lt__(self, other): + if not isinstance(other, Set): + return _compat.handle_rich_comp_not_implemented() + return len(self) < len(other) and self._is_subset(other) + + def __gt__(self, other): + if not isinstance(other, Set): + return _compat.handle_rich_comp_not_implemented() + return len(self) > len(other) and self._is_superset(other) + + def __ge__(self, other): + if not isinstance(other, Set): + return _compat.handle_rich_comp_not_implemented() + return len(self) >= len(other) and self._is_superset(other) + + def __eq__(self, other): + if not isinstance(other, Set): + return False + if isinstance(other, _basebag): + return self._dict == other._dict + if not len(self) == len(other): + return False + for elem in other: + if self._dict.get(elem, 0) != 1: + return False + return True + + def __ne__(self, other): + return not (self == other) + + # Operations - &, |, +, -, ^, * and isdisjoint + + def __and__(self, other): + """Intersection is the minimum of corresponding counts. + + This runs in O(l + n) where: + n is self.num_unique_elements() + if other is a bag: + l = 1 + else: + l = len(other) + """ + if not isinstance(other, _basebag): + other = self._from_iterable(other) + values = dict() + for elem in self._dict: + values[elem] = min(other._dict.get(elem, 0), self._dict.get(elem, 0)) + return self.from_mapping(values) + + def isdisjoint(self, other): + """Return if this bag is disjoint with the passed collection. + + This runs in O(len(other)) + + TODO move isdisjoint somewhere more appropriate + """ + for value in other: + if value in self: + return False + return True + + def __or__(self, other): + """Union is the maximum of all elements. + + This runs in O(m + n) where: + n is self.num_unique_elements() + if other is a bag: + m = other.num_unique_elements() + else: + m = len(other) + """ + if not isinstance(other, _basebag): + other = self._from_iterable(other) + values = dict() + for elem in self.unique_elements() | other.unique_elements(): + values[elem] = max(self._dict.get(elem, 0), other._dict.get(elem, 0)) + return self.from_mapping(values) + + def __add__(self, other): + """Return a new bag also containing all the elements of other. + + self + other = self & other + self | other + + This runs in O(m + n) where: + n is self.num_unique_elements() + m is len(other) + Args: + other (Iterable): elements to add to self + """ + out = self.copy() + for value in other: + out._dict[value] = out._dict.get(value, 0) + 1 + out._size += 1 + return out + + def __sub__(self, other): + """Difference between the sets. + + For normal sets this is all x s.t. x in self and x not in other. + For bags this is count(x) = max(0, self.count(x)-other.count(x)) + + This runs in O(m + n) where: + n is self.num_unique_elements() + m is len(other) + Args: + other (Iterable): elements to remove + """ + out = self.copy() + for value in other: + old_count = out._dict.get(value, 0) + if old_count == 1: + del out._dict[value] + out._size -= 1 + elif old_count > 1: + out._dict[value] = old_count - 1 + out._size -= 1 + return out + + def __mul__(self, other): + """Cartesian product of the two sets. + + other can be any iterable. + Both self and other must contain elements that can be added together. + + This should run in O(m*n+l) where: + m is the number of unique elements in self + n is the number of unique elements in other + if other is a bag: + l is 0 + else: + l is the len(other) + The +l will only really matter when other is an iterable with MANY + repeated elements. + For example: {'a'^2} * 'bbbbbbbbbbbbbbbbbbbbbbbbbb' + The algorithm will be dominated by counting the 'b's + """ + if not isinstance(other, _basebag): + other = self._from_iterable(other) + values = dict() + for elem, count in self._dict.items(): + for other_elem, other_count in other._dict.items(): + new_elem = elem + other_elem + new_count = count * other_count + values[new_elem] = new_count + return self.from_mapping(values) + + def __xor__(self, other): + """Symmetric difference between the sets. + + other can be any iterable. + + This runs in O(m + n) where: + m = len(self) + n = len(other) + """ + return (self - other) | (other - self) + + +class bag(_basebag, MutableSet): + """bag is a mutable unhashable bag.""" + + def pop(self): + """Remove and return an element of self.""" + # TODO can this be done more efficiently (no need to create an iterator)? + it = iter(self) + try: + value = next(it) + except StopIteration: + raise KeyError + self.discard(value) + return value + + def add(self, elem): + """Add elem to self.""" + self._dict[elem] = self._dict.get(elem, 0) + 1 + self._size += 1 + + def discard(self, elem): + """Remove elem from this bag, silent if it isn't present.""" + try: + self.remove(elem) + except ValueError: + pass + + def remove(self, elem): + """Remove elem from this bag, raising a ValueError if it isn't present. + + Args: + elem: object to remove from self + Raises: + ValueError: if the elem isn't present + """ + old_count = self._dict.get(elem, 0) + if old_count == 0: + raise ValueError + elif old_count == 1: + del self._dict[elem] + else: + self._dict[elem] -= 1 + self._size -= 1 + + def discard_all(self, other): + """Discard all of the elems from other.""" + if not isinstance(other, _basebag): + other = self._from_iterable(other) + for elem, other_count in other._dict.items(): + old_count = self._dict.get(elem, 0) + new_count = old_count - other_count + if new_count >= 0: + if new_count == 0: + if elem in self: + del self._dict[elem] + else: + self._dict[elem] = new_count + self._size += new_count - old_count + + def remove_all(self, other): + """Remove all of the elems from other. + + Raises a ValueError if the multiplicity of any elem in other is greater + than in self. + """ + if not self._is_superset(other): + raise ValueError + self.discard_all(other) + + def clear(self): + """Remove all elements from this bag.""" + self._dict = dict() + self._size = 0 + + # In-place operations + + def __ior__(self, other): + """Set multiplicity of each element to the maximum of the two collections. + + if isinstance(other, _basebag): + This runs in O(other.num_unique_elements()) + else: + This runs in O(len(other)) + """ + if not isinstance(other, _basebag): + other = self._from_iterable(other) + for elem, other_count in other._dict.items(): + old_count = self._dict.get(elem, 0) + new_count = max(other_count, old_count) + self._dict[elem] = new_count + self._size += new_count - old_count + return self + + def __iand__(self, other): + """Set multiplicity of each element to the minimum of the two collections. + + if isinstance(other, _basebag): + This runs in O(other.num_unique_elements()) + else: + This runs in O(len(other)) + """ + if not isinstance(other, _basebag): + other = self._from_iterable(other) + for elem, old_count in set(self._dict.items()): + other_count = other._dict.get(elem, 0) + new_count = min(other_count, old_count) + if new_count == 0: + del self._dict[elem] + else: + self._dict[elem] = new_count + self._size += new_count - old_count + return self + + def __ixor__(self, other): + """Set self to the symmetric difference between the sets. + + if isinstance(other, _basebag): + This runs in O(other.num_unique_elements()) + else: + This runs in O(len(other)) + """ + if not isinstance(other, _basebag): + other = self._from_iterable(other) + other_minus_self = other - self + self -= other + self |= other_minus_self + return self + + def __isub__(self, other): + """Discard the elements of other from self. + + if isinstance(it, _basebag): + This runs in O(it.num_unique_elements()) + else: + This runs in O(len(it)) + """ + self.discard_all(other) + return self + + def __iadd__(self, other): + """Add all of the elements of other to self. + + if isinstance(it, _basebag): + This runs in O(it.num_unique_elements()) + else: + This runs in O(len(it)) + """ + if not isinstance(other, _basebag): + other = self._from_iterable(other) + for elem, other_count in other._dict.items(): + self._dict[elem] = self._dict.get(elem, 0) + other_count + self._size += other_count + return self + + +class frozenbag(_basebag, Hashable): + """frozenbag is an immutable, hashable bab.""" + + def __hash__(self): + """Compute the hash value of a frozenbag. + + This was copied directly from _collections_abc.Set._hash in Python3 which + is identical to _abcoll.Set._hash + We can't call it directly because Python2 raises a TypeError. + """ + if not hasattr(self, '_hash_value'): + self._hash_value = self._hash() + return self._hash_value diff --git a/_vendor/collections_extended/bijection.py b/_vendor/collections_extended/bijection.py new file mode 100644 index 00000000..f9641de4 --- /dev/null +++ b/_vendor/collections_extended/bijection.py @@ -0,0 +1,94 @@ +"""Class definition for bijection.""" + +from collections import MutableMapping, Mapping + + +class bijection(MutableMapping): + """A one-to-one onto mapping, a dict with unique values.""" + + def __init__(self, iterable=None, **kwarg): + """Create a bijection from an iterable. + + Matches dict.__init__. + """ + self._data = {} + self.__inverse = self.__new__(bijection) + self.__inverse._data = {} + self.__inverse.__inverse = self + if iterable is not None: + if isinstance(iterable, Mapping): + for key, value in iterable.items(): + self[key] = value + else: + for pair in iterable: + self[pair[0]] = pair[1] + for key, value in kwarg.items(): + self[key] = value + + def __repr__(self): + if len(self._data) == 0: + return '{0}()'.format(self.__class__.__name__) + else: + repr_format = '{class_name}({values!r})' + return repr_format.format( + class_name=self.__class__.__name__, + values=self._data, + ) + + @property + def inverse(self): + """Return the inverse of this bijection.""" + return self.__inverse + + # Required for MutableMapping + def __len__(self): + return len(self._data) + + # Required for MutableMapping + def __getitem__(self, key): + return self._data[key] + + # Required for MutableMapping + def __setitem__(self, key, value): + if key in self: + del self.inverse._data[self[key]] + if value in self.inverse: + del self._data[self.inverse[value]] + self._data[key] = value + self.inverse._data[value] = key + + # Required for MutableMapping + def __delitem__(self, key): + value = self._data.pop(key) + del self.inverse._data[value] + + # Required for MutableMapping + def __iter__(self): + return iter(self._data) + + def __contains__(self, key): + return key in self._data + + def clear(self): + """Remove everything from this bijection.""" + self._data.clear() + self.inverse._data.clear() + + def copy(self): + """Return a copy of this bijection.""" + return bijection(self) + + def items(self): + """See Mapping.items.""" + return self._data.items() + + def keys(self): + """See Mapping.keys.""" + return self._data.keys() + + def values(self): + """See Mapping.values.""" + return self.inverse.keys() + + def __eq__(self, other): + return isinstance(other, bijection) and self._data == other._data diff --git a/_vendor/collections_extended/range_map.py b/_vendor/collections_extended/range_map.py new file mode 100644 index 00000000..19a61238 --- /dev/null +++ b/_vendor/collections_extended/range_map.py @@ -0,0 +1,384 @@ +"""RangeMap class definition.""" +from bisect import bisect_left, bisect_right +from collections import namedtuple, Mapping, MappingView, Set + + +# Used to mark unmapped ranges +_empty = object() + +MappedRange = namedtuple('MappedRange', ('start', 'stop', 'value')) + + +class KeysView(MappingView, Set): + """A view of the keys that mark the starts of subranges. + + Since iterating over all the keys is impossible, the KeysView only + contains the keys that start each subrange. + """ + + __slots__ = () + + @classmethod + def _from_iterable(self, it): + return set(it) + + def __contains__(self, key): + loc = self._mapping._bisect_left(key) + return self._mapping._keys[loc] == key and \ + self._mapping._values[loc] is not _empty + + def __iter__(self): + for item in self._mapping.ranges(): + yield item.start + + +class ItemsView(MappingView, Set): + """A view of the items that mark the starts of subranges. + + Since iterating over all the keys is impossible, the ItemsView only + contains the items that start each subrange. + """ + + __slots__ = () + + @classmethod + def _from_iterable(self, it): + return set(it) + + def __contains__(self, item): + key, value = item + loc = self._mapping._bisect_left(key) + return self._mapping._keys[loc] == key and \ + self._mapping._values[loc] == value + + def __iter__(self): + for mapped_range in self._mapping.ranges(): + yield (mapped_range.start, mapped_range.value) + + +class ValuesView(MappingView): + """A view on the values of a Mapping.""" + + __slots__ = () + + def __contains__(self, value): + return value in self._mapping._values + + def __iter__(self): + for value in self._mapping._values: + if value is not _empty: + yield value + + +def _check_start_stop(start, stop): + """Check that start and stop are valid - orderable and in the right order. + + Raises: + ValueError: if stop <= start + TypeError: if unorderable + """ + if start is not None and stop is not None and stop <= start: + raise ValueError('stop must be > start') + + +def _check_key_slice(key): + if not isinstance(key, slice): + raise TypeError('Can only set and delete slices') + if key.step is not None: + raise ValueError('Cannot set or delete slices with steps') + + +class RangeMap(Mapping): + """Map ranges of orderable elements to values.""" + + def __init__(self, iterable=None, **kwargs): + """Create a RangeMap. + + A mapping or other iterable can be passed to initialize the RangeMap. + If mapping is passed, it is interpreted as a mapping from range start + indices to values. + If an iterable is passed, each element will define a range in the + RangeMap and should be formatted (start, stop, value). + + default_value is a an optional keyword argument that will initialize the + entire RangeMap to that value. Any missing ranges will be mapped to that + value. However, if ranges are subsequently deleted they will be removed + and *not* mapped to the default_value. + + Args: + iterable: A Mapping or an Iterable to initialize from. + default_value: If passed, the return value for all keys less than the + least key in mapping or missing ranges in iterable. If no mapping + or iterable, the return value for all keys. + """ + default_value = kwargs.pop('default_value', _empty) + if kwargs: + raise TypeError('Unknown keyword arguments: %s' % ', '.join(kwargs.keys())) + self._keys = [None] + self._values = [default_value] + if iterable: + if isinstance(iterable, Mapping): + self._init_from_mapping(iterable) + else: + self._init_from_iterable(iterable) + + @classmethod + def from_mapping(cls, mapping): + """Create a RangeMap from a mapping of interval starts to values.""" + obj = cls() + obj._init_from_mapping(mapping) + return obj + + def _init_from_mapping(self, mapping): + for key, value in sorted(mapping.items()): + self.set(value, key) + + @classmethod + def from_iterable(cls, iterable): + """Create a RangeMap from an iterable of tuples defining each range. + + Each element of the iterable is a tuple (start, stop, value). + """ + obj = cls() + obj._init_from_iterable(iterable) + return obj + + def _init_from_iterable(self, iterable): + for start, stop, value in iterable: + self.set(value, start=start, stop=stop) + + def __str__(self): + range_format = '({range.start}, {range.stop}): {range.value}' + values = ', '.join([range_format.format(range=r) for r in self.ranges()]) + return 'RangeMap(%s)' % values + + def __repr__(self): + range_format = '({range.start!r}, {range.stop!r}, {range.value!r})' + values = ', '.join([range_format.format(range=r) for r in self.ranges()]) + return 'RangeMap([%s])' % values + + def _bisect_left(self, key): + """Return the index of the key or the last key < key.""" + if key is None: + return 0 + else: + return bisect_left(self._keys, key, lo=1) + + def _bisect_right(self, key): + """Return the index of the first key > key.""" + if key is None: + return 1 + else: + return bisect_right(self._keys, key, lo=1) + + def ranges(self, start=None, stop=None): + """Generate MappedRanges for all mapped ranges. + + Yields: + MappedRange + """ + _check_start_stop(start, stop) + start_loc = self._bisect_right(start) + if stop is None: + stop_loc = len(self._keys) + else: + stop_loc = self._bisect_left(stop) + start_val = self._values[start_loc - 1] + candidate_keys = [start] + self._keys[start_loc:stop_loc] + [stop] + candidate_values = [start_val] + self._values[start_loc:stop_loc] + for i, value in enumerate(candidate_values): + if value is not _empty: + start_key = candidate_keys[i] + stop_key = candidate_keys[i + 1] + yield MappedRange(start_key, stop_key, value) + + def __contains__(self, value): + try: + self.__getitem(value) is not _empty + except KeyError: + return False + else: + return True + + def __iter__(self): + for key, value in zip(self._keys, self._values): + if value is not _empty: + yield key + + def __bool__(self): + if len(self._keys) > 1: + return True + else: + return self._values[0] != _empty + + __nonzero__ = __bool__ + + def __getitem(self, key): + """Get the value for a key (not a slice).""" + loc = self._bisect_right(key) - 1 + value = self._values[loc] + if value is _empty: + raise KeyError(key) + else: + return value + + def get(self, key, restval=None): + """Get the value of the range containing key, otherwise return restval.""" + try: + return self.__getitem(key) + except KeyError: + return restval + + def get_range(self, start=None, stop=None): + """Return a RangeMap for the range start to stop. + + Returns: + A RangeMap + """ + return self.from_iterable(self.ranges(start, stop)) + + def set(self, value, start=None, stop=None): + """Set the range from start to stop to value.""" + _check_start_stop(start, stop) + # start_index, stop_index will denote the sections we are replacing + start_index = self._bisect_left(start) + if start is not None: # start_index == 0 + prev_value = self._values[start_index - 1] + if prev_value == value: + # We're setting a range where the left range has the same + # value, so create one big range + start_index -= 1 + start = self._keys[start_index] + if stop is None: + new_keys = [start] + new_values = [value] + stop_index = len(self._keys) + else: + stop_index = self._bisect_right(stop) + stop_value = self._values[stop_index - 1] + stop_key = self._keys[stop_index - 1] + if stop_key == stop and stop_value == value: + new_keys = [start] + new_values = [value] + else: + new_keys = [start, stop] + new_values = [value, stop_value] + self._keys[start_index:stop_index] = new_keys + self._values[start_index:stop_index] = new_values + + def delete(self, start=None, stop=None): + """Delete the range from start to stop from self. + + Raises: + KeyError: If part of the passed range isn't mapped. + """ + _check_start_stop(start, stop) + start_loc = self._bisect_right(start) - 1 + if stop is None: + stop_loc = len(self._keys) + else: + stop_loc = self._bisect_left(stop) + for value in self._values[start_loc:stop_loc]: + if value is _empty: + raise KeyError((start, stop)) + # this is inefficient, we've already found the sub ranges + self.set(_empty, start=start, stop=stop) + + def empty(self, start=None, stop=None): + """Empty the range from start to stop. + + Like delete, but no Error is raised if the entire range isn't mapped. + """ + self.set(_empty, start=start, stop=stop) + + def clear(self): + """Remove all elements.""" + self._keys = [None] + self._values = [_empty] + + @property + def start(self): + """Get the start key of the first range. + + None if RangeMap is empty or unbounded to the left. + """ + if self._values[0] is _empty: + try: + return self._keys[1] + except IndexError: + # This is empty or everything is mapped to a single value + return None + else: + # This is unbounded to the left + return self._keys[0] + + @property + def end(self): + """Get the stop key of the last range. + + None if RangeMap is empty or unbounded to the right. + """ + if self._values[-1] is _empty: + return self._keys[-1] + else: + # This is unbounded to the right + return None + + def __eq__(self, other): + if isinstance(other, RangeMap): + return ( + self._keys == other._keys and + self._values == other._values + ) + else: + return False + + def __getitem__(self, key): + try: + _check_key_slice(key) + except TypeError: + return self.__getitem(key) + else: + return self.get_range(key.start, key.stop) + + def __setitem__(self, key, value): + _check_key_slice(key) + self.set(value, key.start, key.stop) + + def __delitem__(self, key): + _check_key_slice(key) + self.delete(key.start, key.stop) + + def __len__(self): + count = 0 + for v in self._values: + if v is not _empty: + count += 1 + return count + + def keys(self): + """Return a view of the keys.""" + return KeysView(self) + + def values(self): + """Return a view of the values.""" + return ValuesView(self) + + def items(self): + """Return a view of the item pairs.""" + return ItemsView(self) + + # Python2 - override slice methods + def __setslice__(self, i, j, value): + """Implement __setslice__ to override behavior in Python 2. + + This is required because empty slices pass integers in python2 as opposed + to None in python 3. + """ + raise SyntaxError('Assigning slices doesn\t work in Python 2, use set') + + def __delslice__(self, i, j): + raise SyntaxError('Deleting slices doesn\t work in Python 2, use delete') + + def __getslice__(self, i, j): + raise SyntaxError('Getting slices doesn\t work in Python 2, use get_range.') diff --git a/_vendor/collections_extended/setlists.py b/_vendor/collections_extended/setlists.py new file mode 100644 index 00000000..2976077c --- /dev/null +++ b/_vendor/collections_extended/setlists.py @@ -0,0 +1,552 @@ +"""Setlist class definitions.""" +import random as random_ + +from collections import ( + Sequence, + Set, + MutableSequence, + MutableSet, + Hashable, + ) + +from . import _util + + +class _basesetlist(Sequence, Set): + """A setlist is an ordered Collection of unique elements. + + _basesetlist is the superclass of setlist and frozensetlist. It is immutable + and unhashable. + """ + + def __init__(self, iterable=None, raise_on_duplicate=False): + """Create a setlist. + + Args: + iterable (Iterable): Values to initialize the setlist with. + """ + self._list = list() + self._dict = dict() + if iterable: + if raise_on_duplicate: + self._extend(iterable) + else: + self._update(iterable) + + def __repr__(self): + if len(self) == 0: + return '{0}()'.format(self.__class__.__name__) + else: + repr_format = '{class_name}({values!r})' + return repr_format.format( + class_name=self.__class__.__name__, + values=tuple(self), + ) + + # Convenience methods + def _fix_neg_index(self, index): + if index < 0: + index += len(self) + if index < 0: + raise IndexError('index is out of range') + return index + + def _fix_end_index(self, index): + if index is None: + return len(self) + else: + return self._fix_neg_index(index) + + def _append(self, value): + # Checking value in self will check that value is Hashable + if value in self: + raise ValueError('Value "%s" already present' % str(value)) + else: + self._dict[value] = len(self) + self._list.append(value) + + def _extend(self, values): + new_values = set() + for value in values: + if value in new_values: + raise ValueError('New values contain duplicates') + elif value in self: + raise ValueError('New values contain elements already present in self') + else: + new_values.add(value) + for value in values: + self._dict[value] = len(self) + self._list.append(value) + + def _add(self, item): + if item not in self: + self._dict[item] = len(self) + self._list.append(item) + + def _update(self, values): + for value in values: + if value not in self: + self._dict[value] = len(self) + self._list.append(value) + + @classmethod + def _from_iterable(cls, it, **kwargs): + return cls(it, **kwargs) + + # Implement Container + def __contains__(self, value): + return value in self._dict + + # Iterable we get by inheriting from Sequence + + # Implement Sized + def __len__(self): + return len(self._list) + + # Implement Sequence + def __getitem__(self, index): + if isinstance(index, slice): + return self._from_iterable(self._list[index]) + return self._list[index] + + def count(self, value): + """Return the number of occurences of value in self. + + This runs in O(1) + + Args: + value: The value to count + Returns: + int: 1 if the value is in the setlist, otherwise 0 + """ + if value in self: + return 1 + else: + return 0 + + def index(self, value, start=0, end=None): + """Return the index of value between start and end. + + By default, the entire setlist is searched. + + This runs in O(1) + + Args: + value: The value to find the index of + start (int): The index to start searching at (defaults to 0) + end (int): The index to stop searching at (defaults to the end of the list) + Returns: + int: The index of the value + Raises: + ValueError: If the value is not in the list or outside of start - end + IndexError: If start or end are out of range + """ + try: + index = self._dict[value] + except KeyError: + raise ValueError + else: + start = self._fix_neg_index(start) + end = self._fix_end_index(end) + if start <= index and index < end: + return index + else: + raise ValueError + + @classmethod + def _check_type(cls, other, operand_name): + if not isinstance(other, _basesetlist): + message = ( + "unsupported operand type(s) for {operand_name}: " + "'{self_type}' and '{other_type}'").format( + operand_name=operand_name, + self_type=cls, + other_type=type(other), + ) + raise TypeError(message) + + def __add__(self, other): + self._check_type(other, '+') + out = self.copy() + out._extend(other) + return out + + # Implement Set + + def issubset(self, other): + return self <= other + + def issuperset(self, other): + return self >= other + + def union(self, other): + out = self.copy() + out.update(other) + return out + + def intersection(self, other): + other = set(other) + return self._from_iterable(item for item in self if item in other) + + def difference(self, other): + other = set(other) + return self._from_iterable(item for item in self if item not in other) + + def symmetric_difference(self, other): + return self.union(other) - self.intersection(other) + + def __sub__(self, other): + self._check_type(other, '-') + return self.difference(other) + + def __and__(self, other): + self._check_type(other, '&') + return self.intersection(other) + + def __or__(self, other): + self._check_type(other, '|') + return self.union(other) + + def __xor__(self, other): + self._check_type(other, '^') + return self.symmetric_difference(other) + + # Comparison + + def __eq__(self, other): + if not isinstance(other, _basesetlist): + return False + if not len(self) == len(other): + return False + for self_elem, other_elem in zip(self, other): + if self_elem != other_elem: + return False + return True + + def __ne__(self, other): + return not (self == other) + + # New methods + + def sub_index(self, sub, start=0, end=None): + """Return the index of a subsequence. + + This runs in O(len(sub)) + + Args: + sub (Sequence): An Iterable to search for + Returns: + int: The index of the first element of sub + Raises: + ValueError: If sub isn't a subsequence + TypeError: If sub isn't iterable + IndexError: If start or end are out of range + """ + start_index = self.index(sub[0], start, end) + end = self._fix_end_index(end) + if start_index + len(sub) > end: + raise ValueError + for i in range(1, len(sub)): + if sub[i] != self[start_index + i]: + raise ValueError + return start_index + + def copy(self): + return self.__class__(self) + + +class setlist(_basesetlist, MutableSequence, MutableSet): + """A mutable (unhashable) setlist.""" + + def __str__(self): + return '{[%s}]' % ', '.join(repr(v) for v in self) + + # Helper methods + def _delete_all(self, elems_to_delete, raise_errors): + indices_to_delete = set() + for elem in elems_to_delete: + try: + elem_index = self._dict[elem] + except KeyError: + if raise_errors: + raise ValueError('Passed values contain elements not in self') + else: + if elem_index in indices_to_delete: + if raise_errors: + raise ValueError('Passed vales contain duplicates') + indices_to_delete.add(elem_index) + self._delete_values_by_index(indices_to_delete) + + def _delete_values_by_index(self, indices_to_delete): + deleted_count = 0 + for i, elem in enumerate(self._list): + if i in indices_to_delete: + deleted_count += 1 + del self._dict[elem] + else: + new_index = i - deleted_count + self._list[new_index] = elem + self._dict[elem] = new_index + # Now remove deleted_count items from the end of the list + if deleted_count: + self._list = self._list[:-deleted_count] + + # Set/Sequence agnostic + def pop(self, index=-1): + """Remove and return the item at index.""" + value = self._list.pop(index) + del self._dict[value] + return value + + def clear(self): + """Remove all elements from self.""" + self._dict = dict() + self._list = list() + + # Implement MutableSequence + def __setitem__(self, index, value): + if isinstance(index, slice): + old_values = self[index] + for v in value: + if v in self and v not in old_values: + raise ValueError + self._list[index] = value + self._dict = {} + for i, v in enumerate(self._list): + self._dict[v] = i + else: + index = self._fix_neg_index(index) + old_value = self._list[index] + if value in self: + if value == old_value: + return + else: + raise ValueError + del self._dict[old_value] + self._list[index] = value + self._dict[value] = index + + def __delitem__(self, index): + if isinstance(index, slice): + indices_to_delete = set(self.index(e) for e in self._list[index]) + self._delete_values_by_index(indices_to_delete) + else: + index = self._fix_neg_index(index) + value = self._list[index] + del self._dict[value] + for elem in self._list[index + 1:]: + self._dict[elem] -= 1 + del self._list[index] + + def insert(self, index, value): + """Insert value at index. + + Args: + index (int): Index to insert value at + value: Value to insert + Raises: + ValueError: If value already in self + IndexError: If start or end are out of range + """ + if value in self: + raise ValueError + index = self._fix_neg_index(index) + self._dict[value] = index + for elem in self._list[index:]: + self._dict[elem] += 1 + self._list.insert(index, value) + + def append(self, value): + """Append value to the end. + + Args: + value: Value to append + Raises: + ValueError: If value alread in self + TypeError: If value isn't hashable + """ + self._append(value) + + def extend(self, values): + """Append all values to the end. + + If any of the values are present, ValueError will + be raised and none of the values will be appended. + + Args: + values (Iterable): Values to append + Raises: + ValueError: If any values are already present or there are duplicates + in the passed values. + TypeError: If any of the values aren't hashable. + """ + self._extend(values) + + def __iadd__(self, values): + """Add all values to the end of self. + + Args: + values (Iterable): Values to append + Raises: + ValueError: If any values are already present + """ + self._check_type(values, '+=') + self.extend(values) + return self + + def remove(self, value): + """Remove value from self. + + Args: + value: Element to remove from self + Raises: + ValueError: if element is already present + """ + try: + index = self._dict[value] + except KeyError: + raise ValueError('Value "%s" is not present.') + else: + del self[index] + + def remove_all(self, elems_to_delete): + """Remove all elements from elems_to_delete, raises ValueErrors. + + See Also: + discard_all + Args: + elems_to_delete (Iterable): Elements to remove. + Raises: + ValueError: If the count of any element is greater in + elems_to_delete than self. + TypeError: If any of the values aren't hashable. + """ + self._delete_all(elems_to_delete, raise_errors=True) + + # Implement MutableSet + + def add(self, item): + """Add an item. + + Note: + This does not raise a ValueError for an already present value like + append does. This is to match the behavior of set.add + Args: + item: Item to add + Raises: + TypeError: If item isn't hashable. + """ + self._add(item) + + def update(self, values): + """Add all values to the end. + + If any of the values are present, silently ignore + them (as opposed to extend which raises an Error). + + See also: + extend + Args: + values (Iterable): Values to add + Raises: + TypeError: If any of the values are unhashable. + """ + self._update(values) + + def discard_all(self, elems_to_delete): + """Discard all the elements from elems_to_delete. + + This is much faster than removing them one by one. + This runs in O(len(self) + len(elems_to_delete)) + + Args: + elems_to_delete (Iterable): Elements to discard. + Raises: + TypeError: If any of the values aren't hashable. + """ + self._delete_all(elems_to_delete, raise_errors=False) + + def discard(self, value): + """Discard an item. + + Note: + This does not raise a ValueError for a missing value like remove does. + This is to match the behavior of set.discard + """ + try: + self.remove(value) + except ValueError: + pass + + def difference_update(self, other): + """Update self to include only the differene with other.""" + other = set(other) + indices_to_delete = set() + for i, elem in enumerate(self): + if elem in other: + indices_to_delete.add(i) + if indices_to_delete: + self._delete_values_by_index(indices_to_delete) + + def intersection_update(self, other): + """Update self to include only the intersection with other.""" + other = set(other) + indices_to_delete = set() + for i, elem in enumerate(self): + if elem not in other: + indices_to_delete.add(i) + if indices_to_delete: + self._delete_values_by_index(indices_to_delete) + + def symmetric_difference_update(self, other): + """Update self to include only the symmetric difference with other.""" + other = setlist(other) + indices_to_delete = set() + for i, item in enumerate(self): + if item in other: + indices_to_delete.add(i) + for item in other: + self.add(item) + self._delete_values_by_index(indices_to_delete) + + def __isub__(self, other): + self._check_type(other, '-=') + self.difference_update(other) + return self + + def __iand__(self, other): + self._check_type(other, '&=') + self.intersection_update(other) + return self + + def __ior__(self, other): + self._check_type(other, '|=') + self.update(other) + return self + + def __ixor__(self, other): + self._check_type(other, '^=') + self.symmetric_difference_update(other) + return self + + # New methods + def shuffle(self, random=None): + """Shuffle all of the elements in self randomly.""" + random_.shuffle(self._list, random=random) + for i, elem in enumerate(self._list): + self._dict[elem] = i + + def sort(self, *args, **kwargs): + """Sort this setlist in place.""" + self._list.sort(*args, **kwargs) + for index, value in enumerate(self._list): + self._dict[value] = index + + +class frozensetlist(_basesetlist, Hashable): + """An immutable (hashable) setlist.""" + + def __hash__(self): + if not hasattr(self, '_hash_value'): + self._hash_value = _util.hash_iterable(self) + return self._hash_value