diff --git a/src/cpl_query/_extension/iterable.py b/src/cpl_query/_extension/iterable.py index 2db2c507..14f681a8 100644 --- a/src/cpl_query/_extension/iterable.py +++ b/src/cpl_query/_extension/iterable.py @@ -1,9 +1,10 @@ -from typing import Optional, Callable +from typing import Optional, Callable, Union from cpl_query._extension.ordered_iterable import OrderedIterable from cpl_query.extension.ordered_iterable_abc import OrderedIterableABC from .._query.all_query import all_query from .._query.any_query import any_query +from .._query.avg_query import avg_query from .._query.first_query import first_or_default_query, first_query from .._query.for_each_query import for_each_query from .._query.order_by import order_by_query, order_by_descending_query @@ -23,6 +24,9 @@ class Iterable(IterableABC): def all(self, func: Callable) -> bool: return all_query(self, func) + def average(self, t: type, func: Callable) -> Union[int, float, complex]: + return avg_query(self, t, func) + def first(self) -> any: return first_query(self) diff --git a/src/cpl_query/_query/avg_query.py b/src/cpl_query/_query/avg_query.py new file mode 100644 index 00000000..bd884893 --- /dev/null +++ b/src/cpl_query/_query/avg_query.py @@ -0,0 +1,22 @@ +from typing import Callable, Union + +from cpl_query.exceptions import InvalidTypeException, WrongTypeException +from cpl_query.extension.iterable_abc import IterableABC + + +def avg_query(_list: IterableABC, _t: type, _func: Callable) -> Union[int, float, complex]: + average = 0 + count = len(_list) + + if _t != int and _t != float and _t != complex: + raise InvalidTypeException() + + for element in _list: + value = _func(element) + if type(value) != _t: + raise WrongTypeException() + + average += value + + return average / count + diff --git a/src/cpl_query/exceptions.py b/src/cpl_query/exceptions.py new file mode 100644 index 00000000..1bf256e9 --- /dev/null +++ b/src/cpl_query/exceptions.py @@ -0,0 +1,6 @@ +class InvalidTypeException(Exception): + pass + + +class WrongTypeException(Exception): + pass diff --git a/src/cpl_query/extension/iterable_abc.py b/src/cpl_query/extension/iterable_abc.py index b4ce047d..85aaf734 100644 --- a/src/cpl_query/extension/iterable_abc.py +++ b/src/cpl_query/extension/iterable_abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Callable +from typing import Optional, Callable, Union class IterableABC(ABC, list): @@ -14,6 +14,9 @@ class IterableABC(ABC, list): @abstractmethod def all(self, func: Callable) -> bool: pass + @abstractmethod + def average(self, t: type, func: Callable) -> Union[int, float, complex]: pass + @abstractmethod def first(self) -> any: pass diff --git a/src/cpl_query/tests/query_test.py b/src/cpl_query/tests/query_test.py index 5020183f..0af4dbaa 100644 --- a/src/cpl_query/tests/query_test.py +++ b/src/cpl_query/tests/query_test.py @@ -3,6 +3,7 @@ import unittest from random import randint from cpl.utils import String +from cpl_query.exceptions import InvalidTypeException, WrongTypeException from cpl_query.extension.list import List from cpl_query.tests.models import User, Address @@ -59,6 +60,28 @@ class QueryTest(unittest.TestCase): self.assertTrue(res) self.assertFalse(n_res) + def test_avg(self): + avg = 0 + for user in self._tests: + avg += user.address.nr + + avg = avg / len(self._tests) + res = self._tests.average(int, lambda u: u.address.nr) + + self.assertEqual(res, avg) + + def test_avg_invalid(self): + def _(): + res = self._tests.average(str, lambda u: u.address.nr) + + self.assertRaises(InvalidTypeException, _) + + def test_avg_wrong(self): + def _(): + res = self._tests.average(int, lambda u: u.address.street) + + self.assertRaises(WrongTypeException, _) + def test_first(self): results = [] for user in self._tests: