[WIP] operator implement

This commit is contained in:
Sven Heidemann 2023-04-16 03:06:55 +02:00
parent 79a6c1db8f
commit 82f23f237c
9 changed files with 172 additions and 14 deletions

View File

@ -1,12 +1,8 @@
from abc import ABC
from typing import Any
from cpl_reactive_extensions.subscriber import Subscriber
class Operator(ABC):
def __init__(self):
ABC.__init__(self)
class Operator:
def call(self, subscriber: Subscriber, source: Any):
pass

View File

@ -1,5 +1,8 @@
from __future__ import annotations
from typing import Callable, Any, Optional
from cpl_core.type import T
from cpl_reactive_extensions.abc.operator import Operator
from cpl_reactive_extensions.abc.subscribable import Subscribable
from cpl_reactive_extensions.subscriber import Observer, Subscriber
@ -35,6 +38,12 @@ class Observable(Subscribable):
observable = Observable(callback)
return observable
def lift(self, operator: Operator) -> Observable:
observable = Observable()
observable._source = self
observable._operator = operator
return observable
@staticmethod
def _is_observer(value: Any) -> bool:
return isinstance(value, Observer)
@ -80,3 +89,26 @@ class Observable(Subscribable):
self._subscribe(observer)
except Exception as e:
observer.error(e)
def pipe(self, *args) -> Observable:
# observables = []
# for arg in args:
# observable = arg(self)
# observables.append(observable)
return self._pipe_from_array(args)
def _pipe_from_array(self, args):
if len(args) == 0:
return lambda x: x
if len(args) == 1:
return args[0]
def piped(input: T):
return Observable._reduce(lambda prev, fn: fn(prev), input)
return piped
@staticmethod
def _reduce(func: Callable, input: T):
return func(input)

View File

@ -0,0 +1,53 @@
from typing import Callable
from cpl_core.type import T
from cpl_reactive_extensions import Subscriber
class OperatorSubscriber(Subscriber):
def __init__(
self,
destination: Subscriber,
on_next: Callable = None,
on_error: Callable = None,
on_complete: Callable = None,
on_finalize: Callable = None,
should_unsubscribe: Callable = None,
):
Subscriber.__init__(self)
self._on_finalize = on_finalize
self._should_unsubscribe = should_unsubscribe
def on_next_wrapper(self: OperatorSubscriber, value: T):
try:
on_next(value)
except Exception as e:
destination.error(e)
self._on_next = on_next_wrapper if on_next is not None else super()._on_next
def on_error_wrapper(self: OperatorSubscriber, value: T):
try:
on_error(value)
except Exception as e:
destination.error(e)
finally:
self.unsubscribe()
self._on_error = on_error_wrapper if on_error is not None else super()._on_error
def on_complete_wrapper(self: OperatorSubscriber, value: T):
try:
on_complete(value)
except Exception as e:
destination.error(e)
finally:
self.unsubscribe()
self._on_complete = on_complete_wrapper if on_complete is not None else super()._on_complete
def unsubscribe(self):
if self._should_unsubscribe and not self._should_unsubscribe():
return
super().unsubscribe()
not self.closed and self._on_finalize is not None and self._on_finalize()

View File

@ -0,0 +1,26 @@
from cpl_core.type import T
from cpl_reactive_extensions import Subscriber, Observable
from cpl_reactive_extensions.operator_subscriber import OperatorSubscriber
from cpl_reactive_extensions.utils import operate
def take(count: int):
if count <= 0:
return Observable()
def init(source: Observable, subscriber: Subscriber):
seen = 0
def sub(value: T):
nonlocal seen
if seen + 1 <= count:
seen += 1
subscriber.next(value)
if count <= seen:
subscriber.complete()
source.subscribe(OperatorSubscriber(subscriber, sub))
return operate(init)

View File

@ -10,6 +10,7 @@ class Subscriber(Subscription, Observer):
def __init__(
self, on_next_or_observer: ObserverOrCallable, on_error: Callable = None, on_complete: Callable = None
):
self.is_stopped = False
Subscription.__init__(self)
if isinstance(on_next_or_observer, Observer):
self._on_next = on_next_or_observer.next
@ -21,21 +22,21 @@ class Subscriber(Subscription, Observer):
self._on_complete = on_complete
def next(self, value: T):
if self._closed:
if self.is_stopped:
raise Exception("Observer is closed")
self._on_next(value)
def error(self, ex: Exception):
if self._on_error is None:
if self.is_stopped:
return
self._on_error(ex)
def complete(self):
self._closed = True
if self._on_complete is None:
if self.is_stopped:
return
self.is_stopped = True
self._on_complete()
def unsubscribe(self):
@ -43,3 +44,6 @@ class Subscriber(Subscription, Observer):
return
super().unsubscribe()
self._on_next = None
self._on_error = None
self._on_complete = None

View File

@ -0,0 +1,23 @@
from typing import Callable
from cpl_reactive_extensions import Observable, Subscriber
from cpl_reactive_extensions.abc import Operator
def operate(init: Callable[[Observable, Subscriber], Operator]):
def observable(source: Observable):
def create(self: Subscriber, lifted_source: Observable):
try:
return init(lifted_source, self)
except Exception as e:
self.error(e)
operator = Operator()
operator.call = create
if "lift" not in dir(source):
raise TypeError("Unable to lift unknown Observable type")
return source.lift(operator)
return observable

View File

@ -0,0 +1,27 @@
import traceback
import unittest
from cpl_core.console import Console
from cpl_reactive_extensions.interval import Interval
from cpl_reactive_extensions.operators.take import take
class ObservableOperatorTestCase(unittest.TestCase):
def setUp(self):
self._error = False
self._completed = False
def _on_error(self, ex: Exception):
tb = traceback.format_exc()
Console.error(f"Got error from observable: {ex}", tb)
self._error = True
def _on_complete(self):
self._completed = True
def test_take_two(self):
def sub(x):
Console.write_line(x)
observable = Interval(1.0)
sub = observable.pipe(take(2)).subscribe(sub)

View File

@ -1,10 +1,6 @@
import unittest
from unittests_query.enumerable_query_test_case import EnumerableQueryTestCase
from unittests_query.enumerable_test_case import EnumerableTestCase
from unittests_query.iterable_query_test_case import IterableQueryTestCase
from unittests_query.iterable_test_case import IterableTestCase
from unittests_query.sequence_test_case import SequenceTestCase
from unittests_reactive_extenstions.observable_operator import ObservableOperatorTestCase
from unittests_reactive_extenstions.reactive_test_case import ReactiveTestCase
@ -14,6 +10,7 @@ class ReactiveTestSuite(unittest.TestSuite):
loader = unittest.TestLoader()
self.addTests(loader.loadTestsFromTestCase(ReactiveTestCase))
self.addTests(loader.loadTestsFromTestCase(ObservableOperatorTestCase))
def run(self, *args):
super().run(*args)