Last active
July 27, 2025 23:09
-
-
Save SF-300/4d86a2bdd4440b7d153317b71c9424e7 to your computer and use it in GitHub Desktop.
Minimal asyncio-based externally-driven actor system
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio as aio | |
| import itertools | |
| import logging | |
| import typing as t | |
| import dataclasses | |
| import contextvars | |
| from uuid import uuid4 | |
| from enum import IntEnum | |
| from contextvars import ContextVar | |
| from dataclasses import dataclass | |
| from lazy_object_proxy import Proxy # type: ignore | |
| _logger = logging.getLogger(__name__) | |
| class MsgPriority(IntEnum): | |
| NORMAL = 2 | |
| HIGH = 1 | |
| URGENT = 0 | |
| class _Message[M](t.NamedTuple): | |
| priority: MsgPriority | |
| # NOTE: PriorityQueue does not maintain insertion order itself, so we have to enforce it manually. | |
| # https://stackoverflow.com/a/47969819/3344105 | |
| msg_idx: int | |
| msg: M | |
| type Mailbox[M] = aio.PriorityQueue[_Message[M]] | |
| @dataclass(frozen=True) | |
| class Request[Response](t.Awaitable[Response]): | |
| __response: aio.Future[Response] = dataclasses.field( | |
| default_factory=aio.Future, | |
| init=False, | |
| repr=False, | |
| hash=False, | |
| compare=False, | |
| ) | |
| def set_result(self, result: Response) -> None: | |
| self.__response.set_result(result) | |
| def set_exception(self, exception: Exception) -> None: | |
| self.__response.set_exception(exception) | |
| def cancel(self) -> None: | |
| self.__response.cancel() | |
| def __await__(self) -> t.Generator[t.Any, None, Response]: | |
| return self.__response.__await__() | |
| type ActorId = str | |
| @dataclass(frozen=True, kw_only=True) | |
| class ActorState[M]: | |
| actor_stash: tuple[M, ...] = tuple() | |
| InitActorSystem = t.NewType("InitActorSystem", "ActorSystem") | |
| class ActorLike[M](t.Protocol): | |
| def tell(self, *msgs: M) -> None: ... | |
| @t.overload | |
| def ask(self, msg: M) -> t.Awaitable[t.Any]: ... | |
| @t.overload | |
| def ask[Response](self, msg: Request[Response]) -> t.Awaitable[Response]: ... | |
| if t.TYPE_CHECKING: | |
| class ActorRef[M](ActorLike[M]): | |
| def __init__(self, actor_id: ActorId, actor_system: "ActorSystem") -> None: ... | |
| def tell(self, *msgs) -> t.Any: ... | |
| def ask(self, msg) -> t.Any: ... | |
| else: | |
| class ActorRef(Proxy): | |
| def __init__(self, actor_id: ActorId, actor_system: "ActorSystem") -> None: | |
| super().__init__(lambda: actor_system[actor_id]) | |
| class LiveStash[M](t.MutableSequence[M]): | |
| def __init__( | |
| self, | |
| check_is_processing: t.Callable[[], bool], | |
| mailbox: Mailbox[M], | |
| buffer: t.Sequence[M], | |
| msg_idx_iter: t.Iterator[int], | |
| ) -> None: | |
| self._check_is_processing = check_is_processing | |
| self._mailbox = mailbox | |
| self._buffer = list(buffer) | |
| self._msg_idx_iter = msg_idx_iter | |
| def __getitem__(self, index): | |
| return self._buffer[index] | |
| def __len__(self): | |
| return len(self._buffer) | |
| def __setitem__(self, index, value): | |
| if not self._check_is_processing(): | |
| raise RuntimeError("Cannot modify LiveStash while actor is not processing") | |
| self._buffer[index] = value | |
| def __delitem__(self, index): | |
| if not self._check_is_processing(): | |
| raise RuntimeError("Cannot modify LiveStash while actor is not processing") | |
| del self._buffer[index] | |
| def insert(self, index, value): | |
| if not self._check_is_processing(): | |
| raise RuntimeError("Cannot modify LiveStash while actor is not processing") | |
| self._buffer.insert(index, value) | |
| def rindex(self, value: M) -> int: | |
| self.reverse() | |
| i = self.index(value) | |
| self.reverse() | |
| return len(self) - i - 1 | |
| def unstash_all(self) -> None: | |
| if not self._check_is_processing(): | |
| raise RuntimeError("Cannot modify LiveStash while actor is not processing") | |
| while self._buffer: | |
| msg = self._buffer.pop(0) | |
| self._mailbox.put_nowait(_Message(MsgPriority.HIGH, next(self._msg_idx_iter), msg)) | |
| def __call__(self, msg: M) -> None: | |
| self.append(msg) | |
| class Actor[S: ActorState, M](ActorLike[M]): | |
| def __init__( | |
| self, | |
| actor_system: InitActorSystem, | |
| state: S, | |
| actor_id: ActorId | None = None, | |
| ) -> None: | |
| self.__actor_system = actor_system | |
| self.__actor_id = actor_id if actor_id is not None else uuid4().hex | |
| self.__state = state | |
| self.__processing = None | |
| self.__mailbox = mailbox = aio.PriorityQueue() | |
| self.__msg_idx_iter = msg_idx_iter = itertools.count() | |
| def check_is_processing() -> bool: | |
| if self.__processing is None: | |
| return False | |
| if self.__processing.done(): | |
| return False | |
| return True | |
| self.__stash = LiveStash[M](check_is_processing, mailbox, state.actor_stash, msg_idx_iter) | |
| def tell(self, *msgs: M) -> None: | |
| if len(msgs) == 0: | |
| return | |
| if any(isinstance(msg, Request) for msg in msgs): | |
| raise ValueError("Cannot send Request messages using tell()") | |
| self.__actor_system.send(self.__actor_system.current_actor_id, self.__actor_id, *msgs) | |
| def ask(self, msg): | |
| result = self.__actor_system.send( | |
| self.__actor_system.current_actor_id, | |
| self.__actor_id, | |
| msg, | |
| ) | |
| assert isinstance(result, t.Awaitable) | |
| return result | |
| @property | |
| def actor_id(self) -> ActorId: | |
| return self.__actor_id | |
| @property | |
| def state(self) -> S: | |
| return self.__state | |
| @property | |
| def _stash(self) -> LiveStash[M]: | |
| return self.__stash | |
| @property | |
| def _actor_system(self) -> InitActorSystem: | |
| return self.__actor_system | |
| async def _step(self, state: S, *msgs: M) -> S: | |
| return state | |
| async def __process(self) -> None: | |
| state = self.__state | |
| try: | |
| while not self.__mailbox.empty(): | |
| msgs = [] | |
| while not self.__mailbox.empty(): | |
| *_, msg = await self.__mailbox.get() | |
| msgs.append(msg) | |
| state = await self._step(state, *msgs) | |
| for _ in msgs: | |
| self.__mailbox.task_done() | |
| if any(isinstance(msg, Request) for msg in self.__stash): | |
| raise ValueError("Requests cannot be persisted in stash") | |
| self.__state = dataclasses.replace(state, actor_stash=tuple(self.__stash)) | |
| finally: | |
| self.__processing = None | |
| def __receive( | |
| self, | |
| create_task: t.Callable[[t.Coroutine], aio.Future[t.Any]], | |
| *msgs: M, | |
| ) -> aio.Future[t.Any]: | |
| for msg in msgs: | |
| self.__mailbox.put_nowait( | |
| _Message( | |
| priority=MsgPriority.NORMAL, | |
| msg_idx=next(self.__msg_idx_iter), | |
| msg=msg, | |
| ) | |
| ) | |
| if self.__processing is None: | |
| self.__processing = create_task(self.__process()) | |
| return self.__processing | |
| _actor_receive_func_name = f"_{Actor.__name__}__receive" | |
| assert hasattr(Actor, _actor_receive_func_name) | |
| class _Scheduled(t.NamedTuple): | |
| receiver_id: ActorId | |
| receiver: Actor | |
| msgs: t.Sequence[t.Any] | |
| class ActorSystem: | |
| _current_actor_id: ContextVar[ActorId | None] = ContextVar("current_actor", default=None) | |
| _max_batch_msgs: int = 42 | |
| def __init__(self, logger=_logger) -> None: | |
| self.__scheduled = aio.Queue[_Scheduled]() | |
| self.__logger = logger | |
| self.__actors = dict[ActorId, Actor]() | |
| self.__pendings_ops = set[aio.Future]() | |
| self.__running = False | |
| def __getitem__(self, actor_id: ActorId) -> Actor: | |
| return self.__actors[actor_id] | |
| def __len__(self) -> int: | |
| return len(self.__actors) | |
| @property | |
| def current_actor_id(self) -> ActorId | None: | |
| return self._current_actor_id.get() | |
| @property | |
| def is_running(self) -> bool: | |
| return self.__running | |
| def spawn[**P, R: Actor]( | |
| self, | |
| create_actor: t.Callable[t.Concatenate[InitActorSystem, P], R], | |
| *args: P.args, | |
| **kwargs: P.kwargs, | |
| ) -> R: | |
| actor = create_actor(t.cast(InitActorSystem, self), *args, **kwargs) | |
| actor_id = actor.actor_id | |
| self.__actors[actor_id] = actor | |
| return t.cast(R, ActorRef(actor_id, self)) # type: ignore | |
| def send( | |
| self, | |
| sender_id: ActorId | None, | |
| receiver_id: ActorId, | |
| *msgs: t.Any, | |
| ) -> t.Awaitable[t.Any] | None: | |
| actor = self.__actors[receiver_id] | |
| for msg in msgs: | |
| self.__logger.debug(f"{sender_id} -> {receiver_id}: {msg!r}") | |
| if len(msgs) == 0: | |
| return None | |
| elif len(msgs) > 1 and any(isinstance(msg, Request) for msg in msgs): | |
| raise ValueError("Cannot send multiple Request messages at once") | |
| self.__scheduled.put_nowait( | |
| _Scheduled( | |
| receiver_id=receiver_id, | |
| receiver=actor, | |
| msgs=msgs, | |
| ) | |
| ) | |
| if not (len(msgs) == 1 and isinstance(result := msgs[0], Request)): | |
| return None | |
| return result | |
| async def step(self) -> None: | |
| if self.__running: | |
| raise RuntimeError("Actor system is already running") | |
| self.__running = True | |
| try: | |
| async with aio.TaskGroup() as tg: | |
| async def wait_exhausted(): | |
| while True: | |
| if not self.__scheduled.empty(): | |
| became_empty = tg.create_task(self.__scheduled.join()) | |
| self.__pendings_ops.add(became_empty) | |
| if len(self.__pendings_ops) == 0 and self.__scheduled.empty(): | |
| break | |
| done, _ = await aio.wait( | |
| self.__pendings_ops, | |
| return_when=aio.FIRST_COMPLETED, | |
| ) | |
| self.__pendings_ops -= done | |
| async def dispatch(): | |
| def do_step(receiver_id: ActorId, receiver: Actor, msgs) -> aio.Future: | |
| self._current_actor_id.set(receiver_id) | |
| receive_func = getattr(receiver, _actor_receive_func_name) | |
| f = receive_func(tg.create_task, *msgs) | |
| return f | |
| while True: | |
| receiver_id, receiver, msgs = await self.__scheduled.get() | |
| ctx = contextvars.copy_context() | |
| msgs, remaining = msgs[: self._max_batch_msgs], msgs[self._max_batch_msgs :] | |
| f = ctx.run(do_step, receiver_id, receiver, msgs) | |
| self.__pendings_ops.add(f) | |
| self.__scheduled.task_done() | |
| if not remaining: | |
| continue | |
| self.__scheduled.put_nowait( | |
| _Scheduled( | |
| receiver_id=receiver_id, | |
| receiver=receiver, | |
| msgs=remaining, | |
| ) | |
| ) | |
| # NOTE: Wrap with TaskGroup' tasks to ensure that exceptions are correctly propagated. | |
| async with _running(tg.create_task(dispatch())): | |
| await tg.create_task(wait_exhausted()) | |
| except BaseException as e: | |
| raise e | |
| finally: | |
| self.__running = False | |
| @contextlib.asynccontextmanager | |
| async def _running(task: t.Awaitable) -> t.AsyncIterator[None]: | |
| fut = aio.ensure_future(task) | |
| try: | |
| yield | |
| finally: | |
| if not fut.cancel(): | |
| return | |
| with contextlib.suppress(aio.CancelledError): | |
| await fut |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio as aio | |
| import dataclasses | |
| import pytest | |
| import typing as t | |
| from dataclasses import dataclass | |
| from medfriends_shared.misc.async_utils.actor import Actor, ActorSystem, Request, ActorState, InitActorSystem | |
| async def test_actor_state_and_simple_message(): | |
| """Test actor initialization and simple message handling.""" | |
| @dataclass(frozen=True) | |
| class CounterState(ActorState): | |
| value: int = 0 | |
| class CounterActor(Actor): | |
| async def _step(self, state: CounterState, *events: int) -> CounterState: | |
| new_value = state.value | |
| for event in events: | |
| new_value += event | |
| return dataclasses.replace(state, value=new_value) | |
| # Create an actor system | |
| system = ActorSystem() | |
| # Spawn an actor | |
| actor = system.spawn(CounterActor, CounterState()) | |
| # Check initial state | |
| assert actor.state.value == 0 | |
| # Send a message and process it | |
| actor.tell(1) | |
| await system.step() | |
| # Check updated state | |
| assert actor.state.value == 1 | |
| # Send multiple messages | |
| actor.tell(2, 3) | |
| await system.step() | |
| # Check final state | |
| assert actor.state.value == 6 | |
| async def test_multiple_actors(): | |
| """Test multiple actors working together in the system.""" | |
| @dataclass(frozen=True) | |
| class ProducerState(ActorState): | |
| items: list[int] = dataclasses.field(default_factory=list) | |
| class ProducerActor(Actor): | |
| async def _step(self, state: ProducerState, *events: t.Any) -> ProducerState: | |
| # Produce numbers and append to state | |
| new_items = list(state.items) | |
| for event in events: | |
| if isinstance(event, int): | |
| new_items.append(event) | |
| return dataclasses.replace(state, items=new_items) | |
| @dataclass(frozen=True) | |
| class ConsumerState(ActorState): | |
| total: int = 0 | |
| class ConsumerActor(Actor): | |
| async def _step(self, state: ConsumerState, *events: int) -> ConsumerState: | |
| # Sum up all incoming numbers | |
| new_total = state.total | |
| for event in events: | |
| new_total += event | |
| return dataclasses.replace(state, total=new_total) | |
| # Create an actor system | |
| system = ActorSystem() | |
| # Spawn actors | |
| producer = system.spawn(ProducerActor, ProducerState()) | |
| consumer = system.spawn(ConsumerActor, ConsumerState()) | |
| # Send messages to producer | |
| producer.tell(1, 2, 3) | |
| await system.step() | |
| # Check producer state | |
| assert producer.state.items == [1, 2, 3] | |
| # Forward producer state to consumer | |
| for num in producer.state.items: | |
| consumer.tell(num) | |
| await system.step() | |
| # Check consumer state | |
| assert consumer.state.total == 6 | |
| async def test_actor_request_response(): | |
| """Test request-response pattern with actors.""" | |
| # Define a state with requests dict | |
| @dataclass(frozen=True) | |
| class EchoState(ActorState): | |
| pending_requests: dict[str, Request[str]] = dataclasses.field(default_factory=dict) | |
| class EchoActor(Actor): | |
| async def _step( | |
| self, state: EchoState, *events: t.Union[tuple[str, str], Request[str], tuple[str, Request[str]]] | |
| ) -> EchoState: | |
| new_pending = dict(state.pending_requests) | |
| for event in events: | |
| if isinstance(event, Request): | |
| # Just echo back "hello" for any request | |
| event.set_result("hello") | |
| elif isinstance(event, tuple) and len(event) == 2: | |
| if isinstance(event[1], Request): | |
| # Register request with ID | |
| request_id, request = event[0], event[1] | |
| new_pending[request_id] = request | |
| else: | |
| # Handle message with ID | |
| request_id, message = event[0], event[1] | |
| if request_id in new_pending: | |
| new_pending[request_id].set_result(f"Echo: {message}") | |
| new_pending.pop(request_id) | |
| return dataclasses.replace(state, pending_requests=new_pending) | |
| # Create an actor system | |
| system = ActorSystem() | |
| # Spawn an actor | |
| actor = system.spawn(EchoActor, EchoState()) | |
| # Create a request and send it | |
| request = Request[str]() | |
| actor.ask(request) | |
| # Process the request | |
| await system.step() | |
| # Check the response | |
| result = await request | |
| assert result == "hello" | |
| # Test with request ID pattern | |
| request = Request[str]() | |
| # Register request with ID first | |
| actor.tell(("req1", request)) | |
| await system.step() | |
| # Send message with ID | |
| actor.tell(("req1", "test message")) | |
| await system.step() | |
| # Check the response | |
| result = await request | |
| assert result == "Echo: test message" | |
| async def test_nested_actor_request_response(): | |
| """Test request-response pattern between actors within step functions.""" | |
| # Define a responder actor state and step | |
| @dataclass(frozen=True) | |
| class ResponderState(ActorState): | |
| pass | |
| class ResponderActor(Actor): | |
| async def _step( | |
| self, state: ResponderState, *events: t.Union[Request[str], tuple[str, Request[str]]] | |
| ) -> ResponderState: | |
| for event in events: | |
| if isinstance(event, Request): | |
| # Simple request | |
| event.set_result("Simple response") | |
| elif isinstance(event, tuple) and len(event) == 2 and isinstance(event[1], Request): | |
| # Message with request | |
| message, req = event | |
| req.set_result(f"Response to: {message}") | |
| return state | |
| # Define a requester actor state and step | |
| @dataclass(frozen=True) | |
| class RequesterState(ActorState): | |
| responses: list[str] = dataclasses.field(default_factory=list) | |
| nested_call_complete: bool = False | |
| class RequesterActor(Actor): | |
| def __init__(self, actor_system: InitActorSystem, state: RequesterState, responder: Actor): | |
| super().__init__(actor_system, state) | |
| self.responder = responder | |
| async def _step(self, state: RequesterState, *events: t.Union[str, bool]) -> RequesterState: | |
| new_responses = list(state.responses) | |
| new_nested_complete = state.nested_call_complete | |
| for event in events: | |
| if isinstance(event, str): | |
| if event == "make_simple_request": | |
| # Make a simple request and wait for response | |
| request = Request[str]() | |
| response = await self.responder.ask(request) | |
| new_responses.append(response) | |
| else: | |
| # Make a request with a message | |
| request = Request[str]() | |
| self.responder.tell((event, request)) | |
| # Wait for response within step | |
| response = await request | |
| new_responses.append(response) | |
| elif isinstance(event, bool) and event: | |
| # Test nested calls | |
| request1 = Request[str]() | |
| response1 = await self.responder.ask(request1) | |
| # After first response arrives, send second request | |
| request2 = Request[str]() | |
| self.responder.tell(("nested", request2)) | |
| # Wait for second response | |
| response2 = await request2 | |
| new_responses.append(response1) | |
| new_responses.append(response2) | |
| new_nested_complete = True | |
| return dataclasses.replace(state, responses=new_responses, nested_call_complete=new_nested_complete) | |
| # Create an actor system | |
| system = ActorSystem() | |
| # Spawn actors | |
| responder = system.spawn(ResponderActor, ResponderState()) | |
| requester = system.spawn(RequesterActor, RequesterState(), responder) | |
| # Test simple request | |
| requester.tell("make_simple_request") | |
| await system.step() | |
| # Check requester state | |
| assert requester.state.responses == ["Simple response"] | |
| # Test request with message | |
| requester.tell("hello") | |
| await system.step() | |
| assert requester.state.responses == ["Simple response", "Response to: hello"] | |
| # Test nested requests | |
| requester.tell(True) # Trigger nested requests | |
| await system.step() | |
| assert requester.state.nested_call_complete | |
| assert len(requester.state.responses) == 4 | |
| assert requester.state.responses[2:] == ["Simple response", "Response to: nested"] | |
| async def test_chain_of_actors_request_response(): | |
| """Test a chain of actors passing requests through multiple layers.""" | |
| @dataclass(frozen=True) | |
| class FinalResponderState(ActorState): | |
| pass | |
| class FinalResponderActor(Actor): | |
| async def _step( | |
| self, state: FinalResponderState, *events: Request[str] | |
| ) -> FinalResponderState: | |
| for event in events: | |
| if isinstance(event, Request): | |
| event.set_result("Final response") | |
| return state | |
| @dataclass(frozen=True) | |
| class MiddleActorState(ActorState): | |
| pass | |
| class MiddleActor(Actor): | |
| def __init__(self, actor_system: InitActorSystem, state: MiddleActorState, final_actor: Actor): | |
| super().__init__(actor_system, state) | |
| self.final_actor = final_actor | |
| async def _step(self, state: MiddleActorState, *events: Request[str]) -> MiddleActorState: | |
| for event in events: | |
| if isinstance(event, Request): | |
| # Forward request to final actor | |
| forward_request = Request[str]() | |
| final_response = await self.final_actor.ask(forward_request) | |
| # Send modified response back to original requester | |
| event.set_result(f"Middle processed: {final_response}") | |
| return state | |
| @dataclass(frozen=True) | |
| class InitiatorState(ActorState): | |
| response: str = "" | |
| class InitiatorActor(Actor): | |
| def __init__(self, actor_system: InitActorSystem, state: InitiatorState, middle_actor: Actor): | |
| super().__init__(actor_system, state) | |
| self.middle_actor = middle_actor | |
| async def _step(self, state: InitiatorState, *events: str) -> InitiatorState: | |
| new_response = state.response | |
| for event in events: | |
| if event == "start_chain": | |
| # Send request to middle actor | |
| request = Request[str]() | |
| response = await self.middle_actor.ask(request) | |
| new_response = response | |
| return dataclasses.replace(state, response=new_response) | |
| # Create system | |
| system = ActorSystem() | |
| # Spawn actors in reverse order | |
| final_actor = system.spawn(FinalResponderActor, FinalResponderState()) | |
| middle_actor = system.spawn(MiddleActor, MiddleActorState(), final_actor) | |
| initiator = system.spawn(InitiatorActor, InitiatorState(), middle_actor) | |
| # Start the chain | |
| initiator.tell("start_chain") | |
| await system.step() | |
| # Check the result | |
| assert initiator.state.response == "Middle processed: Final response" | |
| async def test_high_throughput(): | |
| """Test actor system performance with high message throughput.""" | |
| @dataclass(frozen=True) | |
| class CounterState(ActorState): | |
| value: int = 0 | |
| class CounterActor(Actor): | |
| async def _step(self, state: CounterState, *events: int) -> CounterState: | |
| new_value = state.value | |
| for event in events: | |
| new_value += event | |
| return dataclasses.replace(state, value=new_value) | |
| # Set up system and actor | |
| system = ActorSystem() | |
| actor = system.spawn(CounterActor, CounterState()) | |
| # Send a large number of messages | |
| message_count = 1000 | |
| for i in range(message_count): | |
| actor.tell(1) | |
| # Process all messages | |
| await system.step() | |
| # Check final state | |
| assert actor.state.value == message_count | |
| async def test_recursive_actor_step_execution(): | |
| """Test system's ability to handle actors that schedule other actors during their step.""" | |
| @dataclass(frozen=True) | |
| class ForwardingState(ActorState): | |
| target: Actor | None = None | |
| events: list[int] = dataclasses.field(default_factory=list) | |
| class ForwardingActor(Actor): | |
| async def _step(self, state: ForwardingState, *events: int) -> ForwardingState: | |
| # Record events | |
| new_events = list(state.events) | |
| new_events.extend(events) | |
| # Forward to target if set | |
| if state.target is not None: | |
| for event in events: | |
| state.target.tell(event * 2) # Double the value when forwarding | |
| return dataclasses.replace(state, events=new_events) | |
| # Set up system and actors | |
| system = ActorSystem() | |
| # Create two actors - second one first so it can be referenced | |
| actor2 = system.spawn(ForwardingActor, ForwardingState()) | |
| actor1 = system.spawn(ForwardingActor, ForwardingState(target=actor2)) | |
| # Send message to actor1 | |
| actor1.tell(1, 2) | |
| await system.step() | |
| # Check states - actor1 should have original messages | |
| assert actor1.state.events == [1, 2] | |
| # actor2 should have doubled messages | |
| assert actor2.state.events == [2, 4] | |
| async def test_request_with_broken_response(): | |
| """Test handling of requests where the responder breaks the protocol.""" | |
| @dataclass(frozen=True) | |
| class BrokenResponderState(ActorState): | |
| pass | |
| class BrokenResponderActor(Actor): | |
| async def _step(self, state: BrokenResponderState, *events: Request[str]) -> BrokenResponderState: | |
| # Doesn't set a result or exception on the request | |
| return state | |
| @dataclass(frozen=True) | |
| class RequesterState(ActorState): | |
| error: Exception | None = None | |
| class RequesterActor(Actor): | |
| def __init__(self, actor_system: InitActorSystem, state: RequesterState, responder: Actor): | |
| super().__init__(actor_system, state) | |
| self.responder = responder | |
| async def _step(self, state: RequesterState, *events: str) -> RequesterState: | |
| new_error = state.error | |
| if "request" in events: | |
| request = Request[str]() | |
| try: | |
| # Wait with a timeout to avoid test hanging | |
| await aio.wait_for(self.responder.ask(request), 0.1) | |
| except Exception as e: | |
| new_error = e | |
| return dataclasses.replace(state, error=new_error) | |
| # Set up actors | |
| system = ActorSystem() | |
| responder = system.spawn(BrokenResponderActor, BrokenResponderState()) | |
| requester = system.spawn(RequesterActor, RequesterState(), responder) | |
| # Send request | |
| requester.tell("request") | |
| await system.step() | |
| # Should have a timeout error | |
| assert requester.state.error is not None | |
| assert isinstance(requester.state.error, aio.TimeoutError) | |
| @pytest.mark.asyncio | |
| async def test_actor_step_exception_handling(): | |
| """Test how exceptions in actor step functions are handled.""" | |
| @dataclasses.dataclass(frozen=True) | |
| class FailingActorState(ActorState): | |
| value: int = 0 | |
| class FailingActor(Actor): | |
| async def _step(self, state: FailingActorState, *events: int) -> FailingActorState: | |
| if any(event < 0 for event in events): | |
| raise ValueError("Negative values not allowed") | |
| return dataclasses.replace(state, value=state.value + sum(events)) | |
| system = ActorSystem() | |
| actor = system.spawn(FailingActor, FailingActorState()) | |
| # Send valid messages | |
| actor.tell(1, 2, 3) | |
| await system.step() | |
| assert actor.state.value == 6 | |
| # Send message that causes exception | |
| actor.tell(-1) | |
| with pytest.RaisesGroup( | |
| pytest.RaisesExc(ValueError, match="Negative values not allowed"), | |
| ): | |
| await system.step() | |
| # State should remain unchanged after exception | |
| assert actor.state.value == 6 | |
| @pytest.mark.asyncio | |
| async def test_stash_unstash_happy_path_preserves_order_and_empties_stash(): | |
| """Events stashed while paused are replayed in-order on resume; stash is emptied.""" | |
| @dataclass(frozen=True) | |
| class S(ActorState): | |
| paused: bool = False | |
| processed: list[int] = dataclasses.field(default_factory=list) | |
| stashed_count: int = 0 | |
| class StashingActor(Actor): | |
| def __init__(self, system: InitActorSystem, state: S): | |
| super().__init__(system, state) | |
| async def _step(self, state: S, *events: t.Union[str, int]) -> S: | |
| new_paused = state.paused | |
| new_processed = list(state.processed) | |
| new_stashed_count = state.stashed_count | |
| for ev in events: | |
| if ev == "pause": | |
| new_paused = True | |
| elif ev == "resume": | |
| new_paused = False | |
| self._stash.unstash_all() | |
| elif isinstance(ev, int): | |
| if new_paused: | |
| self._stash(ev) | |
| new_stashed_count += 1 | |
| else: | |
| new_processed.append(ev) | |
| return dataclasses.replace(state, paused=new_paused, processed=new_processed, stashed_count=new_stashed_count) | |
| system = ActorSystem() | |
| actor = system.spawn(StashingActor, S()) | |
| actor.tell("pause") | |
| await system.step() | |
| assert actor.state.paused is True | |
| actor.tell(1, 2, 3) | |
| await system.step() | |
| assert actor.state.processed == [] | |
| assert actor.state.stashed_count == 3 # Verify 3 items were stashed | |
| actor.tell("resume") | |
| await system.step() | |
| assert actor.state.paused is False | |
| assert actor.state.processed == [1, 2, 3] # Verify stashed items were processed in order | |
| @pytest.mark.asyncio | |
| async def test_multiple_batches_stashed_then_single_resume_replays_all_in_order(): | |
| """Several paused rounds accumulate in stash; one resume replays them FIFO.""" | |
| @dataclass(frozen=True) | |
| class S(ActorState): | |
| paused: bool = False | |
| processed: list[int] = dataclasses.field(default_factory=list) | |
| stashed_count: int = 0 | |
| class StashingActor(Actor): | |
| def __init__(self, system: InitActorSystem, state: S): | |
| super().__init__(system, state) | |
| async def _step(self, state: S, *events: t.Union[str, int]) -> S: | |
| new_paused = state.paused | |
| new_processed = list(state.processed) | |
| new_stashed_count = state.stashed_count | |
| for ev in events: | |
| if ev == "pause": | |
| new_paused = True | |
| elif ev == "resume": | |
| new_paused = False | |
| self._stash.unstash_all() | |
| elif isinstance(ev, int): | |
| if new_paused: | |
| self._stash(ev) | |
| new_stashed_count += 1 | |
| else: | |
| new_processed.append(ev) | |
| return dataclasses.replace(state, paused=new_paused, processed=new_processed, stashed_count=new_stashed_count) | |
| system = ActorSystem() | |
| actor = system.spawn(StashingActor, S()) | |
| actor.tell("pause") | |
| await system.step() | |
| actor.tell(1, 2) | |
| await system.step() | |
| assert actor.state.stashed_count == 2 # Verify 2 items were stashed | |
| assert actor.state.processed == [] | |
| actor.tell(3, 4) | |
| await system.step() | |
| assert actor.state.stashed_count == 4 # Verify 4 items total were stashed | |
| assert actor.state.processed == [] | |
| actor.tell("resume") | |
| await system.step() | |
| assert actor.state.processed == [1, 2, 3, 4] # Verify all stashed items were processed in FIFO order |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment