Created
February 25, 2025 02:17
-
-
Save darinkishore/2153f5245e1c60948a99b10c9816d605 to your computer and use it in GitHub Desktop.
Reasoning Model DSPy adapter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ast | |
| import enum | |
| import inspect | |
| import json | |
| import re | |
| from typing import Any, Literal | |
| from dspy.adapters.base import Adapter | |
| from dspy.adapters.image_utils import try_expand_image_tags | |
| from dspy.adapters.utils import get_annotation_name, serialize_for_json | |
| from dspy.signatures.signature import Signature | |
| from dspy.signatures.utils import get_dspy_field_type | |
| from pydantic import TypeAdapter | |
| from pydantic.fields import FieldInfo | |
| def parse_value(raw_str: str, annotation: Any) -> Any: | |
| """ | |
| Converts a raw string extracted from the <fieldName>...</fieldName> block | |
| into the target Python type as indicated by the signature's `annotation`. | |
| If `annotation` is a simple type like str, int, or a list, tries to parse JSON or | |
| `ast.literal_eval`, falling back to the string as-is if it cannot parse. | |
| """ | |
| # If the output type is just string, return directly | |
| if annotation is str: | |
| return raw_str.strip() | |
| # If the output type is an Enum, we try to match an enum member by name or value | |
| if inspect.isclass(annotation) and issubclass(annotation, enum.Enum): | |
| return match_enum_value(annotation, raw_str.strip()) | |
| # Otherwise, try to parse as JSON or a Python literal | |
| raw_str = raw_str.strip() | |
| parsed = None | |
| try: | |
| parsed = json.loads(raw_str) | |
| except (json.JSONDecodeError, TypeError): | |
| # Not valid JSON, try Python literal | |
| try: | |
| parsed = ast.literal_eval(raw_str) | |
| except (SyntaxError, ValueError): | |
| # Just leave as-is | |
| parsed = raw_str | |
| # Finally, validate it against the annotation using pydantic | |
| return TypeAdapter(annotation).validate_python(parsed) | |
| def match_enum_value(enum_cls: enum.Enum, raw_value: str) -> enum.Enum: | |
| """ | |
| Finds an enum member corresponding to a user-provided value (name or value). | |
| Raises ValueError if there's no match. | |
| """ | |
| # Try by .name | |
| for member in enum_cls: | |
| if member.name == raw_value: | |
| return member | |
| # Try by .value | |
| for member in enum_cls: | |
| if str(member.value) == raw_value: | |
| return member | |
| raise ValueError(f"Could not match '{raw_value}' to any member of {enum_cls.__name__}.") | |
| def enumerate_fields(fields: dict) -> str: | |
| """ | |
| Creates a numbered list of fields with their types and descriptions, | |
| similar to ChatAdapter's enumerate_fields function. | |
| """ | |
| parts = [] | |
| for idx, (k, v) in enumerate(fields.items()): | |
| parts.append(f'{idx + 1}. `{k}`') | |
| parts[-1] += f' ({get_annotation_name(v.annotation)})' | |
| desc = v.description or v.json_schema_extra.get('desc', '') | |
| if desc: | |
| parts[-1] += f': {desc}' | |
| return '\n'.join(parts).strip() | |
| def field_metadata(field_name: str, field_info: FieldInfo) -> str: | |
| """ | |
| Generates metadata about a field's type constraints, | |
| similar to ChatAdapter's field_metadata function. | |
| """ | |
| field_type = field_info.annotation | |
| if field_type is str: | |
| return '' | |
| elif field_type is bool: | |
| return ' (must be True or False)' | |
| elif field_type in (int, float): | |
| return f' (must be a single {field_type.__name__} value)' | |
| elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): | |
| return f' (must be one of: {"; ".join(field_type.__members__)})' | |
| elif hasattr(field_type, '__origin__') and field_type.__origin__ is Literal: | |
| return f' (must exactly match one of: {"; ".join([str(x) for x in field_type.__args__])})' | |
| elif field_info.json_schema_extra.get('format') == 'json' or hasattr(field_info, 'schema'): | |
| constraint = ' (must follow a valid JSON structure)' | |
| if hasattr(field_info, 'schema') and field_info.schema: | |
| constraint += f' Must follow schema: {field_info.schema}' | |
| return constraint | |
| return '' | |
| def create_system_message(signature: Signature) -> str: | |
| """ | |
| Creates a system prompt describing the goal <task>, the input fields, the output fields, | |
| and how the final <input>/<output> blocks must be structured. Mirrors the structure of | |
| ChatAdapter's prepare_instructions function. | |
| """ | |
| parts = [] | |
| # 1) Task description | |
| task_part = f"""Your objective is to carefully fulfill the following task: | |
| <task>\n{signature.instructions}\n</task>""" | |
| parts.append(task_part) | |
| # 2) Input and output field descriptions | |
| parts.append( | |
| 'In order to do so, you will be provided with:\n' + enumerate_fields(signature.input_fields) | |
| ) | |
| parts.append( | |
| 'Based on these, you must produce the following outputs to complete the task:\n' | |
| + enumerate_fields(signature.output_fields) | |
| ) | |
| # 3) Structure template | |
| parts.append('All interactions will be structured as follows:') | |
| # Define inline function for field metadata, similar to ChatAdapter | |
| def field_metadata(field_name, field_info): | |
| field_type = field_info.annotation | |
| if get_dspy_field_type(field_info) == 'input' or field_type is str: | |
| return f'{{{field_name}}}' | |
| elif field_type is bool: | |
| return f'{{{field_name}}} # note: the value you produce must be True or False' | |
| elif field_type in (int, float): | |
| return f'{{{field_name}}} # note: the value you produce must be a single {field_type.__name__} value' | |
| elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): | |
| return f'{{{field_name}}} # note: the value you produce must be one of: {"; ".join(field_type.__members__)}' | |
| elif hasattr(field_type, '__origin__') and field_type.__origin__ is Literal: | |
| return f'{{{field_name}}} # note: the value you produce must exactly match (no extra characters) one of: {"; ".join([str(x) for x in field_type.__args__])}' | |
| else: | |
| # For complex types, include the full JSON schema | |
| schema_text = 'must be parseable according to the following JSON schema:\n' | |
| schema_text += prepare_schema(field_type) | |
| return f'\n{{{field_name}}}\n# Note: The value you produce {schema_text}' | |
| # Define function to prepare schema, like in ChatAdapter | |
| def prepare_schema(type_): | |
| schema = TypeAdapter(type_).json_schema() | |
| schema = move_type_to_front(schema) | |
| return json.dumps(schema, indent=2, ensure_ascii=False) | |
| # Define function to move 'type' key to front for better readability | |
| def move_type_to_front(d): | |
| if isinstance(d, dict): | |
| return { | |
| k: move_type_to_front(v) | |
| for k, v in sorted(d.items(), key=lambda item: (item[0] != 'type', item[0])) | |
| } | |
| elif isinstance(d, list): | |
| return [move_type_to_front(item) for item in d] | |
| return d | |
| # Define inline function to format XML template with metadata | |
| def format_xml_template(fields, tag_type): | |
| lines = [] | |
| lines.append(f'<{tag_type}>') | |
| for name, field_info in fields.items(): | |
| lines.append(f'<{name}>\n{field_metadata(name, field_info)}\n</{name}>') | |
| lines.append(f'</{tag_type}>') | |
| return '\n'.join(lines) | |
| # Add input and output templates with metadata | |
| parts.append(format_xml_template(signature.input_fields, 'input')) | |
| parts.append(format_xml_template(signature.output_fields, 'output')) | |
| # 4) Format reminder | |
| format_reminder = """When responding, make sure to: | |
| 1. Perform the task faithfully based on the content within the <task> tags. | |
| 2. Produce all required output fields within <output> tags | |
| 3. Format each field with its own <field_name>...</field_name> tags""" | |
| if any(field_info.annotation is not str for _, field_info in signature.output_fields.items()): | |
| format_reminder += '\n4. Ensure each field value matches its expected type' | |
| parts.append(format_reminder) | |
| # Combine everything | |
| return '\n\n'.join(parts).strip() | |
| def format_input_xml(signature: Signature, data: dict[str, Any]) -> str: | |
| """ | |
| Given the signature's input fields, produce the <input> ... </input> block, | |
| filling in placeholders with the actual data from `data`. | |
| """ | |
| lines = [] | |
| for field_name, field_info in signature.input_fields.items(): | |
| raw_value = data.get(field_name, 'Not provided in this example.') | |
| # Convert to a string for inclusion in the XML | |
| # We do not parse or re-check schema here, just format | |
| text_value = serialize_for_json(raw_value) | |
| # Add newlines after opening tag and before closing tag | |
| lines.append(f'<{field_name}>\n{text_value}\n</{field_name}>\n') | |
| return '<input>\n\n' + '\n'.join(lines) + '\n</input>' | |
| def format_output_xml(signature: Signature, data: dict[str, Any]) -> str: | |
| """ | |
| Produces the <output> ... </output> block for the signature's output fields. | |
| """ | |
| lines = [] | |
| for field_name, field_info in signature.output_fields.items(): | |
| raw_value = data.get(field_name, 'Not provided in this example.') | |
| text_value = serialize_for_json(raw_value) | |
| # Add newlines after opening tag and before closing tag | |
| lines.append(f'<{field_name}>\n{text_value}\n</{field_name}>') | |
| return '<output>\n' + '\n'.join(lines) + '\n</output>' | |
| def format_turn( | |
| signature: Signature, values: dict[str, Any], role: str, incomplete: bool = False | |
| ) -> dict[str, str]: | |
| """ | |
| Constructs a new message ("turn") to append to a chat thread. The message is formatted | |
| with XML tags to instruct an LLM to generate responses conforming to the specified DSPy signature. | |
| Args: | |
| signature: The DSPy signature to which future LLM responses should conform. | |
| values: A dictionary mapping field names (from the DSPy signature) to corresponding values | |
| that should be included in the message. | |
| role: The role of the message, which can be either "user" or "assistant". | |
| incomplete: If True, indicates that output field values are present in the set of specified | |
| ``values``. If False, indicates that ``values`` only contains input field values. | |
| Returns: | |
| A chat message that can be appended to a chat thread. The message contains two string fields: | |
| ``role`` ("user" or "assistant") and ``content`` (the message text). | |
| """ | |
| message_parts = [] | |
| # Add a prefix for incomplete examples | |
| if incomplete: | |
| message_parts.append( | |
| 'This is an example of the task, though some input or output fields are not supplied.' | |
| ) | |
| if role == 'user': | |
| # For user messages, format the input fields | |
| input_block = format_input_xml(signature, values) | |
| message_parts.append(f'Example input:\n{input_block}' if incomplete else input_block) | |
| # Add instructions for output fields if this is a real user message (not a demo) | |
| if not incomplete and signature.output_fields: | |
| message_parts.append('Based on the above <input>, please produce the correct <output>.') | |
| else: # role == "assistant" | |
| # For assistant messages, format the output fields | |
| output_block = format_output_xml(signature, values) | |
| message_parts.append(f'Example output:\n{output_block}' if incomplete else output_block) | |
| return {'role': role, 'content': '\n\n'.join(message_parts)} | |
| def parse_output_xml(signature: Signature, completion: str) -> dict[str, Any]: | |
| """ | |
| Extracts each required output field from the completion by searching | |
| <fieldName> ... </fieldName>. If not found, raises an error. Then passes | |
| the extracted text to parse_value for type conversion. | |
| """ | |
| results = {} | |
| for field_name, field_info in signature.output_fields.items(): | |
| # Regex with DOTALL so we can match multiline | |
| pattern = re.compile(rf'<{field_name}>(.*?)</{field_name}>', re.DOTALL) | |
| match = pattern.search(completion) | |
| if not match: | |
| raise ValueError(f'Missing <{field_name}>...</{field_name}> block in completion.') | |
| raw_field_text = match.group(1) | |
| parsed = parse_value(raw_field_text, field_info.annotation) | |
| results[field_name] = parsed | |
| # Verify the user provided everything | |
| if set(results.keys()) != set(signature.output_fields.keys()): | |
| raise ValueError( | |
| f'Expected output fields {list(signature.output_fields.keys())}, ' | |
| f"but only got {list(results.keys())} from the LLM's completion." | |
| ) | |
| return results | |
| # TODO: fix type s/t reddit_mental_health.pipeline.core.models.CriterionMatch becomes CriterionMatch | |
| class SemanticAdapter(Adapter): | |
| """ | |
| A DSPy Adapter that formats input/output using an XML/tag-based scheme: | |
| <task>...</task> | |
| <input>...</input> | |
| <output>...</output> | |
| It generates a sequence of messages (system + user/assistant pairs) similarly | |
| to ChatAdapter, but the final user message requests the model to produce the | |
| <output> block. The parse method then extracts those fields from the LLM's | |
| reply. | |
| """ | |
| def format(self, signature: Signature, demos: list[dict], inputs: dict[str, Any]): | |
| """ | |
| Builds a list of messages: | |
| 1. system: holds the instructions (task docstring, field descriptions, etc.) | |
| 2. For each demo that has some input+output, we add: | |
| - user: with <input> data | |
| - assistant: with <output> data | |
| 3. A final user message: with <input> from `inputs` | |
| The LLM is then expected to produce the final <output> block as a reply. | |
| """ | |
| # Build the system message first | |
| system_content = create_system_message(signature) | |
| messages = [{'role': 'system', 'content': system_content}] | |
| # Partition demos into incomplete vs complete | |
| # (mimicking ChatAdapter's logic, though it may be optional) | |
| incomplete_demos = [ | |
| d | |
| for d in demos | |
| if not all(k in d for k in signature.output_fields) # missing some output | |
| ] | |
| complete_demos = [d for d in demos if d not in incomplete_demos] | |
| # We only form user→assistant pairs for demos that have at least some input | |
| # and some output. Incomplete demos → user=some inputs, assistant=some outputs | |
| demos_ordered = incomplete_demos + complete_demos | |
| for demo in demos_ordered: | |
| # Insert a user message | |
| user_msg = format_turn( | |
| signature, demo, role='user', incomplete=demo in incomplete_demos | |
| ) | |
| messages.append(user_msg) | |
| # Insert an assistant message if it has any output fields | |
| assistant_msg = format_turn( | |
| signature, demo, role='assistant', incomplete=demo in incomplete_demos | |
| ) | |
| messages.append(assistant_msg) | |
| # Finally, the real user input | |
| final_user_msg = format_turn(signature, inputs, role='user', incomplete=False) | |
| messages.append(final_user_msg) | |
| # Expand any <image> tags, etc. (like ChatAdapter does) | |
| messages = try_expand_image_tags(messages) | |
| return messages | |
| def parse(self, signature: Signature, completion: str) -> dict[str, Any]: | |
| """ | |
| Takes the final LLM output (the assistant's completion), | |
| extracts <fieldName>...</fieldName> for each output field, | |
| attempts to parse it, and returns the combined dictionary. | |
| """ | |
| # Attempt to parse | |
| try: | |
| parsed_dict = parse_output_xml(signature, completion) | |
| # Ensure we only return the signature's outputs | |
| # (parse_output_xml already raises an error if something is missing) | |
| return parsed_dict | |
| except Exception: | |
| # If there's a parse error or missing field, fallback to JSONAdapter | |
| from dspy.adapters.json_adapter import JSONAdapter | |
| return JSONAdapter().parse(signature, completion) | |
| def format_finetune_data( | |
| self, | |
| signature: Signature, | |
| demos: list[dict], | |
| inputs: dict[str, Any], | |
| outputs: dict[str, Any], | |
| ): | |
| """ | |
| For fine-tuning: we produce a single training example with system + user msgs, | |
| then we add the ground-truth <output> as an assistant message. Similar to ChatAdapter. | |
| """ | |
| messages = self.format(signature, demos, inputs) | |
| # Now append the correct output block as an assistant message | |
| assistant_msg = format_turn(signature, outputs, role='assistant', incomplete=False) | |
| messages.append(assistant_msg) | |
| return {'messages': messages} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment