Skip to content

Instantly share code, notes, and snippets.

@SF-300
Last active July 12, 2025 10:00
Show Gist options
  • Select an option

  • Save SF-300/feede04ec35ea12fbf4cd56573c47e3c to your computer and use it in GitHub Desktop.

Select an option

Save SF-300/feede04ec35ea12fbf4cd56573c47e3c to your computer and use it in GitHub Desktop.
Temporal.io cross-workflow pub-sub implementation
import contextlib
import functools
import json
import typing as t
import hashlib
from types import MethodType
import pydantic
import pydantic_core
from temporalio import workflow as wf
from temporalio.exceptions import ApplicationError
from temporalio.service import RPCError, RPCStatusCode
with wf.unsafe.imports_passed_through():
import yaql
import pydantic
from yaql.language.expressions import Statement as CompiledQuery
from pydantic.dataclasses import dataclass
# FIXME: Limit available functions to ensure non-determinism is not introducable via queries.
_yaql_engine = yaql.factory.YaqlFactory().create()
type Query = str
type WorkflowId = str
@dataclass(frozen=True)
class SenderIdV1:
wf_id: WorkflowId
run_id: str
@dataclass(frozen=True)
class StateVersionIdV1:
number: int
@dataclass(frozen=True)
class SubscribedEventV1:
subscriber: SenderIdV1
handler: str
state_query: Query
state_version: StateVersionIdV1 = StateVersionIdV1(number=0)
type_: t.Literal["8dcd0faf-78e1-41eb-92af-0d9b46331a33"] = (
"8dcd0faf-78e1-41eb-92af-0d9b46331a33"
)
@dataclass(frozen=True)
class UnsubscribeEventV1:
subscriber: SenderIdV1
type_: t.Literal["2aaeef9e-b5ea-40df-bb09-4bb973fa61b2"] = (
"2aaeef9e-b5ea-40df-bb09-4bb973fa61b2"
)
type SubscriptionManagementEventV1 = t.Annotated[
SubscribedEventV1 | UnsubscribeEventV1,
pydantic.Discriminator("type_"),
]
class _CompiledQueryPydanticAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: t.Any,
_handler: pydantic.GetCoreSchemaHandler,
) -> pydantic_core.core_schema.CoreSchema:
def validate_from_str(value: str) -> CompiledQuery:
result = compile_yaql_query(value)
return result
from_str_schema = pydantic_core.core_schema.chain_schema(
[
pydantic_core.core_schema.str_schema(),
pydantic_core.core_schema.no_info_plain_validator_function(
validate_from_str
),
]
)
return pydantic_core.core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=pydantic_core.core_schema.union_schema(
[
# check if it's an instance first before doing any further work
pydantic_core.core_schema.is_instance_schema(CompiledQuery),
from_str_schema,
]
),
serialization=pydantic_core.core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.expression
),
)
@classmethod
def __get_pydantic_json_schema__(
cls,
_core_schema: pydantic_core.core_schema.CoreSchema,
handler: pydantic.GetJsonSchemaHandler,
) -> pydantic.json_schema.JsonSchemaValue:
return handler(pydantic_core.core_schema.str_schema())
@dataclass
class SubscriptionV1:
handler: str
query: t.Annotated[CompiledQuery, _CompiledQueryPydanticAnnotation]
version: StateVersionIdV1
last_sent_state_hash: bytes | None = None
type SubscriptionsV1 = t.Mapping[WorkflowId, SubscriptionV1]
type JsonCollection = t.Mapping | t.Sequence
@dataclass(frozen=True)
class StateUpdateNotificationV1:
sender: SenderIdV1
state: JsonCollection
state_version: StateVersionIdV1
type_: t.Literal["412e3c98-731f-4818-8a2f-2395f42272d0"] = (
"412e3c98-731f-4818-8a2f-2395f42272d0"
)
class _CreateNotification(t.Protocol):
def __call__(
self,
sender: SenderIdV1,
state_version: StateVersionIdV1,
state: JsonCollection,
) -> StateUpdateNotificationV1: ...
def _create_default_notification(
sender: SenderIdV1,
state_version: StateVersionIdV1,
state: JsonCollection,
) -> StateUpdateNotificationV1:
return StateUpdateNotificationV1(
sender=sender,
state_version=state_version,
state=state,
)
class StateUpdatesEmitter:
def __init__(
self,
get_state: t.Callable[[StateVersionIdV1], t.Awaitable[JsonCollection]],
subscriptions: SubscriptionsV1 | None = None,
):
self._get_state = get_state
self._subscriptions = dict[WorkflowId, SubscriptionV1]()
if subscriptions:
self._subscriptions.update(subscriptions)
self._sender_info = wf.info()
async def process(self, event: SubscribedEventV1 | UnsubscribeEventV1) -> None:
if isinstance(event, SubscribedEventV1):
# NOTE: If we're re-subscribing, we overwrite the existing subscription
subscription = self._subscriptions[event.subscriber.wf_id] = SubscriptionV1(
handler=event.handler,
query=compile_yaql_query(event.state_query),
version=event.state_version,
)
state = await self._get_state(event.state_version)
await self._send_notification(
_create_default_notification, # TODO: Make parameterizable
event.subscriber.wf_id,
subscription,
state,
)
elif isinstance(event, UnsubscribeEventV1):
self._subscriptions.pop(event.subscriber.wf_id, None)
async def fire(
self, create_notification: _CreateNotification = _create_default_notification
) -> None:
version_to_state_f = dict[StateVersionIdV1, t.Awaitable[JsonCollection]]()
errors = []
# NOTE: We have to use a loop, and not something like asyncio.gather or asyncio.TaskGroup,
# because they don't guarantee the order of coroutines execution, and we can't afford
# non-determinism in Temporal workflows.
for subscriber_wid, subscription in tuple(self._subscriptions.items()):
if subscription.version not in version_to_state_f:
version_to_state_f[subscription.version] = self._get_state(
subscription.version
)
# FIXME: Handle failures
try:
state = await version_to_state_f[subscription.version]
await self._send_notification(
create_notification,
subscriber_wid,
subscription,
state,
)
except Exception as e:
errors.append(e)
if errors:
raise ExceptionGroup(
"Failed to send notifications to some subscribers",
errors,
)
async def _send_notification(
self,
create_notification: _CreateNotification,
subscriber_wid: WorkflowId,
subscription: SubscriptionV1,
state: JsonCollection,
) -> None:
try:
state = subscription.query.evaluate(data=state) # type: ignore
if state is None:
return
state_hash = deterministic_json_hash(state)
if state_hash == subscription.last_sent_state_hash:
# No change in state, skip sending notification
return
notification = create_notification(
SenderIdV1(
wf_id=self._sender_info.workflow_id,
run_id=self._sender_info.run_id,
),
state_version=subscription.version,
state=state,
)
subscriber_handle = wf.get_external_workflow_handle(
workflow_id=subscriber_wid,
)
try:
await subscriber_handle.signal(
subscription.handler,
notification,
)
except RPCError as e:
# FIXME: Log
if e.status != RPCStatusCode.NOT_FOUND:
raise
# The subscriber workflow is not running, remove the subscription
self._subscriptions.pop(subscriber_wid, None)
else:
subscription.last_sent_state_hash = state_hash
except Exception as e:
e.add_note(
f"Failed to send notification to subscriber '{subscriber_wid}' "
f"with handler '{subscription.handler}' and query '{subscription.query.expression}'"
)
raise e
@dataclass(frozen=True)
class Subscribe:
_subscription_management_signal_handler: str
async def __call__(
self,
emitter_wid: WorkflowId,
handler: str | t.Callable[[t.Any], t.Any],
query: str = "$",
*,
emitter_rid: str | None = None,
) -> None:
handle = wf.get_external_workflow_handle(
workflow_id=emitter_wid,
run_id=emitter_rid,
)
if isinstance(handler, MethodType):
handler = handler.__func__
if callable(handler):
handler = handler.__name__
if not isinstance(handler, str):
raise TypeError(
"Handler must be a name of a signal handler method or a reference to it"
)
subscriber_info = wf.info()
await handle.signal(
self._subscription_management_signal_handler,
SubscribedEventV1(
subscriber=SenderIdV1(
wf_id=subscriber_info.workflow_id,
run_id=subscriber_info.run_id,
),
handler=handler,
state_query=ensure_yaql_query(query),
),
)
@dataclass(frozen=True)
class Unsubscribe:
_subscription_management_signal_handler: str
async def __call__(
self,
emitter_wid: WorkflowId,
*,
emitter_rid: str | None = None,
) -> None:
handle = wf.get_external_workflow_handle(
emitter_wid,
run_id=emitter_rid,
)
subscriber_info = wf.info()
await handle.signal(
self._subscription_management_signal_handler,
UnsubscribeEventV1(
subscriber=SenderIdV1(
wf_id=subscriber_info.workflow_id,
run_id=subscriber_info.run_id,
),
),
)
@dataclass(frozen=True)
class Subscribed:
_subscribe: Subscribe
_unsubscribe: Unsubscribe
@contextlib.asynccontextmanager
async def __call__(
self,
emitter_wid: WorkflowId,
handler: str | t.Callable[[t.Any], t.Any],
query: str = "*",
*,
emitter_rid: str | None = None,
) -> t.AsyncIterator[None]:
await self._subscribe(
emitter_wid,
handler,
query,
emitter_rid=emitter_rid,
)
try:
yield
finally:
await self._unsubscribe(
emitter_wid,
emitter_rid=emitter_rid,
)
class SubscriptionFuncs(t.NamedTuple):
subscribe: Subscribe
unsubscribe: Unsubscribe
subscribed: Subscribed
class SubscriptionSignalsHandlerMethodDecl(t.Protocol):
__name__: str
def __call__(
_, # type: ignore
self,
event: SubscriptionManagementEventV1,
) -> t.Any: ...
def create_subscription_funcs(
handler: SubscriptionSignalsHandlerMethodDecl | str,
) -> SubscriptionFuncs:
if callable(handler):
handler = handler.__name__
subscribe = Subscribe(_subscription_management_signal_handler=handler)
unsubscribe = Unsubscribe(_subscription_management_signal_handler=handler)
return SubscriptionFuncs(
subscribe=subscribe,
unsubscribe=unsubscribe,
subscribed=Subscribed(
_subscribe=subscribe,
_unsubscribe=unsubscribe,
),
)
def ensure_yaql_query(query: str) -> Query:
compile_yaql_query(query)
return t.cast(Query, query)
@functools.lru_cache(maxsize=None)
def compile_yaql_query(query: str) -> CompiledQuery:
return _yaql_engine(query)
def deterministic_json_hash(obj) -> bytes:
"""
Computes a deterministic SHA256 hash of a JSON-serializable object.
Ensures consistent hashing by sorting dictionary keys and using compact separators.
This is crucial for Temporal's determinism requirements.
"""
try:
json_string = json.dumps(
obj,
sort_keys=True,
separators=(",", ":"),
ensure_ascii=False,
)
except TypeError as e:
raise ApplicationError(
f"Object not JSON serializable for deterministic hashing: {e}. "
"Ensure state data is JSON-serializable."
)
return hashlib.sha256(json_string.encode("utf-8")).digest()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment