Changed func type from str to Callable

This commit is contained in:
Sven Heidemann 2021-07-26 15:32:28 +02:00
parent 0cae3428b9
commit b7be439381
5 changed files with 32 additions and 26 deletions

View File

@ -1,7 +1,9 @@
from collections import Callable
from cpl_query._query.where_query import where_query from cpl_query._query.where_query import where_query
from cpl_query.extension.iterable_abc import IterableABC from cpl_query.extension.iterable_abc import IterableABC
def any_query(_list: IterableABC, _func: str) -> bool: def any_query(_list: IterableABC, _func: Callable) -> bool:
result = where_query(_list, _func) result = where_query(_list, _func)
return len(result) > 0 return len(result) > 0

View File

@ -1,13 +1,17 @@
from collections import Callable
from cpl_query.extension.iterable_abc import IterableABC from cpl_query.extension.iterable_abc import IterableABC
def where_query(_list: IterableABC, _func: str) -> IterableABC: def where_query(_list: IterableABC, _func: Callable) -> IterableABC:
result = IterableABC() result = IterableABC()
for element in _list: for element in _list:
element_type = type(element).__name__ element_type = type(element).__name__
if element_type in _func: if _func(element):
func = _func.replace(element_type, 'element') result.append(element)
if eval(func): # if element_type in _func:
result.append(element) # func = _func.replace(element_type, 'element')
# if eval(func):
# result.append(element)
return result return result

View File

@ -16,7 +16,7 @@ class Iterable(IterableABC):
def __init__(self): def __init__(self):
IterableABC.__init__(self) IterableABC.__init__(self)
def any(self, func: str) -> bool: def any(self, func: Callable) -> bool:
return any_query(self, func) return any_query(self, func)
def first(self) -> any: def first(self) -> any:
@ -44,7 +44,7 @@ class Iterable(IterableABC):
def single_or_default(self) -> Optional[any]: def single_or_default(self) -> Optional[any]:
return single_or_default_query(self) return single_or_default_query(self)
def where(self, func: str) -> IterableABC: def where(self, func: Callable) -> IterableABC:
res = where_query(self, func) res = where_query(self, func)
res.__class__ = Iterable res.__class__ = Iterable
return res return res

View File

@ -9,7 +9,7 @@ class IterableABC(ABC, list):
list.__init__(self) list.__init__(self)
@abstractmethod @abstractmethod
def any(self, func: str) -> bool: pass def any(self, func: Callable) -> bool: pass
@abstractmethod @abstractmethod
def first(self) -> any: pass def first(self) -> any: pass
@ -33,4 +33,4 @@ class IterableABC(ABC, list):
def single_or_default(self) -> Optional[any]: pass def single_or_default(self) -> Optional[any]: pass
@abstractmethod @abstractmethod
def where(self, func: str) -> 'IterableABC': pass def where(self, func: Callable) -> 'IterableABC': pass

View File

@ -41,8 +41,8 @@ class QueryTest(unittest.TestCase):
if user.address.nr == 10: if user.address.nr == 10:
results.append(user) results.append(user)
res = self._tests.any(f'User.address.nr == 10') res = self._tests.any(lambda u: u.address.nr == 10)
n_res = self._tests.any(f'User.address.nr == 100') n_res = self._tests.any(lambda u: u.address.nr == 100)
self.assertTrue(res) self.assertTrue(res)
self.assertFalse(n_res) self.assertFalse(n_res)
@ -53,8 +53,8 @@ class QueryTest(unittest.TestCase):
if user.address.nr == 10: if user.address.nr == 10:
results.append(user) results.append(user)
res = self._tests.where(f'User.address.nr == 10') res = self._tests.where(lambda u: u.address.nr == 10)
s_res = self._tests.where(f'User.address.nr == 10').first() s_res = self._tests.where(lambda u: u.address.nr == 10).first()
self.assertEqual(len(res), len(results)) self.assertEqual(len(res), len(results))
self.assertIsNotNone(s_res) self.assertIsNotNone(s_res)
@ -65,9 +65,9 @@ class QueryTest(unittest.TestCase):
if user.address.nr == 10: if user.address.nr == 10:
results.append(user) results.append(user)
res = self._tests.where(f'User.address.nr == 10') res = self._tests.where(lambda u: u.address.nr == 10)
s_res = self._tests.where(f'User.address.nr == 10').first_or_default() s_res = self._tests.where(lambda u: u.address.nr == 10).first_or_default()
sn_res = self._tests.where(f'User.address.nr == 11').first_or_default() sn_res = self._tests.where(lambda u: u.address.nr == 11).first_or_default()
self.assertEqual(len(res), len(results)) self.assertEqual(len(res), len(results))
self.assertIsNotNone(s_res) self.assertIsNotNone(s_res)
@ -77,8 +77,6 @@ class QueryTest(unittest.TestCase):
users = [] users = []
self._tests.for_each( self._tests.for_each(
lambda user: ( lambda user: (
# Console.write_line(f'User: {user.name} | '),
# Console.write(f'Address: {user.address.street}'),
users.append(user) users.append(user)
) )
) )
@ -114,7 +112,8 @@ class QueryTest(unittest.TestCase):
self.assertEqual(res, s_res) self.assertEqual(res, s_res)
def test_then_by_descending(self): def test_then_by_descending(self):
res = self._tests.order_by_descending(lambda user: user.address.street[0]).then_by_descending(lambda user: user.address.nr) res = self._tests.order_by_descending(lambda user: user.address.street[0]).then_by_descending(
lambda user: user.address.nr)
s_res = self._tests s_res = self._tests
s_res.sort(key=lambda user: (user.address.street[0], user.address.nr), reverse=True) s_res.sort(key=lambda user: (user.address.street[0], user.address.nr), reverse=True)
@ -122,16 +121,16 @@ class QueryTest(unittest.TestCase):
self.assertEqual(res, s_res) self.assertEqual(res, s_res)
def test_single(self): def test_single(self):
res = self._tests.where(f'User.address.nr == {self._t_user.address.nr}') res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr)
s_res = self._tests.where(f'User.address.nr == {self._t_user.address.nr}').single() s_res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr).single()
self.assertEqual(len(res), 1) self.assertEqual(len(res), 1)
self.assertEqual(self._t_user, s_res) self.assertEqual(self._t_user, s_res)
def test_single_or_default(self): def test_single_or_default(self):
res = self._tests.where(f'User.address.nr == {self._t_user.address.nr}') res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr)
s_res = self._tests.where(f'User.address.nr == {self._t_user.address.nr}').single_or_default() s_res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr).single_or_default()
sn_res = self._tests.where(f'User.address.nr == {self._t_user.address.nr + 1}').single_or_default() sn_res = self._tests.where(lambda u: u.address.nr == self._t_user.address.nr + 1).single_or_default()
self.assertEqual(len(res), 1) self.assertEqual(len(res), 1)
self.assertEqual(self._t_user, s_res) self.assertEqual(self._t_user, s_res)
@ -143,5 +142,6 @@ class QueryTest(unittest.TestCase):
if user.address.nr == 5: if user.address.nr == 5:
results.append(user) results.append(user)
res = self._tests.where('User.address.nr == 5') res = self._tests.where(lambda u: u.address.nr == 5)
# res = self._tests.where('User.address.nr == 5')
self.assertEqual(len(results), len(res)) self.assertEqual(len(results), len(res))