From 5d6e7677dee6107c765f839e0dbeaebb10c55074 Mon Sep 17 00:00:00 2001 From: Sven Heidemann Date: Tue, 4 Apr 2023 14:42:43 +0200 Subject: [PATCH] Improved typing in query --- src/cpl_query/base/queryable_abc.py | 7 +++++-- src/cpl_query/base/sequence.py | 2 +- unittests/unittests_query/iterable_query_test_case.py | 6 ++++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/cpl_query/base/queryable_abc.py b/src/cpl_query/base/queryable_abc.py index 578c3c53..95cadadd 100644 --- a/src/cpl_query/base/queryable_abc.py +++ b/src/cpl_query/base/queryable_abc.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Callable, Union, Iterable, Self +from typing import Optional, Callable, Union, Iterable, Self, Any from typing import TYPE_CHECKING from cpl_query._helper import is_number @@ -378,7 +378,10 @@ class QueryableABC(Sequence): if _func is None: _func = _default_lambda - return type(self)(object, [_func(_o) for _o in self]) + _l = [_func(_o) for _o in self] + _t = type(_l[0]) if len(_l) > 0 else Any + + return type(self)(_t, _l) def select_many(self, _func: Callable) -> Self: r"""Flattens resulting lists to one diff --git a/src/cpl_query/base/sequence.py b/src/cpl_query/base/sequence.py index 95aaaa0a..22ea0f34 100644 --- a/src/cpl_query/base/sequence.py +++ b/src/cpl_query/base/sequence.py @@ -25,7 +25,7 @@ class Sequence(ABC): return self.to_list().__len__() @classmethod - def __class_getitem__(cls, _t: type): + def __class_getitem__(cls, _t: type) -> type: return _t def __repr__(self): diff --git a/unittests/unittests_query/iterable_query_test_case.py b/unittests/unittests_query/iterable_query_test_case.py index 2b44c25e..3592d4f1 100644 --- a/unittests/unittests_query/iterable_query_test_case.py +++ b/unittests/unittests_query/iterable_query_test_case.py @@ -270,6 +270,12 @@ class IterableQueryTestCase(unittest.TestCase): self.assertEqual(res.to_list(), l_res) def test_select(self): + def test(_l: List) -> List[int]: + return _l.select(lambda user: user.address.nr) + + self.assertEqual(List[User], self._tests.type) + self.assertEqual(List[int], test(self._tests).type) + range_list = List(int, range(0, 100)) selected_range = range_list.select(lambda x: x + 1)