diff --git a/src/cpl_query/extension/list.py b/src/cpl_query/extension/list.py index b216b18f..7496d43c 100644 --- a/src/cpl_query/extension/list.py +++ b/src/cpl_query/extension/list.py @@ -10,3 +10,9 @@ class List(Iterable): if values is not None: self.extend(values) + + def append(self, __object: object) -> None: + if self._type is not None and type(__object) != self._type and not isinstance(type(__object), self._type): + raise Exception(f'Unexpected type: {type(__object)}') + + super().append(__object) diff --git a/src/cpl_query/tests/iterable_test.py b/src/cpl_query/tests/iterable_test.py new file mode 100644 index 00000000..9b12b28e --- /dev/null +++ b/src/cpl_query/tests/iterable_test.py @@ -0,0 +1,28 @@ +import unittest + +from cpl_query.extension.list import List + + +class IterableTest(unittest.TestCase): + + def setUp(self) -> None: + self._list = List(int) + + def _clear(self): + self._list.clear() + self.assertEqual(self._list, []) + + def test_append(self): + self._list.append(1) + self._list.append(2) + self._list.append(3) + + self.assertEqual(self._list, [1, 2, 3]) + self._clear() + + def test_append_wrong_type(self): + self._list.append(1) + self._list.append(2) + + self.assertRaises(Exception, lambda v: self._list.append(v), '3') + self._clear() diff --git a/src/cpl_query/tests/query_test.py b/src/cpl_query/tests/query_test.py index 2d0e8784..9135c434 100644 --- a/src/cpl_query/tests/query_test.py +++ b/src/cpl_query/tests/query_test.py @@ -143,5 +143,4 @@ class QueryTest(unittest.TestCase): results.append(user) res = self._tests.where(lambda u: u.address.nr == 5) - # res = self._tests.where('User.address.nr == 5') self.assertEqual(len(results), len(res)) diff --git a/src/cpl_query/tests/tester.py b/src/cpl_query/tests/tester.py index 98b10ebb..084071c3 100644 --- a/src/cpl_query/tests/tester.py +++ b/src/cpl_query/tests/tester.py @@ -1,5 +1,6 @@ import unittest +from cpl_query.tests.iterable_test import IterableTest from cpl_query.tests.query_test import QueryTest @@ -11,6 +12,7 @@ class Tester: def create(self): loader = unittest.TestLoader() self._suite.addTests(loader.loadTestsFromTestCase(QueryTest)) + self._suite.addTests(loader.loadTestsFromTestCase(IterableTest)) def start(self): runner = unittest.TextTestRunner()