diff --git a/src/cpl_query/extension/iterable_abc.py b/src/cpl_query/extension/iterable_abc.py index 45135c6e..e45d7815 100644 --- a/src/cpl_query/extension/iterable_abc.py +++ b/src/cpl_query/extension/iterable_abc.py @@ -158,7 +158,7 @@ class IterableABC(ABC, list): """ pass - def extend(self, __iterable: Iterable) -> None: + def extend(self, __iterable: Iterable) -> 'IterableABC': r"""Adds elements of given list to list Parameter @@ -169,6 +169,8 @@ class IterableABC(ABC, list): for value in __iterable: self.append(value) + return self + @abstractmethod def last(self) -> any: r"""Returns last element @@ -290,6 +292,24 @@ class IterableABC(ABC, list): """ pass + def select(self, _f: Callable) -> 'IterableABC': + r"""Formats each element of list to a given format + + Returns + ------- + :class: `cpl_query.extension.iterable_abc.IterableABC` + """ + pass + + def select_many(self, _f: Callable) -> 'IterableABC': + r"""Flattens resulting lists to one + + Returns + ------- + :class: `cpl_query.extension.iterable_abc.IterableABC` + """ + pass + @abstractmethod def single(self) -> any: r"""Returns one single element of list diff --git a/unittests/unittests_query/query_test_case.py b/unittests/unittests_query/query_test_case.py index 00d4fed8..74fff280 100644 --- a/unittests/unittests_query/query_test_case.py +++ b/unittests/unittests_query/query_test_case.py @@ -161,9 +161,9 @@ class QueryTestCase(unittest.TestCase): def test_for_each(self): users = [] self._tests.for_each(lambda user: ( - users.append(user) - ) + users.append(user) ) + ) self.assertEqual(len(users), len(self._tests)) @@ -239,6 +239,35 @@ class QueryTestCase(unittest.TestCase): self.assertEqual(l_res, res) + def test_select(self): + range_list = List(int, range(0, 100)) + selected_range = range_list.select(lambda x: x + 1) + + modulo_range = [] + for x in range(0, 100): + if x % 2 == 0: + modulo_range.append(x) + self.assertEqual(selected_range.to_list(), list(range(1, 101))) + self.assertEqual(range_list.where(lambda x: x % 2 == 0).to_list(), modulo_range) + + def test_select_many(self): + range_list = List(int, range(0, 100)) + selected_range = range_list.select(lambda x: [x, x]) + + self.assertEqual(selected_range, [[x, x] for x in range(0, 100)]) + self.assertEqual(selected_range.select_many(lambda x: x).to_list(), [_x for _l in [2 * [x] for x in range(0, 100)] for _x in _l]) + + class TestClass: + def __init__(self, i, is_sub=False): + self.i = i + if is_sub: + return + self.elements = [TestClass(x, True) for x in range(0, 10)] + + elements = List(TestClass, [TestClass(i) for i in range(0, 100)]) + selected_elements = elements.select_many(lambda x: x.elements).select(lambda x: x.i) + self.assertEqual(selected_elements.where(lambda x: x == 0).count(), 100) + def test_single(self): res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr) s_res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr).single()