diff --git a/src/cpl_query/enumerable/enumerable.py b/src/cpl_query/enumerable/enumerable.py index 10a528ce..6291cd5d 100644 --- a/src/cpl_query/enumerable/enumerable.py +++ b/src/cpl_query/enumerable/enumerable.py @@ -206,21 +206,17 @@ class Enumerable(EnumerableABC): if _func is None: _func = _default_lambda - result = Enumerable() - result.extend(_func(_o) for _o in self) - return result + _l = [_func(_o) for _o in self] + return Enumerable(self._type if len(_l) < 1 else type(_l[0]), _l) def select_many(self, _func: Callable = None) -> EnumerableABC: if _func is None: _func = _default_lambda - result = Enumerable() - # The line below is pain. I don't understand anything of it... + # The line below is pain. I don't understand anything of the list comprehension... # written on 09.11.2022 by Sven Heidemann - elements = [_a for _o in self for _a in _func(_o)] - - result.extend(elements) - return result + _l = [_a for _o in self for _a in _func(_o)] + return Enumerable(self._type if len(_l) < 1 else type(_l[0]), _l) def single(self: EnumerableABC) -> any: if self is None: @@ -237,9 +233,9 @@ class Enumerable(EnumerableABC): if self is None: raise ArgumentNoneException(ExceptionArgument.list) - if len(self) > 1: + if self.count() > 1: raise IndexError('Found more than one element') - elif len(self) == 0: + elif self.count() == 0: return None return self.element_at(0) diff --git a/src/cpl_query/iterable/iterable.py b/src/cpl_query/iterable/iterable.py index 87f876de..154630cf 100644 --- a/src/cpl_query/iterable/iterable.py +++ b/src/cpl_query/iterable/iterable.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, Iterable as IterableType from cpl_query._helper import is_number from cpl_query.exceptions import ArgumentNoneException, ExceptionArgument, InvalidTypeException, IndexOutOfRangeException @@ -12,7 +12,7 @@ def _default_lambda(x: object): class Iterable(IterableABC): - def __init__(self, t: type = None, values: list = None): + def __init__(self, t: type = None, values: IterableType = None): IterableABC.__init__(self, t, values) def all(self, _func: Callable = None) -> bool: diff --git a/src/cpl_query/iterable/iterable_abc.py b/src/cpl_query/iterable/iterable_abc.py index 7b32ea90..b2b6ccc2 100644 --- a/src/cpl_query/iterable/iterable_abc.py +++ b/src/cpl_query/iterable/iterable_abc.py @@ -11,7 +11,7 @@ class IterableABC(SequenceABC, QueryableABC): """ @abstractmethod - def __init__(self, t: type = None, values: list = None): + def __init__(self, t: type = None, values: Iterable = None): SequenceABC.__init__(self, t, values) def __getitem__(self, n) -> object: diff --git a/unittests/unittests_query/enumerable_query_test_case.py b/unittests/unittests_query/enumerable_query_test_case.py index 2f659abf..d95275e0 100644 --- a/unittests/unittests_query/enumerable_query_test_case.py +++ b/unittests/unittests_query/enumerable_query_test_case.py @@ -156,6 +156,8 @@ class EnumerableQueryTestCase(unittest.TestCase): self.assertEqual(len(res), len(results)) self.assertEqual(res.element_at(0), s_res) + self.assertEqual(res.element_at(0), res.first()) + self.assertEqual(res.first(), res.first()) def test_first_or_default(self): results = [] diff --git a/unittests/unittests_query/iterable_query_test_case.py b/unittests/unittests_query/iterable_query_test_case.py index a390e31c..c3618e17 100644 --- a/unittests/unittests_query/iterable_query_test_case.py +++ b/unittests/unittests_query/iterable_query_test_case.py @@ -151,6 +151,8 @@ class IterableQueryTestCase(unittest.TestCase): self.assertEqual(len(res), len(results)) self.assertEqual(res[0], s_res) + self.assertEqual(res[0], res.first()) + self.assertEqual(res.first(), res.first()) def test_first_or_default(self): results = [] diff --git a/unittests/unittests_query/performance_test_case.py b/unittests/unittests_query/performance_test_case.py new file mode 100644 index 00000000..1be52d01 --- /dev/null +++ b/unittests/unittests_query/performance_test_case.py @@ -0,0 +1,45 @@ +import sys +import timeit +import unittest + +from cpl_query.enumerable import Enumerable +from cpl_query.extension.list import List +from cpl_query.iterable import Iterable + +VALUES = 1000 +COUNT = 100 + + +class PerformanceTestCase(unittest.TestCase): + + def setUp(self): + i = 0 + self.values = [] + while i < VALUES: + self.values.append(i) + i += 1 + + # def test_range(self): + # default = timeit.timeit(lambda: list(self.values), number=COUNT) + # enumerable = timeit.timeit(lambda: Enumerable(int, self.values), number=COUNT) + # iterable = timeit.timeit(lambda: Iterable(int, self.values), number=COUNT) + # + # print(f'd: {default}') + # print(f'e: {enumerable}') + # print(f'i: {iterable}') + # + # self.assertLess(default, enumerable) + # self.assertLess(default, iterable) + + def test_where_single(self): + print(Enumerable(int, self.values).where(lambda x: x == COUNT).single_or_default()) + # default = timeit.timeit(lambda: [x for x in list(self.values) if x == 50], number=COUNT) + # enumerable = timeit.timeit(lambda: Enumerable(int, self.values).where(lambda x: x == 50).single(), number=COUNT) + # iterable = timeit.timeit(lambda: Iterable(int, self.values).where(lambda x: x == 50).single(), number=COUNT) + # + # print(f'd: {default}') + # print(f'e: {enumerable}') + # print(f'i: {iterable}') + # + # self.assertLess(default, enumerable) + # self.assertLess(default, iterable)