diff --git a/src/cpl_query/_query/avg.py b/src/cpl_query/_query/avg.py index ff299a3d..09fbbcc4 100644 --- a/src/cpl_query/_query/avg.py +++ b/src/cpl_query/_query/avg.py @@ -9,6 +9,9 @@ def avg_query(_list: IterableABC, _func: Callable) -> Union[int, float, complex] if _list is None: raise ArgumentNoneException(ExceptionArgument.list) + if _func is None and not is_number(_list.type): + raise InvalidTypeException() + average = 0 count = len(_list) @@ -19,9 +22,6 @@ def avg_query(_list: IterableABC, _func: Callable) -> Union[int, float, complex] else: value = element - if _func is None and type(element) != _list.type or not is_number(type(value)): - raise WrongTypeException() - average += value return average / count diff --git a/src/cpl_query/_query/max_min.py b/src/cpl_query/_query/max_min.py index a6cae896..d9c2c6a5 100644 --- a/src/cpl_query/_query/max_min.py +++ b/src/cpl_query/_query/max_min.py @@ -10,6 +10,9 @@ def max_query(_list: IterableABC, _func: Callable) -> Union[int, float, complex] if _list is None: raise ArgumentNoneException(ExceptionArgument.list) + if _func is None and not is_number(_list.type): + raise InvalidTypeException() + max_value = 0 for element in _list: if _func is not None: @@ -17,9 +20,6 @@ def max_query(_list: IterableABC, _func: Callable) -> Union[int, float, complex] else: value = element - if _func is None and type(value) != _list.type or not is_number(type(value)): - raise WrongTypeException() - if value > max_value: max_value = value @@ -30,6 +30,9 @@ def min_query(_list: IterableABC, _func: Callable) -> Union[int, float, complex] if _list is None: raise ArgumentNoneException(ExceptionArgument.list) + if _func is None and not is_number(_list.type): + raise InvalidTypeException() + min_value = 0 is_first = True for element in _list: @@ -38,9 +41,6 @@ def min_query(_list: IterableABC, _func: Callable) -> Union[int, float, complex] else: value = element - if _func is None and type(value) != _list.type or not is_number(type(value)): - raise WrongTypeException() - if is_first: min_value = value is_first = False diff --git a/src/cpl_query/tests/query_test.py b/src/cpl_query/tests/query_test.py index 7426bd5f..61cd1de9 100644 --- a/src/cpl_query/tests/query_test.py +++ b/src/cpl_query/tests/query_test.py @@ -70,10 +70,11 @@ class QueryTest(unittest.TestCase): self.assertEqual(avg, res) - def wrong(): - e_res = self._tests.average(lambda u: u.address.street) + def invalid(): + tests = List(str, ['hello', 'world']) + e_res = tests.average() - self.assertRaises(WrongTypeException, wrong) + self.assertRaises(InvalidTypeException, invalid) tests = List(int, list(range(0, 100))) self.assertEqual(sum(tests) / len(tests), tests.average()) @@ -174,10 +175,11 @@ class QueryTest(unittest.TestCase): tests = List(values=list(range(0, 100))) self.assertEqual(99, tests.max()) - def wrong(): - e_res = List(str, list([str(v) for v in range(0, 100)])).max() + def invalid(): + tests = List(str, ['hello', 'world']) + e_res = tests.average() - self.assertRaises(WrongTypeException, wrong) + self.assertRaises(InvalidTypeException, invalid) def test_min(self): res = self._tests.min(lambda u: u.address.nr) @@ -186,10 +188,11 @@ class QueryTest(unittest.TestCase): tests = List(values=list(range(0, 100))) self.assertEqual(0, tests.min()) - def wrong(): - e_res = List(str, list([str(v) for v in range(0, 100)])).min() + def invalid(): + tests = List(str, ['hello', 'world']) + e_res = tests.average() - self.assertRaises(WrongTypeException, wrong) + self.assertRaises(InvalidTypeException, invalid) def test_order_by(self): res = self._tests.order_by(lambda user: user.address.street)