diff --git a/src/cpl_query/_extension/iterable.py b/src/cpl_query/_extension/iterable.py index 14f681a8..28f9bcfc 100644 --- a/src/cpl_query/_extension/iterable.py +++ b/src/cpl_query/_extension/iterable.py @@ -5,6 +5,7 @@ from cpl_query.extension.ordered_iterable_abc import OrderedIterableABC from .._query.all_query import all_query from .._query.any_query import any_query from .._query.avg_query import avg_query +from .._query.contains_query import contains_query from .._query.first_query import first_or_default_query, first_query from .._query.for_each_query import for_each_query from .._query.order_by import order_by_query, order_by_descending_query @@ -27,6 +28,9 @@ class Iterable(IterableABC): def average(self, t: type, func: Callable) -> Union[int, float, complex]: return avg_query(self, t, func) + def contains(self, value: object) -> bool: + return contains_query(self, value) + def first(self) -> any: return first_query(self) diff --git a/src/cpl_query/_query/contains_query.py b/src/cpl_query/_query/contains_query.py new file mode 100644 index 00000000..e235c65d --- /dev/null +++ b/src/cpl_query/_query/contains_query.py @@ -0,0 +1,5 @@ +from cpl_query.extension.iterable_abc import IterableABC + + +def contains_query(_list: IterableABC, value: object) -> bool: + return value in _list diff --git a/src/cpl_query/extension/iterable_abc.py b/src/cpl_query/extension/iterable_abc.py index 85aaf734..ba80b979 100644 --- a/src/cpl_query/extension/iterable_abc.py +++ b/src/cpl_query/extension/iterable_abc.py @@ -17,6 +17,9 @@ class IterableABC(ABC, list): @abstractmethod def average(self, t: type, func: Callable) -> Union[int, float, complex]: pass + @abstractmethod + def contains(self, value: object) -> bool: pass + @abstractmethod def first(self) -> any: pass diff --git a/src/cpl_query/tests/query_test.py b/src/cpl_query/tests/query_test.py index d26f2f88..b836bb6c 100644 --- a/src/cpl_query/tests/query_test.py +++ b/src/cpl_query/tests/query_test.py @@ -79,6 +79,10 @@ class QueryTest(unittest.TestCase): self.assertRaises(InvalidTypeException, invalid) self.assertRaises(WrongTypeException, wrong) + def test_contains(self): + self.assertTrue(self._tests.contains(self._t_user)) + self.assertFalse(self._tests.contains(User("Test", None))) + def test_first(self): results = [] for user in self._tests: