diff --git a/src/cpl_reactive_extensions/observable.py b/src/cpl_reactive_extensions/observable.py index b3534576..809c1ad7 100644 --- a/src/cpl_reactive_extensions/observable.py +++ b/src/cpl_reactive_extensions/observable.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from typing import Callable, Union, Optional from cpl_reactive_extensions.observer import Observer @@ -31,10 +31,12 @@ class Observable: def subscribe( self, observer_or_next: Union[Callable, Observer], on_error: Callable = None, on_complete: Callable = None ) -> Observer: + observable: Optional[Observable] = None + if isinstance(observer_or_next, Observable): + observable = observer_or_next + if isinstance(observer_or_next, Callable): observer = Observer(observer_or_next, on_error, on_complete) - elif isinstance(observer_or_next, Observable): - observer = observer_or_next else: observer = observer_or_next @@ -42,8 +44,8 @@ class Observable: self._observers.append(observer) return observer - if len(observer._observers) > 0: - for observer in observer._observers: + if observable is not None and len(observable._observers) > 0: + for observer in observable._observers: self._call(observer) else: self._call(observer) diff --git a/src/cpl_reactive_extensions/observer.py b/src/cpl_reactive_extensions/observer.py index a9fe0f4e..0c9dee24 100644 --- a/src/cpl_reactive_extensions/observer.py +++ b/src/cpl_reactive_extensions/observer.py @@ -16,6 +16,9 @@ class Observer: return self._closed def next(self, value: T): + if self._closed: + raise Exception("Observer is closed") + self._on_next(value) def error(self, ex: Exception): @@ -24,6 +27,7 @@ class Observer: self._on_error(ex) def complete(self): + self._closed = True if self._on_complete is None: return diff --git a/unittests/unittests_reactive_extenstions/reactive_test_case.py b/unittests/unittests_reactive_extenstions/reactive_test_case.py index fc30f812..7b590394 100644 --- a/unittests/unittests_reactive_extenstions/reactive_test_case.py +++ b/unittests/unittests_reactive_extenstions/reactive_test_case.py @@ -1,3 +1,4 @@ +import time import traceback import unittest from threading import Timer @@ -15,7 +16,7 @@ class ReactiveTestCase(unittest.TestCase): def _on_error(self, ex: Exception): tb = traceback.format_exc() - Console.error(f"Somthing went wrong: {ex}", tb) + Console.error(f"Got error from observable: {ex}", tb) self._error = True def _on_complete(self): @@ -67,6 +68,19 @@ class ReactiveTestCase(unittest.TestCase): Timer(1.0, complete).start() + time.sleep(2) + + def _test_complete(x: Observer): + x.next(1) + x.next(2) + x.complete() + x.next(3) + + observable2 = Observable(_test_complete) + + observable2.subscribe(lambda x: x, self._on_error) + self.assertTrue(self._error) + def test_observable_from(self): expected_x = 1 @@ -93,10 +107,10 @@ class ReactiveTestCase(unittest.TestCase): expected_x = 1 subject = Subject(int) - subject.subscribe(_next, self._on_error, self._on_complete) - subject.subscribe(_next, self._on_error, self._on_complete) + subject.subscribe(_next, self._on_error) + subject.subscribe(_next, self._on_error) observable = Observable.from_list([1, 2, 3]) - observable.subscribe(subject, self._on_error, self._on_complete) + observable.subscribe(subject, self._on_error) self.assertFalse(self._error)