From b7be43938149f1aa8ff7ddf1e28d57a31e1bdbd8 Mon Sep 17 00:00:00 2001 From: Sven Heidemann Date: Mon, 26 Jul 2021 15:32:28 +0200 Subject: [PATCH] Changed func type from str to Callable --- src/cpl_query/_query/any_query.py | 4 +++- src/cpl_query/_query/where_query.py | 14 +++++++---- src/cpl_query/extension/iterable.py | 4 ++-- src/cpl_query/extension/iterable_abc.py | 4 ++-- src/cpl_query/tests/query_test.py | 32 ++++++++++++------------- 5 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/cpl_query/_query/any_query.py b/src/cpl_query/_query/any_query.py index 8ff9456e..ff85fa98 100644 --- a/src/cpl_query/_query/any_query.py +++ b/src/cpl_query/_query/any_query.py @@ -1,7 +1,9 @@ +from collections import Callable + from cpl_query._query.where_query import where_query from cpl_query.extension.iterable_abc import IterableABC -def any_query(_list: IterableABC, _func: str) -> bool: +def any_query(_list: IterableABC, _func: Callable) -> bool: result = where_query(_list, _func) return len(result) > 0 diff --git a/src/cpl_query/_query/where_query.py b/src/cpl_query/_query/where_query.py index 97e3eefc..3c73b5b2 100644 --- a/src/cpl_query/_query/where_query.py +++ b/src/cpl_query/_query/where_query.py @@ -1,13 +1,17 @@ +from collections import Callable + from cpl_query.extension.iterable_abc import IterableABC -def where_query(_list: IterableABC, _func: str) -> IterableABC: +def where_query(_list: IterableABC, _func: Callable) -> IterableABC: result = IterableABC() for element in _list: element_type = type(element).__name__ - if element_type in _func: - func = _func.replace(element_type, 'element') - if eval(func): - result.append(element) + if _func(element): + result.append(element) + # if element_type in _func: + # func = _func.replace(element_type, 'element') + # if eval(func): + # result.append(element) return result diff --git a/src/cpl_query/extension/iterable.py b/src/cpl_query/extension/iterable.py index 50926fa6..d5844620 100644 --- a/src/cpl_query/extension/iterable.py +++ b/src/cpl_query/extension/iterable.py @@ -16,7 +16,7 @@ class Iterable(IterableABC): def __init__(self): IterableABC.__init__(self) - def any(self, func: str) -> bool: + def any(self, func: Callable) -> bool: return any_query(self, func) def first(self) -> any: @@ -44,7 +44,7 @@ class Iterable(IterableABC): def single_or_default(self) -> Optional[any]: return single_or_default_query(self) - def where(self, func: str) -> IterableABC: + def where(self, func: Callable) -> IterableABC: res = where_query(self, func) res.__class__ = Iterable return res diff --git a/src/cpl_query/extension/iterable_abc.py b/src/cpl_query/extension/iterable_abc.py index e4bf8f80..63d5d97b 100644 --- a/src/cpl_query/extension/iterable_abc.py +++ b/src/cpl_query/extension/iterable_abc.py @@ -9,7 +9,7 @@ class IterableABC(ABC, list): list.__init__(self) @abstractmethod - def any(self, func: str) -> bool: pass + def any(self, func: Callable) -> bool: pass @abstractmethod def first(self) -> any: pass @@ -33,4 +33,4 @@ class IterableABC(ABC, list): def single_or_default(self) -> Optional[any]: pass @abstractmethod - def where(self, func: str) -> 'IterableABC': pass + def where(self, func: Callable) -> 'IterableABC': pass diff --git a/src/cpl_query/tests/query_test.py b/src/cpl_query/tests/query_test.py index 4c9ca6dc..2d0e8784 100644 --- a/src/cpl_query/tests/query_test.py +++ b/src/cpl_query/tests/query_test.py @@ -41,8 +41,8 @@ class QueryTest(unittest.TestCase): if user.address.nr == 10: results.append(user) - res = self._tests.any(f'User.address.nr == 10') - n_res = self._tests.any(f'User.address.nr == 100') + res = self._tests.any(lambda u: u.address.nr == 10) + n_res = self._tests.any(lambda u: u.address.nr == 100) self.assertTrue(res) self.assertFalse(n_res) @@ -53,8 +53,8 @@ class QueryTest(unittest.TestCase): if user.address.nr == 10: results.append(user) - res = self._tests.where(f'User.address.nr == 10') - s_res = self._tests.where(f'User.address.nr == 10').first() + res = self._tests.where(lambda u: u.address.nr == 10) + s_res = self._tests.where(lambda u: u.address.nr == 10).first() self.assertEqual(len(res), len(results)) self.assertIsNotNone(s_res) @@ -65,9 +65,9 @@ class QueryTest(unittest.TestCase): if user.address.nr == 10: results.append(user) - res = self._tests.where(f'User.address.nr == 10') - s_res = self._tests.where(f'User.address.nr == 10').first_or_default() - sn_res = self._tests.where(f'User.address.nr == 11').first_or_default() + res = self._tests.where(lambda u: u.address.nr == 10) + s_res = self._tests.where(lambda u: u.address.nr == 10).first_or_default() + sn_res = self._tests.where(lambda u: u.address.nr == 11).first_or_default() self.assertEqual(len(res), len(results)) self.assertIsNotNone(s_res) @@ -77,8 +77,6 @@ class QueryTest(unittest.TestCase): users = [] self._tests.for_each( lambda user: ( - # Console.write_line(f'User: {user.name} | '), - # Console.write(f'Address: {user.address.street}'), users.append(user) ) ) @@ -114,7 +112,8 @@ class QueryTest(unittest.TestCase): self.assertEqual(res, s_res) def test_then_by_descending(self): - res = self._tests.order_by_descending(lambda user: user.address.street[0]).then_by_descending(lambda user: user.address.nr) + res = self._tests.order_by_descending(lambda user: user.address.street[0]).then_by_descending( + lambda user: user.address.nr) s_res = self._tests s_res.sort(key=lambda user: (user.address.street[0], user.address.nr), reverse=True) @@ -122,16 +121,16 @@ class QueryTest(unittest.TestCase): self.assertEqual(res, s_res) def test_single(self): - res = self._tests.where(f'User.address.nr == {self._t_user.address.nr}') - s_res = self._tests.where(f'User.address.nr == {self._t_user.address.nr}').single() + res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr) + s_res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr).single() self.assertEqual(len(res), 1) self.assertEqual(self._t_user, s_res) def test_single_or_default(self): - res = self._tests.where(f'User.address.nr == {self._t_user.address.nr}') - s_res = self._tests.where(f'User.address.nr == {self._t_user.address.nr}').single_or_default() - sn_res = self._tests.where(f'User.address.nr == {self._t_user.address.nr + 1}').single_or_default() + res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr) + s_res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr).single_or_default() + sn_res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr + 1).single_or_default() self.assertEqual(len(res), 1) self.assertEqual(self._t_user, s_res) @@ -143,5 +142,6 @@ class QueryTest(unittest.TestCase): if user.address.nr == 5: results.append(user) - res = self._tests.where('User.address.nr == 5') + res = self._tests.where(lambda u: u.address.nr == 5) + # res = self._tests.where('User.address.nr == 5') self.assertEqual(len(results), len(res))