239 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			239 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
	
# Tests for _speedups.LocationStore and NetUtils._LocationStore
 | 
						|
import typing
 | 
						|
import unittest
 | 
						|
import warnings
 | 
						|
from NetUtils import LocationStore, _LocationStore
 | 
						|
 | 
						|
State = typing.Dict[typing.Tuple[int, int], typing.Set[int]]
 | 
						|
RawLocations = typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]]
 | 
						|
 | 
						|
sample_data: RawLocations = {
 | 
						|
    1: {
 | 
						|
        11: (21, 2, 7),
 | 
						|
        12: (22, 2, 0),
 | 
						|
        13: (13, 1, 0),
 | 
						|
    },
 | 
						|
    2: {
 | 
						|
        23: (11, 1, 0),
 | 
						|
        22: (12, 1, 0),
 | 
						|
        21: (23, 2, 0),
 | 
						|
    },
 | 
						|
    4: {
 | 
						|
        9: (99, 3, 0),
 | 
						|
    },
 | 
						|
    3: {
 | 
						|
        9: (99, 4, 0),
 | 
						|
    },
 | 
						|
}
 | 
						|
 | 
						|
empty_state: State = {
 | 
						|
    (0, slot): set() for slot in sample_data
 | 
						|
}
 | 
						|
 | 
						|
full_state: State = {
 | 
						|
    (0, slot): set(locations) for (slot, locations) in sample_data.items()
 | 
						|
}
 | 
						|
 | 
						|
one_state: State = {
 | 
						|
    (0, 1): {12}
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
class Base:
 | 
						|
    class TestLocationStore(unittest.TestCase):
 | 
						|
        """Test method calls on a loaded store."""
 | 
						|
        store: typing.Union[LocationStore, _LocationStore]
 | 
						|
 | 
						|
        def test_len(self) -> None:
 | 
						|
            self.assertEqual(len(self.store), 4)
 | 
						|
            self.assertEqual(len(self.store[1]), 3)
 | 
						|
 | 
						|
        def test_key_error(self) -> None:
 | 
						|
            with self.assertRaises(KeyError):
 | 
						|
                _ = self.store[0]
 | 
						|
            with self.assertRaises(KeyError):
 | 
						|
                _ = self.store[5]
 | 
						|
            locations = self.store[1]  # no Exception
 | 
						|
            with self.assertRaises(KeyError):
 | 
						|
                _ = locations[7]
 | 
						|
            _ = locations[11]  # no Exception
 | 
						|
 | 
						|
        def test_getitem(self) -> None:
 | 
						|
            self.assertEqual(self.store[1][11], (21, 2, 7))
 | 
						|
            self.assertEqual(self.store[1][13], (13, 1, 0))
 | 
						|
            self.assertEqual(self.store[2][22], (12, 1, 0))
 | 
						|
            self.assertEqual(self.store[4][9], (99, 3, 0))
 | 
						|
 | 
						|
        def test_get(self) -> None:
 | 
						|
            self.assertEqual(self.store.get(1, None), self.store[1])
 | 
						|
            self.assertEqual(self.store.get(0, None), None)
 | 
						|
            self.assertEqual(self.store[1].get(11, (None, None, None)), self.store[1][11])
 | 
						|
            self.assertEqual(self.store[1].get(10, (None, None, None)), (None, None, None))
 | 
						|
 | 
						|
        def test_iter(self) -> None:
 | 
						|
            self.assertEqual(sorted(self.store), [1, 2, 3, 4])
 | 
						|
            self.assertEqual(len(self.store), len(sample_data))
 | 
						|
            self.assertEqual(list(self.store[1]), [11, 12, 13])
 | 
						|
            self.assertEqual(len(self.store[1]), len(sample_data[1]))
 | 
						|
 | 
						|
        def test_items(self) -> None:
 | 
						|
            self.assertEqual(sorted(p for p, _ in self.store.items()), sorted(self.store))
 | 
						|
            self.assertEqual(sorted(p for p, _ in self.store[1].items()), sorted(self.store[1]))
 | 
						|
            self.assertEqual(sorted(self.store.items())[0][0], 1)
 | 
						|
            self.assertEqual(sorted(self.store.items())[0][1], self.store[1])
 | 
						|
            self.assertEqual(sorted(self.store[1].items())[0][0], 11)
 | 
						|
            self.assertEqual(sorted(self.store[1].items())[0][1], self.store[1][11])
 | 
						|
 | 
						|
        def test_find_item(self) -> None:
 | 
						|
            self.assertEqual(sorted(self.store.find_item(set(), 99)), [])
 | 
						|
            self.assertEqual(sorted(self.store.find_item({3}, 1)), [])
 | 
						|
            self.assertEqual(sorted(self.store.find_item({5}, 99)), [])
 | 
						|
            self.assertEqual(sorted(self.store.find_item({3}, 99)),
 | 
						|
                             [(4, 9, 99, 3, 0)])
 | 
						|
            self.assertEqual(sorted(self.store.find_item({3, 4}, 99)),
 | 
						|
                             [(3, 9, 99, 4, 0), (4, 9, 99, 3, 0)])
 | 
						|
 | 
						|
        def test_get_for_player(self) -> None:
 | 
						|
            self.assertEqual(self.store.get_for_player(3), {4: {9}})
 | 
						|
            self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}})
 | 
						|
 | 
						|
        def test_get_checked(self) -> None:
 | 
						|
            self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13])
 | 
						|
            self.assertEqual(self.store.get_checked(one_state, 0, 1), [12])
 | 
						|
            self.assertEqual(self.store.get_checked(empty_state, 0, 1), [])
 | 
						|
            self.assertEqual(self.store.get_checked(full_state, 0, 3), [9])
 | 
						|
 | 
						|
        def test_get_missing(self) -> None:
 | 
						|
            self.assertEqual(self.store.get_missing(full_state, 0, 1), [])
 | 
						|
            self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13])
 | 
						|
            self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13])
 | 
						|
            self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9])
 | 
						|
 | 
						|
        def test_get_remaining(self) -> None:
 | 
						|
            self.assertEqual(self.store.get_remaining(full_state, 0, 1), [])
 | 
						|
            self.assertEqual(self.store.get_remaining(one_state, 0, 1), [13, 21])
 | 
						|
            self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [13, 21, 22])
 | 
						|
            self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [99])
 | 
						|
 | 
						|
        def test_location_set_intersection(self) -> None:
 | 
						|
            locations = {10, 11, 12}
 | 
						|
            locations.intersection_update(self.store[1])
 | 
						|
            self.assertEqual(locations, {11, 12})
 | 
						|
 | 
						|
    class TestLocationStoreConstructor(unittest.TestCase):
 | 
						|
        """Test constructors for a given store type."""
 | 
						|
        type: type
 | 
						|
 | 
						|
        def test_hole(self) -> None:
 | 
						|
            with self.assertRaises(Exception):
 | 
						|
                self.type({
 | 
						|
                    1: {1: (1, 1, 1)},
 | 
						|
                    3: {1: (1, 1, 1)},
 | 
						|
                })
 | 
						|
 | 
						|
        def test_no_slot1(self) -> None:
 | 
						|
            with self.assertRaises(Exception):
 | 
						|
                self.type({
 | 
						|
                    2: {1: (1, 1, 1)},
 | 
						|
                    3: {1: (1, 1, 1)},
 | 
						|
                })
 | 
						|
 | 
						|
        def test_slot0(self) -> None:
 | 
						|
            with self.assertRaises(ValueError):
 | 
						|
                self.type({
 | 
						|
                    0: {1: (1, 1, 1)},
 | 
						|
                    1: {1: (1, 1, 1)},
 | 
						|
                })
 | 
						|
            with self.assertRaises(ValueError):
 | 
						|
                self.type({
 | 
						|
                    0: {1: (1, 1, 1)},
 | 
						|
                    2: {1: (1, 1, 1)},
 | 
						|
                })
 | 
						|
 | 
						|
        def test_no_players(self) -> None:
 | 
						|
            with self.assertRaises(Exception):
 | 
						|
                _ = self.type({})
 | 
						|
 | 
						|
        def test_no_locations(self) -> None:
 | 
						|
            with warnings.catch_warnings():
 | 
						|
                warnings.simplefilter("ignore")
 | 
						|
                store = self.type({
 | 
						|
                    1: {},
 | 
						|
                })
 | 
						|
                self.assertEqual(len(store), 1)
 | 
						|
                self.assertEqual(len(store[1]), 0)
 | 
						|
 | 
						|
        def test_no_locations_for_1(self) -> None:
 | 
						|
            store = self.type({
 | 
						|
                1: {},
 | 
						|
                2: {1: (1, 2, 3)},
 | 
						|
            })
 | 
						|
            self.assertEqual(len(store), 2)
 | 
						|
            self.assertEqual(len(store[1]), 0)
 | 
						|
            self.assertEqual(len(store[2]), 1)
 | 
						|
 | 
						|
        def test_no_locations_for_last(self) -> None:
 | 
						|
            store = self.type({
 | 
						|
                1: {1: (1, 2, 3)},
 | 
						|
                2: {},
 | 
						|
            })
 | 
						|
            self.assertEqual(len(store), 2)
 | 
						|
            self.assertEqual(len(store[1]), 1)
 | 
						|
            self.assertEqual(len(store[2]), 0)
 | 
						|
 | 
						|
 | 
						|
