diff --git a/src/cpl_query/_extension/iterable.py b/src/cpl_query/_extension/iterable.py index 28f9bcfc..f925851e 100644 --- a/src/cpl_query/_extension/iterable.py +++ b/src/cpl_query/_extension/iterable.py @@ -6,6 +6,7 @@ from .._query.all_query import all_query from .._query.any_query import any_query from .._query.avg_query import avg_query from .._query.contains_query import contains_query +from .._query.count_query import count_query from .._query.first_query import first_or_default_query, first_query from .._query.for_each_query import for_each_query from .._query.order_by import order_by_query, order_by_descending_query @@ -31,6 +32,9 @@ class Iterable(IterableABC): def contains(self, value: object) -> bool: return contains_query(self, value) + def count(self, func: Callable = None) -> int: + return count_query(self, func) + def first(self) -> any: return first_query(self) diff --git a/src/cpl_query/_query/count_query.py b/src/cpl_query/_query/count_query.py new file mode 100644 index 00000000..246feee1 --- /dev/null +++ b/src/cpl_query/_query/count_query.py @@ -0,0 +1,15 @@ +from collections import Callable + +from cpl_query._query.where_query import where_query +from cpl_query.exceptions import ArgumentNoneException, ExceptionArgument +from cpl_query.extension.iterable_abc import IterableABC + + +def count_query(_list: IterableABC, _func: Callable = None) -> int: + if _list is None: + raise ArgumentNoneException(ExceptionArgument.list) + + if _func is None: + return len(_list) + + return len(where_query(_list, _func)) diff --git a/src/cpl_query/extension/iterable_abc.py b/src/cpl_query/extension/iterable_abc.py index ba80b979..6a85d604 100644 --- a/src/cpl_query/extension/iterable_abc.py +++ b/src/cpl_query/extension/iterable_abc.py @@ -20,6 +20,9 @@ class IterableABC(ABC, list): @abstractmethod def contains(self, value: object) -> bool: pass + @abstractmethod + def count(self, func: Callable) -> int: pass + @abstractmethod def first(self) -> any: pass diff --git a/src/cpl_query/tests/query_test.py b/src/cpl_query/tests/query_test.py index b836bb6c..cb7a42a4 100644 --- a/src/cpl_query/tests/query_test.py +++ b/src/cpl_query/tests/query_test.py @@ -83,6 +83,10 @@ class QueryTest(unittest.TestCase): self.assertTrue(self._tests.contains(self._t_user)) self.assertFalse(self._tests.contains(User("Test", None))) + def test_count(self): + self.assertEqual(len(self._tests), self._tests.count()) + self.assertEqual(1, self._tests.count(lambda u: u == self._t_user)) + def test_first(self): results = [] for user in self._tests: