Created
December 10, 2025 13:56
-
-
Save igalshilman/0b5bab17b38307706614c5b6c5280dab to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| R = typing.TypeVar('R', bound=typing.Callable[..., typing.Awaitable[typing.Any]]) | |
| P = typing.ParamSpec('P' | |
| def handler(func): | |
| func._is_handler = True | |
| return func | |
| class Box: | |
| __slots__ = ('value',) | |
| def make_handler_wrapper(the_self: Box, func): | |
| @wraps(func) | |
| async def handler_wrapper(*args, **kwargs): | |
| self = the_self.value | |
| # this is an initial handler invocation from restate | |
| # this is called by restate when the handler is invoked | |
| # restate does not know that this is a method on a class, it expects a function with context and arg. | |
| # therefore we need to grab 'self' out of thin air, and pass it to the actual method. | |
| args = args[1:] | |
| return await func(self, *args, **kwargs) | |
| return handler_wrapper | |
| class VirtualObjectMetaClass(type): | |
| def __new__(cls, name, bases, attrs): | |
| vo_name = attrs.get('__name__', name) | |
| vo = restate.VirtualObject(vo_name) | |
| attrs['vo'] = vo | |
| the_self = Box() | |
| # wrap all methods marked as @handler | |
| for key, value in attrs.items(): | |
| if callable(value) and getattr(value, '_is_handler', False): | |
| wrapped = make_handler_wrapper(the_self, value) | |
| attrs[key] = vo.handler()(wrapped) | |
| # we need to create a constructor that sets the 'self' in the box | |
| # so that when the handler is called, it can retrieve the correct instance. | |
| original_init = attrs.get('__init__', None) | |
| def contr(self, *args, **kwargs): | |
| the_self.value = self | |
| if original_init: | |
| original_init(self, *args, **kwargs) | |
| return | |
| attrs['__init__'] = contr | |
| return super().__new__(cls, name, bases, attrs) | |
| class VirtualObjectBase(metaclass=VirtualObjectMetaClass): | |
| @classmethod | |
| def underlying_vo(cls): | |
| vo = cls['vo'] # type: ignore | |
| return typing.cast(restate.VirtualObject, vo) | |
| @property | |
| def context(self): | |
| return typing.cast(restate.ObjectContext, current_context()) | |
| class ClaimAgent(VirtualObjectBase): | |
| __name__ = "ClaimAgent" | |
| @handler | |
| async def run(self, message: ChatMessage) -> str | None: | |
| runner = Runner(app=app, session_service=session_service) | |
| events = runner.run_async( | |
| user_id=self.context.key(), | |
| session_id=message.session_id, | |
| new_message=Content(role="user", parts=[Part.from_text(text=message.message)]), | |
| ) | |
| final_response = None | |
| async for event in events: | |
| if event.is_final_response() and event.content and event.content.parts: | |
| if event.content.parts[0].text: | |
| final_response = event.content.parts[0].text | |
| return final_response | |
| class OneMore(VirtualObjectBase): | |
| __name__ = "OneMore" | |
| @handler | |
| async def ping(self, message: str) -> str: | |
| return f"Pong: {message}" | |
| agent_service = ClaimAgent() | |
| one_more_service = OneMore() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment