Refactoring & added all query

This commit is contained in:
Sven Heidemann 2021-07-27 09:15:19 +02:00
parent a3fff9c7d7
commit 0f85d4b9bc
5 changed files with 28 additions and 5 deletions

View File

@ -2,6 +2,7 @@ from typing import Optional, Callable
from cpl_query._extension.ordered_iterable import OrderedIterable from cpl_query._extension.ordered_iterable import OrderedIterable
from cpl_query.extension.ordered_iterable_abc import OrderedIterableABC from cpl_query.extension.ordered_iterable_abc import OrderedIterableABC
from .._query.all_query import all_query
from .._query.any_query import any_query from .._query.any_query import any_query
from .._query.first_query import first_or_default_query, first_query from .._query.first_query import first_or_default_query, first_query
from .._query.for_each_query import for_each_query from .._query.for_each_query import for_each_query
@ -19,6 +20,9 @@ class Iterable(IterableABC):
def any(self, func: Callable) -> bool: def any(self, func: Callable) -> bool:
return any_query(self, func) return any_query(self, func)
def all(self, func: Callable) -> bool:
return all_query(self, func)
def first(self) -> any: def first(self) -> any:
return first_query(self) return first_query(self)

View File

@ -0,0 +1,9 @@
from collections import Callable
from cpl_query._query.where_query import where_query
from cpl_query.extension.iterable_abc import IterableABC
def all_query(_list: IterableABC, _func: Callable) -> bool:
result = where_query(_list, _func)
return len(result) == len(_list)

View File

@ -6,12 +6,7 @@ from cpl_query.extension.iterable_abc import IterableABC
def where_query(_list: IterableABC, _func: Callable) -> IterableABC: def where_query(_list: IterableABC, _func: Callable) -> IterableABC:
result = IterableABC() result = IterableABC()
for element in _list: for element in _list:
element_type = type(element).__name__
if _func(element): if _func(element):
result.append(element) result.append(element)
# if element_type in _func:
# func = _func.replace(element_type, 'element')
# if eval(func):
# result.append(element)
return result return result

View File

@ -11,6 +11,9 @@ class IterableABC(ABC, list):
@abstractmethod @abstractmethod
def any(self, func: Callable) -> bool: pass def any(self, func: Callable) -> bool: pass
@abstractmethod
def all(self, func: Callable) -> bool: pass
@abstractmethod @abstractmethod
def first(self) -> any: pass def first(self) -> any: pass

View File

@ -47,6 +47,18 @@ class QueryTest(unittest.TestCase):
self.assertTrue(res) self.assertTrue(res)
self.assertFalse(n_res) self.assertFalse(n_res)
def test_all(self):
results = []
for user in self._tests:
if user.address.nr == 10:
results.append(user)
res = self._tests.all(lambda u: u.address is not None)
n_res = self._tests.all(lambda u: u.address.nr == 100)
self.assertTrue(res)
self.assertFalse(n_res)
def test_first(self): def test_first(self):
results = [] results = []
for user in self._tests: for user in self._tests: