diff --git a/src/cpl_reactive_extensions/abc/operator.py b/src/cpl_reactive_extensions/abc/operator.py index f82425b4..8ac9ee04 100644 --- a/src/cpl_reactive_extensions/abc/operator.py +++ b/src/cpl_reactive_extensions/abc/operator.py @@ -1,12 +1,8 @@ -from abc import ABC from typing import Any from cpl_reactive_extensions.subscriber import Subscriber -class Operator(ABC): - def __init__(self): - ABC.__init__(self) - +class Operator: def call(self, subscriber: Subscriber, source: Any): pass diff --git a/src/cpl_reactive_extensions/observable.py b/src/cpl_reactive_extensions/observable.py index 57619f0d..0bab5258 100644 --- a/src/cpl_reactive_extensions/observable.py +++ b/src/cpl_reactive_extensions/observable.py @@ -1,5 +1,8 @@ +from __future__ import annotations + from typing import Callable, Any, Optional +from cpl_core.type import T from cpl_reactive_extensions.abc.operator import Operator from cpl_reactive_extensions.abc.subscribable import Subscribable from cpl_reactive_extensions.subscriber import Observer, Subscriber @@ -35,6 +38,12 @@ class Observable(Subscribable): observable = Observable(callback) return observable + def lift(self, operator: Operator) -> Observable: + observable = Observable() + observable._source = self + observable._operator = operator + return observable + @staticmethod def _is_observer(value: Any) -> bool: return isinstance(value, Observer) @@ -80,3 +89,26 @@ class Observable(Subscribable): self._subscribe(observer) except Exception as e: observer.error(e) + + def pipe(self, *args) -> Observable: + # observables = [] + # for arg in args: + # observable = arg(self) + # observables.append(observable) + return self._pipe_from_array(args) + + def _pipe_from_array(self, args): + if len(args) == 0: + return lambda x: x + + if len(args) == 1: + return args[0] + + def piped(input: T): + return Observable._reduce(lambda prev, fn: fn(prev), input) + + return piped + + @staticmethod + def _reduce(func: Callable, input: T): + return func(input) diff --git a/src/cpl_reactive_extensions/operator_subscriber.py b/src/cpl_reactive_extensions/operator_subscriber.py new file mode 100644 index 00000000..28e3998e --- /dev/null +++ b/src/cpl_reactive_extensions/operator_subscriber.py @@ -0,0 +1,53 @@ +from typing import Callable + +from cpl_core.type import T +from cpl_reactive_extensions import Subscriber + + +class OperatorSubscriber(Subscriber): + def __init__( + self, + destination: Subscriber, + on_next: Callable = None, + on_error: Callable = None, + on_complete: Callable = None, + on_finalize: Callable = None, + should_unsubscribe: Callable = None, + ): + Subscriber.__init__(self) + self._on_finalize = on_finalize + self._should_unsubscribe = should_unsubscribe + + def on_next_wrapper(self: OperatorSubscriber, value: T): + try: + on_next(value) + except Exception as e: + destination.error(e) + + self._on_next = on_next_wrapper if on_next is not None else super()._on_next + + def on_error_wrapper(self: OperatorSubscriber, value: T): + try: + on_error(value) + except Exception as e: + destination.error(e) + finally: + self.unsubscribe() + + self._on_error = on_error_wrapper if on_error is not None else super()._on_error + + def on_complete_wrapper(self: OperatorSubscriber, value: T): + try: + on_complete(value) + except Exception as e: + destination.error(e) + finally: + self.unsubscribe() + + self._on_complete = on_complete_wrapper if on_complete is not None else super()._on_complete + + def unsubscribe(self): + if self._should_unsubscribe and not self._should_unsubscribe(): + return + super().unsubscribe() + not self.closed and self._on_finalize is not None and self._on_finalize() diff --git a/src/cpl_reactive_extensions/operators/__init__.py b/src/cpl_reactive_extensions/operators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cpl_reactive_extensions/operators/take.py b/src/cpl_reactive_extensions/operators/take.py new file mode 100644 index 00000000..b9c64af0 --- /dev/null +++ b/src/cpl_reactive_extensions/operators/take.py @@ -0,0 +1,26 @@ +from cpl_core.type import T +from cpl_reactive_extensions import Subscriber, Observable +from cpl_reactive_extensions.operator_subscriber import OperatorSubscriber +from cpl_reactive_extensions.utils import operate + + +def take(count: int): + if count <= 0: + return Observable() + + def init(source: Observable, subscriber: Subscriber): + seen = 0 + + def sub(value: T): + nonlocal seen + + if seen + 1 <= count: + seen += 1 + subscriber.next(value) + + if count <= seen: + subscriber.complete() + + source.subscribe(OperatorSubscriber(subscriber, sub)) + + return operate(init) diff --git a/src/cpl_reactive_extensions/subscriber.py b/src/cpl_reactive_extensions/subscriber.py index ae12d78a..5c4959ba 100644 --- a/src/cpl_reactive_extensions/subscriber.py +++ b/src/cpl_reactive_extensions/subscriber.py @@ -10,6 +10,7 @@ class Subscriber(Subscription, Observer): def __init__( self, on_next_or_observer: ObserverOrCallable, on_error: Callable = None, on_complete: Callable = None ): + self.is_stopped = False Subscription.__init__(self) if isinstance(on_next_or_observer, Observer): self._on_next = on_next_or_observer.next @@ -21,21 +22,21 @@ class Subscriber(Subscription, Observer): self._on_complete = on_complete def next(self, value: T): - if self._closed: + if self.is_stopped: raise Exception("Observer is closed") self._on_next(value) def error(self, ex: Exception): - if self._on_error is None: + if self.is_stopped: return self._on_error(ex) def complete(self): - self._closed = True - if self._on_complete is None: + if self.is_stopped: return + self.is_stopped = True self._on_complete() def unsubscribe(self): @@ -43,3 +44,6 @@ class Subscriber(Subscription, Observer): return super().unsubscribe() + self._on_next = None + self._on_error = None + self._on_complete = None diff --git a/src/cpl_reactive_extensions/utils.py b/src/cpl_reactive_extensions/utils.py new file mode 100644 index 00000000..1fd7da18 --- /dev/null +++ b/src/cpl_reactive_extensions/utils.py @@ -0,0 +1,23 @@ +from typing import Callable + +from cpl_reactive_extensions import Observable, Subscriber +from cpl_reactive_extensions.abc import Operator + + +def operate(init: Callable[[Observable, Subscriber], Operator]): + def observable(source: Observable): + def create(self: Subscriber, lifted_source: Observable): + try: + return init(lifted_source, self) + except Exception as e: + self.error(e) + + operator = Operator() + operator.call = create + + if "lift" not in dir(source): + raise TypeError("Unable to lift unknown Observable type") + + return source.lift(operator) + + return observable diff --git a/unittests/unittests_reactive_extenstions/observable_operator.py b/unittests/unittests_reactive_extenstions/observable_operator.py new file mode 100644 index 00000000..f31b539a --- /dev/null +++ b/unittests/unittests_reactive_extenstions/observable_operator.py @@ -0,0 +1,27 @@ +import traceback +import unittest + +from cpl_core.console import Console +from cpl_reactive_extensions.interval import Interval +from cpl_reactive_extensions.operators.take import take + + +class ObservableOperatorTestCase(unittest.TestCase): + def setUp(self): + self._error = False + self._completed = False + + def _on_error(self, ex: Exception): + tb = traceback.format_exc() + Console.error(f"Got error from observable: {ex}", tb) + self._error = True + + def _on_complete(self): + self._completed = True + + def test_take_two(self): + def sub(x): + Console.write_line(x) + + observable = Interval(1.0) + sub = observable.pipe(take(2)).subscribe(sub) diff --git a/unittests/unittests_reactive_extenstions/reactive_test_suite.py b/unittests/unittests_reactive_extenstions/reactive_test_suite.py index 050700d3..e95e8cbf 100644 --- a/unittests/unittests_reactive_extenstions/reactive_test_suite.py +++ b/unittests/unittests_reactive_extenstions/reactive_test_suite.py @@ -1,10 +1,6 @@ import unittest -from unittests_query.enumerable_query_test_case import EnumerableQueryTestCase -from unittests_query.enumerable_test_case import EnumerableTestCase -from unittests_query.iterable_query_test_case import IterableQueryTestCase -from unittests_query.iterable_test_case import IterableTestCase -from unittests_query.sequence_test_case import SequenceTestCase +from unittests_reactive_extenstions.observable_operator import ObservableOperatorTestCase from unittests_reactive_extenstions.reactive_test_case import ReactiveTestCase @@ -14,6 +10,7 @@ class ReactiveTestSuite(unittest.TestSuite): loader = unittest.TestLoader() self.addTests(loader.loadTestsFromTestCase(ReactiveTestCase)) + self.addTests(loader.loadTestsFromTestCase(ObservableOperatorTestCase)) def run(self, *args): super().run(*args)