class TestPurePythonLocationStore(Base.TestLocationStore):
 | 
						|
    """Run base method tests for pure python implementation."""
 | 
						|
    def setUp(self) -> None:
 | 
						|
        self.store = _LocationStore(sample_data)
 | 
						|
        super().setUp()
 | 
						|
 | 
						|
 | 
						|
class TestPurePythonLocationStoreConstructor(Base.TestLocationStoreConstructor):
 | 
						|
    """Run base constructor tests for the pure python implementation."""
 | 
						|
    def setUp(self) -> None:
 | 
						|
        self.type = _LocationStore
 | 
						|
        super().setUp()
 | 
						|
 | 
						|
 | 
						|
@unittest.skipIf(LocationStore is _LocationStore, "_speedups not available")
 | 
						|
class TestSpeedupsLocationStore(Base.TestLocationStore):
 | 
						|
    """Run base method tests for cython implementation."""
 | 
						|
    def setUp(self) -> None:
 | 
						|
        self.store = LocationStore(sample_data)
 | 
						|
        super().setUp()
 | 
						|
 | 
						|
 | 
						|
@unittest.skipIf(LocationStore is _LocationStore, "_speedups not available")
 | 
						|
class TestSpeedupsLocationStoreConstructor(Base.TestLocationStoreConstructor):
 | 
						|
    """Run base constructor tests and tests the additional constraints for cython implementation."""
 | 
						|
    def setUp(self) -> None:
 | 
						|
        self.type = LocationStore
 | 
						|
        super().setUp()
 | 
						|
 | 
						|
    def test_float_key(self) -> None:
 | 
						|
        with self.assertRaises(Exception):
 | 
						|
            self.type({
 | 
						|
                1: {1: (1, 1, 1)},
 | 
						|
                1.1: {1: (1, 1, 1)},
 | 
						|
                3: {1: (1, 1, 1)}
 | 
						|
            })
 | 
						|
 | 
						|
    def test_string_key(self) -> None:
 | 
						|
        with self.assertRaises(Exception):
 | 
						|
            self.type({
 | 
						|
                "1": {1: (1, 1, 1)},
 | 
						|
            })
 | 
						|
 | 
						|
    def test_high_player_number(self) -> None:
 | 
						|
        with self.assertRaises(Exception):
 | 
						|
            self.type({
 | 
						|
                1 << 32: {1: (1, 1, 1)},
 | 
						|
            })
 | 
						|
 | 
						|
    def test_not_a_tuple(self) -> None:
 | 
						|
        with self.assertRaises(Exception):
 | 
						|
            self.type({
 | 
						|
                1: {1: None},
 | 
						|
            })
 |