diff --git a/src/cpl_reactive_extensions/observable.py b/src/cpl_reactive_extensions/observable.py index 88cefa3b..b3534576 100644 --- a/src/cpl_reactive_extensions/observable.py +++ b/src/cpl_reactive_extensions/observable.py @@ -1,18 +1,56 @@ -from typing import Callable +from typing import Callable, Union from cpl_reactive_extensions.observer import Observer class Observable: - def __init__(self, callback: Callable): + def __init__(self, callback: Callable = None): self._callback = callback - self._subscriptions: list[Callable] = [] - def _run_subscriptions(self): - for callback in self._subscriptions: - callback() + self._observers: list[Observer] = [] - def subscribe(self, observer: Observer): + @staticmethod + def from_list(values: list): + i = 0 + + def callback(x: Observer): + nonlocal i + if i == len(values): + i = 0 + x.complete() + else: + x.next(values[i]) + i += 1 + + if not x.closed: + callback(x) + + observable = Observable(callback) + return observable + + def subscribe( + self, observer_or_next: Union[Callable, Observer], on_error: Callable = None, on_complete: Callable = None + ) -> Observer: + if isinstance(observer_or_next, Callable): + observer = Observer(observer_or_next, on_error, on_complete) + elif isinstance(observer_or_next, Observable): + observer = observer_or_next + else: + observer = observer_or_next + + if self._callback is None: + self._observers.append(observer) + return observer + + if len(observer._observers) > 0: + for observer in observer._observers: + self._call(observer) + else: + self._call(observer) + + return observer + + def _call(self, observer: Observer): try: self._callback(observer) except Exception as e: diff --git a/src/cpl_reactive_extensions/observer.py b/src/cpl_reactive_extensions/observer.py index 895b6a66..a9fe0f4e 100644 --- a/src/cpl_reactive_extensions/observer.py +++ b/src/cpl_reactive_extensions/observer.py @@ -6,8 +6,14 @@ from cpl_core.type import T class Observer: def __init__(self, on_next: Callable, on_error: Callable = None, on_complete: Callable = None): self._on_next = on_next - self._on_error = on_error if on_error is not None else lambda err: err - self._on_complete = on_complete if on_complete is not None else lambda x: x + self._on_error = on_error + self._on_complete = on_complete + + self._closed = False + + @property + def closed(self) -> bool: + return self._closed def next(self, value: T): self._on_next(value) diff --git a/src/cpl_reactive_extensions/subject.py b/src/cpl_reactive_extensions/subject.py index 679f8ac1..0b0e25bb 100644 --- a/src/cpl_reactive_extensions/subject.py +++ b/src/cpl_reactive_extensions/subject.py @@ -3,15 +3,18 @@ from cpl_reactive_extensions.observable import Observable class Subject(Observable): - def __init__(self): + def __init__(self, _t: type): Observable.__init__(self) + self._t = _t self._value: T = None @property def value(self) -> T: return self._value - def emit(self, value: T): + def next(self, value: T): + if not isinstance(value, self._t): + raise TypeError(f"Expected {self._t.__name__} not {type(value).__name__}") + self._value = value - self._subscriptions() diff --git a/unittests/unittests_reactive_extenstions/reactive_test_case.py b/unittests/unittests_reactive_extenstions/reactive_test_case.py index e0ee498f..fc30f812 100644 --- a/unittests/unittests_reactive_extenstions/reactive_test_case.py +++ b/unittests/unittests_reactive_extenstions/reactive_test_case.py @@ -1,18 +1,28 @@ +import traceback import unittest from threading import Timer +from cpl_core.console import Console from cpl_reactive_extensions.observable import Observable from cpl_reactive_extensions.observer import Observer +from cpl_reactive_extensions.subject import Subject class ReactiveTestCase(unittest.TestCase): def setUp(self): - pass + self._error = False + self._completed = False + + def _on_error(self, ex: Exception): + tb = traceback.format_exc() + Console.error(f"Somthing went wrong: {ex}", tb) + self._error = True + + def _on_complete(self): + self._completed = True def test_observer(self): called = 0 - has_error = False - completed = False test_x = 1 def callback(observer: Observer): @@ -38,34 +48,55 @@ class ReactiveTestCase(unittest.TestCase): called += 1 self.assertEqual(test_x, x) - def on_err(): - nonlocal has_error - has_error = True - - def on_complete(): - nonlocal completed - completed = True - self.assertEqual(called, 0) - self.assertFalse(has_error) - self.assertFalse(completed) + self.assertFalse(self._error) + self.assertFalse(self._completed) observable.subscribe( - Observer( - on_next, - on_err, - on_complete, - ) + on_next, + self._on_error, + self._on_complete, ) self.assertEqual(called, 3) - self.assertFalse(has_error) - self.assertFalse(completed) + self.assertFalse(self._error) + self.assertFalse(self._completed) def complete(): self.assertEqual(called, 4) - self.assertFalse(has_error) - self.assertTrue(completed) + self.assertFalse(self._error) + self.assertTrue(self._completed) Timer(1.0, complete).start() + def test_observable_from(self): + expected_x = 1 + + def _next(x): + nonlocal expected_x + self.assertEqual(expected_x, x) + expected_x += 1 + + observable = Observable.from_list([1, 2, 3, 4]) + observable.subscribe( + _next, + self._on_error, + ) + self.assertFalse(self._error) + def test_subject(self): - pass + expected_x = 1 + + def _next(x): + nonlocal expected_x + self.assertEqual(expected_x, x) + expected_x += 1 + if expected_x == 4: + expected_x = 1 + + subject = Subject(int) + subject.subscribe(_next, self._on_error, self._on_complete) + subject.subscribe(_next, self._on_error, self._on_complete) + + observable = Observable.from_list([1, 2, 3]) + observable.subscribe(subject, self._on_error, self._on_complete) + + self.assertFalse(self._error)