Improved typing in query

This commit is contained in:
Sven Heidemann 2023-04-04 14:42:43 +02:00
parent 01309e3124
commit 5d6e7677de
3 changed files with 12 additions and 3 deletions

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable, Union, Iterable, Self from typing import Optional, Callable, Union, Iterable, Self, Any
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from cpl_query._helper import is_number from cpl_query._helper import is_number
@ -378,7 +378,10 @@ class QueryableABC(Sequence):
if _func is None: if _func is None:
_func = _default_lambda _func = _default_lambda
return type(self)(object, [_func(_o) for _o in self]) _l = [_func(_o) for _o in self]
_t = type(_l[0]) if len(_l) > 0 else Any
return type(self)(_t, _l)
def select_many(self, _func: Callable) -> Self: def select_many(self, _func: Callable) -> Self:
r"""Flattens resulting lists to one r"""Flattens resulting lists to one

View File

@ -25,7 +25,7 @@ class Sequence(ABC):
return self.to_list().__len__() return self.to_list().__len__()
@classmethod @classmethod
def __class_getitem__(cls, _t: type): def __class_getitem__(cls, _t: type) -> type:
return _t return _t
def __repr__(self): def __repr__(self):

View File

@ -270,6 +270,12 @@ class IterableQueryTestCase(unittest.TestCase):
self.assertEqual(res.to_list(), l_res) self.assertEqual(res.to_list(), l_res)
def test_select(self): def test_select(self):
def test(_l: List) -> List[int]:
return _l.select(lambda user: user.address.nr)
self.assertEqual(List[User], self._tests.type)
self.assertEqual(List[int], test(self._tests).type)
range_list = List(int, range(0, 100)) range_list = List(int, range(0, 100))
selected_range = range_list.select(lambda x: x + 1) selected_range = range_list.select(lambda x: x + 1)