Last active
July 12, 2025 10:00
-
-
Save SF-300/feede04ec35ea12fbf4cd56573c47e3c to your computer and use it in GitHub Desktop.
Temporal.io cross-workflow pub-sub implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import contextlib | |
| import functools | |
| import json | |
| import typing as t | |
| import hashlib | |
| from types import MethodType | |
| import pydantic | |
| import pydantic_core | |
| from temporalio import workflow as wf | |
| from temporalio.exceptions import ApplicationError | |
| from temporalio.service import RPCError, RPCStatusCode | |
| with wf.unsafe.imports_passed_through(): | |
| import yaql | |
| import pydantic | |
| from yaql.language.expressions import Statement as CompiledQuery | |
| from pydantic.dataclasses import dataclass | |
| # FIXME: Limit available functions to ensure non-determinism is not introducable via queries. | |
| _yaql_engine = yaql.factory.YaqlFactory().create() | |
| type Query = str | |
| type WorkflowId = str | |
| @dataclass(frozen=True) | |
| class SenderIdV1: | |
| wf_id: WorkflowId | |
| run_id: str | |
| @dataclass(frozen=True) | |
| class StateVersionIdV1: | |
| number: int | |
| @dataclass(frozen=True) | |
| class SubscribedEventV1: | |
| subscriber: SenderIdV1 | |
| handler: str | |
| state_query: Query | |
| state_version: StateVersionIdV1 = StateVersionIdV1(number=0) | |
| type_: t.Literal["8dcd0faf-78e1-41eb-92af-0d9b46331a33"] = ( | |
| "8dcd0faf-78e1-41eb-92af-0d9b46331a33" | |
| ) | |
| @dataclass(frozen=True) | |
| class UnsubscribeEventV1: | |
| subscriber: SenderIdV1 | |
| type_: t.Literal["2aaeef9e-b5ea-40df-bb09-4bb973fa61b2"] = ( | |
| "2aaeef9e-b5ea-40df-bb09-4bb973fa61b2" | |
| ) | |
| type SubscriptionManagementEventV1 = t.Annotated[ | |
| SubscribedEventV1 | UnsubscribeEventV1, | |
| pydantic.Discriminator("type_"), | |
| ] | |
| class _CompiledQueryPydanticAnnotation: | |
| @classmethod | |
| def __get_pydantic_core_schema__( | |
| cls, | |
| _source_type: t.Any, | |
| _handler: pydantic.GetCoreSchemaHandler, | |
| ) -> pydantic_core.core_schema.CoreSchema: | |
| def validate_from_str(value: str) -> CompiledQuery: | |
| result = compile_yaql_query(value) | |
| return result | |
| from_str_schema = pydantic_core.core_schema.chain_schema( | |
| [ | |
| pydantic_core.core_schema.str_schema(), | |
| pydantic_core.core_schema.no_info_plain_validator_function( | |
| validate_from_str | |
| ), | |
| ] | |
| ) | |
| return pydantic_core.core_schema.json_or_python_schema( | |
| json_schema=from_str_schema, | |
| python_schema=pydantic_core.core_schema.union_schema( | |
| [ | |
| # check if it's an instance first before doing any further work | |
| pydantic_core.core_schema.is_instance_schema(CompiledQuery), | |
| from_str_schema, | |
| ] | |
| ), | |
| serialization=pydantic_core.core_schema.plain_serializer_function_ser_schema( | |
| lambda instance: instance.expression | |
| ), | |
| ) | |
| @classmethod | |
| def __get_pydantic_json_schema__( | |
| cls, | |
| _core_schema: pydantic_core.core_schema.CoreSchema, | |
| handler: pydantic.GetJsonSchemaHandler, | |
| ) -> pydantic.json_schema.JsonSchemaValue: | |
| return handler(pydantic_core.core_schema.str_schema()) | |
| @dataclass | |
| class SubscriptionV1: | |
| handler: str | |
| query: t.Annotated[CompiledQuery, _CompiledQueryPydanticAnnotation] | |
| version: StateVersionIdV1 | |
| last_sent_state_hash: bytes | None = None | |
| type SubscriptionsV1 = t.Mapping[WorkflowId, SubscriptionV1] | |
| type JsonCollection = t.Mapping | t.Sequence | |
| @dataclass(frozen=True) | |
| class StateUpdateNotificationV1: | |
| sender: SenderIdV1 | |
| state: JsonCollection | |
| state_version: StateVersionIdV1 | |
| type_: t.Literal["412e3c98-731f-4818-8a2f-2395f42272d0"] = ( | |
| "412e3c98-731f-4818-8a2f-2395f42272d0" | |
| ) | |
| class _CreateNotification(t.Protocol): | |
| def __call__( | |
| self, | |
| sender: SenderIdV1, | |
| state_version: StateVersionIdV1, | |
| state: JsonCollection, | |
| ) -> StateUpdateNotificationV1: ... | |
| def _create_default_notification( | |
| sender: SenderIdV1, | |
| state_version: StateVersionIdV1, | |
| state: JsonCollection, | |
| ) -> StateUpdateNotificationV1: | |
| return StateUpdateNotificationV1( | |
| sender=sender, | |
| state_version=state_version, | |
| state=state, | |
| ) | |
| class StateUpdatesEmitter: | |
| def __init__( | |
| self, | |
| get_state: t.Callable[[StateVersionIdV1], t.Awaitable[JsonCollection]], | |
| subscriptions: SubscriptionsV1 | None = None, | |
| ): | |
| self._get_state = get_state | |
| self._subscriptions = dict[WorkflowId, SubscriptionV1]() | |
| if subscriptions: | |
| self._subscriptions.update(subscriptions) | |
| self._sender_info = wf.info() | |
| async def process(self, event: SubscribedEventV1 | UnsubscribeEventV1) -> None: | |
| if isinstance(event, SubscribedEventV1): | |
| # NOTE: If we're re-subscribing, we overwrite the existing subscription | |
| subscription = self._subscriptions[event.subscriber.wf_id] = SubscriptionV1( | |
| handler=event.handler, | |
| query=compile_yaql_query(event.state_query), | |
| version=event.state_version, | |
| ) | |
| state = await self._get_state(event.state_version) | |
| await self._send_notification( | |
| _create_default_notification, # TODO: Make parameterizable | |
| event.subscriber.wf_id, | |
| subscription, | |
| state, | |
| ) | |
| elif isinstance(event, UnsubscribeEventV1): | |
| self._subscriptions.pop(event.subscriber.wf_id, None) | |
| async def fire( | |
| self, create_notification: _CreateNotification = _create_default_notification | |
| ) -> None: | |
| version_to_state_f = dict[StateVersionIdV1, t.Awaitable[JsonCollection]]() | |
| errors = [] | |
| # NOTE: We have to use a loop, and not something like asyncio.gather or asyncio.TaskGroup, | |
| # because they don't guarantee the order of coroutines execution, and we can't afford | |
| # non-determinism in Temporal workflows. | |
| for subscriber_wid, subscription in tuple(self._subscriptions.items()): | |
| if subscription.version not in version_to_state_f: | |
| version_to_state_f[subscription.version] = self._get_state( | |
| subscription.version | |
| ) | |
| # FIXME: Handle failures | |
| try: | |
| state = await version_to_state_f[subscription.version] | |
| await self._send_notification( | |
| create_notification, | |
| subscriber_wid, | |
| subscription, | |
| state, | |
| ) | |
| except Exception as e: | |
| errors.append(e) | |
| if errors: | |
| raise ExceptionGroup( | |
| "Failed to send notifications to some subscribers", | |
| errors, | |
| ) | |
| async def _send_notification( | |
| self, | |
| create_notification: _CreateNotification, | |
| subscriber_wid: WorkflowId, | |
| subscription: SubscriptionV1, | |
| state: JsonCollection, | |
| ) -> None: | |
| try: | |
| state = subscription.query.evaluate(data=state) # type: ignore | |
| if state is None: | |
| return | |
| state_hash = deterministic_json_hash(state) | |
| if state_hash == subscription.last_sent_state_hash: | |
| # No change in state, skip sending notification | |
| return | |
| notification = create_notification( | |
| SenderIdV1( | |
| wf_id=self._sender_info.workflow_id, | |
| run_id=self._sender_info.run_id, | |
| ), | |
| state_version=subscription.version, | |
| state=state, | |
| ) | |
| subscriber_handle = wf.get_external_workflow_handle( | |
| workflow_id=subscriber_wid, | |
| ) | |
| try: | |
| await subscriber_handle.signal( | |
| subscription.handler, | |
| notification, | |
| ) | |
| except RPCError as e: | |
| # FIXME: Log | |
| if e.status != RPCStatusCode.NOT_FOUND: | |
| raise | |
| # The subscriber workflow is not running, remove the subscription | |
| self._subscriptions.pop(subscriber_wid, None) | |
| else: | |
| subscription.last_sent_state_hash = state_hash | |
| except Exception as e: | |
| e.add_note( | |
| f"Failed to send notification to subscriber '{subscriber_wid}' " | |
| f"with handler '{subscription.handler}' and query '{subscription.query.expression}'" | |
| ) | |
| raise e | |
| @dataclass(frozen=True) | |
| class Subscribe: | |
| _subscription_management_signal_handler: str | |
| async def __call__( | |
| self, | |
| emitter_wid: WorkflowId, | |
| handler: str | t.Callable[[t.Any], t.Any], | |
| query: str = "$", | |
| *, | |
| emitter_rid: str | None = None, | |
| ) -> None: | |
| handle = wf.get_external_workflow_handle( | |
| workflow_id=emitter_wid, | |
| run_id=emitter_rid, | |
| ) | |
| if isinstance(handler, MethodType): | |
| handler = handler.__func__ | |
| if callable(handler): | |
| handler = handler.__name__ | |
| if not isinstance(handler, str): | |
| raise TypeError( | |
| "Handler must be a name of a signal handler method or a reference to it" | |
| ) | |
| subscriber_info = wf.info() | |
| await handle.signal( | |
| self._subscription_management_signal_handler, | |
| SubscribedEventV1( | |
| subscriber=SenderIdV1( | |
| wf_id=subscriber_info.workflow_id, | |
| run_id=subscriber_info.run_id, | |
| ), | |
| handler=handler, | |
| state_query=ensure_yaql_query(query), | |
| ), | |
| ) | |
| @dataclass(frozen=True) | |
| class Unsubscribe: | |
| _subscription_management_signal_handler: str | |
| async def __call__( | |
| self, | |
| emitter_wid: WorkflowId, | |
| *, | |
| emitter_rid: str | None = None, | |
| ) -> None: | |
| handle = wf.get_external_workflow_handle( | |
| emitter_wid, | |
| run_id=emitter_rid, | |
| ) | |
| subscriber_info = wf.info() | |
| await handle.signal( | |
| self._subscription_management_signal_handler, | |
| UnsubscribeEventV1( | |
| subscriber=SenderIdV1( | |
| wf_id=subscriber_info.workflow_id, | |
| run_id=subscriber_info.run_id, | |
| ), | |
| ), | |
| ) | |
| @dataclass(frozen=True) | |
| class Subscribed: | |
| _subscribe: Subscribe | |
| _unsubscribe: Unsubscribe | |
| @contextlib.asynccontextmanager | |
| async def __call__( | |
| self, | |
| emitter_wid: WorkflowId, | |
| handler: str | t.Callable[[t.Any], t.Any], | |
| query: str = "*", | |
| *, | |
| emitter_rid: str | None = None, | |
| ) -> t.AsyncIterator[None]: | |
| await self._subscribe( | |
| emitter_wid, | |
| handler, | |
| query, | |
| emitter_rid=emitter_rid, | |
| ) | |
| try: | |
| yield | |
| finally: | |
| await self._unsubscribe( | |
| emitter_wid, | |
| emitter_rid=emitter_rid, | |
| ) | |
| class SubscriptionFuncs(t.NamedTuple): | |
| subscribe: Subscribe | |
| unsubscribe: Unsubscribe | |
| subscribed: Subscribed | |
| class SubscriptionSignalsHandlerMethodDecl(t.Protocol): | |
| __name__: str | |
| def __call__( | |
| _, # type: ignore | |
| self, | |
| event: SubscriptionManagementEventV1, | |
| ) -> t.Any: ... | |
| def create_subscription_funcs( | |
| handler: SubscriptionSignalsHandlerMethodDecl | str, | |
| ) -> SubscriptionFuncs: | |
| if callable(handler): | |
| handler = handler.__name__ | |
| subscribe = Subscribe(_subscription_management_signal_handler=handler) | |
| unsubscribe = Unsubscribe(_subscription_management_signal_handler=handler) | |
| return SubscriptionFuncs( | |
| subscribe=subscribe, | |
| unsubscribe=unsubscribe, | |
| subscribed=Subscribed( | |
| _subscribe=subscribe, | |
| _unsubscribe=unsubscribe, | |
| ), | |
| ) | |
| def ensure_yaql_query(query: str) -> Query: | |
| compile_yaql_query(query) | |
| return t.cast(Query, query) | |
| @functools.lru_cache(maxsize=None) | |
| def compile_yaql_query(query: str) -> CompiledQuery: | |
| return _yaql_engine(query) | |
| def deterministic_json_hash(obj) -> bytes: | |
| """ | |
| Computes a deterministic SHA256 hash of a JSON-serializable object. | |
| Ensures consistent hashing by sorting dictionary keys and using compact separators. | |
| This is crucial for Temporal's determinism requirements. | |
| """ | |
| try: | |
| json_string = json.dumps( | |
| obj, | |
| sort_keys=True, | |
| separators=(",", ":"), | |
| ensure_ascii=False, | |
| ) | |
| except TypeError as e: | |
| raise ApplicationError( | |
| f"Object not JSON serializable for deterministic hashing: {e}. " | |
| "Ensure state data is JSON-serializable." | |
| ) | |
| return hashlib.sha256(json_string.encode("utf-8")).digest() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment