Skip to content

Instantly share code, notes, and snippets.

@darinkishore
Created February 25, 2025 02:17
Show Gist options
  • Select an option

  • Save darinkishore/2153f5245e1c60948a99b10c9816d605 to your computer and use it in GitHub Desktop.

Select an option

Save darinkishore/2153f5245e1c60948a99b10c9816d605 to your computer and use it in GitHub Desktop.
Reasoning Model DSPy adapter
import ast
import enum
import inspect
import json
import re
from typing import Any, Literal
from dspy.adapters.base import Adapter
from dspy.adapters.image_utils import try_expand_image_tags
from dspy.adapters.utils import get_annotation_name, serialize_for_json
from dspy.signatures.signature import Signature
from dspy.signatures.utils import get_dspy_field_type
from pydantic import TypeAdapter
from pydantic.fields import FieldInfo
def parse_value(raw_str: str, annotation: Any) -> Any:
"""
Converts a raw string extracted from the <fieldName>...</fieldName> block
into the target Python type as indicated by the signature's `annotation`.
If `annotation` is a simple type like str, int, or a list, tries to parse JSON or
`ast.literal_eval`, falling back to the string as-is if it cannot parse.
"""
# If the output type is just string, return directly
if annotation is str:
return raw_str.strip()
# If the output type is an Enum, we try to match an enum member by name or value
if inspect.isclass(annotation) and issubclass(annotation, enum.Enum):
return match_enum_value(annotation, raw_str.strip())
# Otherwise, try to parse as JSON or a Python literal
raw_str = raw_str.strip()
parsed = None
try:
parsed = json.loads(raw_str)
except (json.JSONDecodeError, TypeError):
# Not valid JSON, try Python literal
try:
parsed = ast.literal_eval(raw_str)
except (SyntaxError, ValueError):
# Just leave as-is
parsed = raw_str
# Finally, validate it against the annotation using pydantic
return TypeAdapter(annotation).validate_python(parsed)
def match_enum_value(enum_cls: enum.Enum, raw_value: str) -> enum.Enum:
"""
Finds an enum member corresponding to a user-provided value (name or value).
Raises ValueError if there's no match.
"""
# Try by .name
for member in enum_cls:
if member.name == raw_value:
return member
# Try by .value
for member in enum_cls:
if str(member.value) == raw_value:
return member
raise ValueError(f"Could not match '{raw_value}' to any member of {enum_cls.__name__}.")
def enumerate_fields(fields: dict) -> str:
"""
Creates a numbered list of fields with their types and descriptions,
similar to ChatAdapter's enumerate_fields function.
"""
parts = []
for idx, (k, v) in enumerate(fields.items()):
parts.append(f'{idx + 1}. `{k}`')
parts[-1] += f' ({get_annotation_name(v.annotation)})'
desc = v.description or v.json_schema_extra.get('desc', '')
if desc:
parts[-1] += f': {desc}'
return '\n'.join(parts).strip()
def field_metadata(field_name: str, field_info: FieldInfo) -> str:
"""
Generates metadata about a field's type constraints,
similar to ChatAdapter's field_metadata function.
"""
field_type = field_info.annotation
if field_type is str:
return ''
elif field_type is bool:
return ' (must be True or False)'
elif field_type in (int, float):
return f' (must be a single {field_type.__name__} value)'
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
return f' (must be one of: {"; ".join(field_type.__members__)})'
elif hasattr(field_type, '__origin__') and field_type.__origin__ is Literal:
return f' (must exactly match one of: {"; ".join([str(x) for x in field_type.__args__])})'
elif field_info.json_schema_extra.get('format') == 'json' or hasattr(field_info, 'schema'):
constraint = ' (must follow a valid JSON structure)'
if hasattr(field_info, 'schema') and field_info.schema:
constraint += f' Must follow schema: {field_info.schema}'
return constraint
return ''
def create_system_message(signature: Signature) -> str:
"""
Creates a system prompt describing the goal <task>, the input fields, the output fields,
and how the final <input>/<output> blocks must be structured. Mirrors the structure of
ChatAdapter's prepare_instructions function.
"""
parts = []
# 1) Task description
task_part = f"""Your objective is to carefully fulfill the following task:
<task>\n{signature.instructions}\n</task>"""
parts.append(task_part)
# 2) Input and output field descriptions
parts.append(
'In order to do so, you will be provided with:\n' + enumerate_fields(signature.input_fields)
)
parts.append(
'Based on these, you must produce the following outputs to complete the task:\n'
+ enumerate_fields(signature.output_fields)
)
# 3) Structure template
parts.append('All interactions will be structured as follows:')
# Define inline function for field metadata, similar to ChatAdapter
def field_metadata(field_name, field_info):
field_type = field_info.annotation
if get_dspy_field_type(field_info) == 'input' or field_type is str:
return f'{{{field_name}}}'
elif field_type is bool:
return f'{{{field_name}}} # note: the value you produce must be True or False'
elif field_type in (int, float):
return f'{{{field_name}}} # note: the value you produce must be a single {field_type.__name__} value'
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
return f'{{{field_name}}} # note: the value you produce must be one of: {"; ".join(field_type.__members__)}'
elif hasattr(field_type, '__origin__') and field_type.__origin__ is Literal:
return f'{{{field_name}}} # note: the value you produce must exactly match (no extra characters) one of: {"; ".join([str(x) for x in field_type.__args__])}'
else:
# For complex types, include the full JSON schema
schema_text = 'must be parseable according to the following JSON schema:\n'
schema_text += prepare_schema(field_type)
return f'\n{{{field_name}}}\n# Note: The value you produce {schema_text}'
# Define function to prepare schema, like in ChatAdapter
def prepare_schema(type_):
schema = TypeAdapter(type_).json_schema()
schema = move_type_to_front(schema)
return json.dumps(schema, indent=2, ensure_ascii=False)
# Define function to move 'type' key to front for better readability
def move_type_to_front(d):
if isinstance(d, dict):
return {
k: move_type_to_front(v)
for k, v in sorted(d.items(), key=lambda item: (item[0] != 'type', item[0]))
}
elif isinstance(d, list):
return [move_type_to_front(item) for item in d]
return d
# Define inline function to format XML template with metadata
def format_xml_template(fields, tag_type):
lines = []
lines.append(f'<{tag_type}>')
for name, field_info in fields.items():
lines.append(f'<{name}>\n{field_metadata(name, field_info)}\n</{name}>')
lines.append(f'</{tag_type}>')
return '\n'.join(lines)
# Add input and output templates with metadata
parts.append(format_xml_template(signature.input_fields, 'input'))
parts.append(format_xml_template(signature.output_fields, 'output'))
# 4) Format reminder
format_reminder = """When responding, make sure to:
1. Perform the task faithfully based on the content within the <task> tags.
2. Produce all required output fields within <output> tags
3. Format each field with its own <field_name>...</field_name> tags"""
if any(field_info.annotation is not str for _, field_info in signature.output_fields.items()):
format_reminder += '\n4. Ensure each field value matches its expected type'
parts.append(format_reminder)
# Combine everything
return '\n\n'.join(parts).strip()
def format_input_xml(signature: Signature, data: dict[str, Any]) -> str:
"""
Given the signature's input fields, produce the <input> ... </input> block,
filling in placeholders with the actual data from `data`.
"""
lines = []
for field_name, field_info in signature.input_fields.items():
raw_value = data.get(field_name, 'Not provided in this example.')
# Convert to a string for inclusion in the XML
# We do not parse or re-check schema here, just format
text_value = serialize_for_json(raw_value)
# Add newlines after opening tag and before closing tag
lines.append(f'<{field_name}>\n{text_value}\n</{field_name}>\n')
return '<input>\n\n' + '\n'.join(lines) + '\n</input>'
def format_output_xml(signature: Signature, data: dict[str, Any]) -> str:
"""
Produces the <output> ... </output> block for the signature's output fields.
"""
lines = []
for field_name, field_info in signature.output_fields.items():
raw_value = data.get(field_name, 'Not provided in this example.')
text_value = serialize_for_json(raw_value)
# Add newlines after opening tag and before closing tag
lines.append(f'<{field_name}>\n{text_value}\n</{field_name}>')
return '<output>\n' + '\n'.join(lines) + '\n</output>'
def format_turn(
signature: Signature, values: dict[str, Any], role: str, incomplete: bool = False
) -> dict[str, str]:
"""
Constructs a new message ("turn") to append to a chat thread. The message is formatted
with XML tags to instruct an LLM to generate responses conforming to the specified DSPy signature.
Args:
signature: The DSPy signature to which future LLM responses should conform.
values: A dictionary mapping field names (from the DSPy signature) to corresponding values
that should be included in the message.
role: The role of the message, which can be either "user" or "assistant".
incomplete: If True, indicates that output field values are present in the set of specified
``values``. If False, indicates that ``values`` only contains input field values.
Returns:
A chat message that can be appended to a chat thread. The message contains two string fields:
``role`` ("user" or "assistant") and ``content`` (the message text).
"""
message_parts = []
# Add a prefix for incomplete examples
if incomplete:
message_parts.append(
'This is an example of the task, though some input or output fields are not supplied.'
)
if role == 'user':
# For user messages, format the input fields
input_block = format_input_xml(signature, values)
message_parts.append(f'Example input:\n{input_block}' if incomplete else input_block)
# Add instructions for output fields if this is a real user message (not a demo)
if not incomplete and signature.output_fields:
message_parts.append('Based on the above <input>, please produce the correct <output>.')
else: # role == "assistant"
# For assistant messages, format the output fields
output_block = format_output_xml(signature, values)
message_parts.append(f'Example output:\n{output_block}' if incomplete else output_block)
return {'role': role, 'content': '\n\n'.join(message_parts)}
def parse_output_xml(signature: Signature, completion: str) -> dict[str, Any]:
"""
Extracts each required output field from the completion by searching
<fieldName> ... </fieldName>. If not found, raises an error. Then passes
the extracted text to parse_value for type conversion.
"""
results = {}
for field_name, field_info in signature.output_fields.items():
# Regex with DOTALL so we can match multiline
pattern = re.compile(rf'<{field_name}>(.*?)</{field_name}>', re.DOTALL)
match = pattern.search(completion)
if not match:
raise ValueError(f'Missing <{field_name}>...</{field_name}> block in completion.')
raw_field_text = match.group(1)
parsed = parse_value(raw_field_text, field_info.annotation)
results[field_name] = parsed
# Verify the user provided everything
if set(results.keys()) != set(signature.output_fields.keys()):
raise ValueError(
f'Expected output fields {list(signature.output_fields.keys())}, '
f"but only got {list(results.keys())} from the LLM's completion."
)
return results
# TODO: fix type s/t reddit_mental_health.pipeline.core.models.CriterionMatch becomes CriterionMatch
class SemanticAdapter(Adapter):
"""
A DSPy Adapter that formats input/output using an XML/tag-based scheme:
<task>...</task>
<input>...</input>
<output>...</output>
It generates a sequence of messages (system + user/assistant pairs) similarly
to ChatAdapter, but the final user message requests the model to produce the
<output> block. The parse method then extracts those fields from the LLM's
reply.
"""
def format(self, signature: Signature, demos: list[dict], inputs: dict[str, Any]):
"""
Builds a list of messages:
1. system: holds the instructions (task docstring, field descriptions, etc.)
2. For each demo that has some input+output, we add:
- user: with <input> data
- assistant: with <output> data
3. A final user message: with <input> from `inputs`
The LLM is then expected to produce the final <output> block as a reply.
"""
# Build the system message first
system_content = create_system_message(signature)
messages = [{'role': 'system', 'content': system_content}]
# Partition demos into incomplete vs complete
# (mimicking ChatAdapter's logic, though it may be optional)
incomplete_demos = [
d
for d in demos
if not all(k in d for k in signature.output_fields) # missing some output
]
complete_demos = [d for d in demos if d not in incomplete_demos]
# We only form user→assistant pairs for demos that have at least some input
# and some output. Incomplete demos → user=some inputs, assistant=some outputs
demos_ordered = incomplete_demos + complete_demos
for demo in demos_ordered:
# Insert a user message
user_msg = format_turn(
signature, demo, role='user', incomplete=demo in incomplete_demos
)
messages.append(user_msg)
# Insert an assistant message if it has any output fields
assistant_msg = format_turn(
signature, demo, role='assistant', incomplete=demo in incomplete_demos
)
messages.append(assistant_msg)
# Finally, the real user input
final_user_msg = format_turn(signature, inputs, role='user', incomplete=False)
messages.append(final_user_msg)
# Expand any <image> tags, etc. (like ChatAdapter does)
messages = try_expand_image_tags(messages)
return messages
def parse(self, signature: Signature, completion: str) -> dict[str, Any]:
"""
Takes the final LLM output (the assistant's completion),
extracts <fieldName>...</fieldName> for each output field,
attempts to parse it, and returns the combined dictionary.
"""
# Attempt to parse
try:
parsed_dict = parse_output_xml(signature, completion)
# Ensure we only return the signature's outputs
# (parse_output_xml already raises an error if something is missing)
return parsed_dict
except Exception:
# If there's a parse error or missing field, fallback to JSONAdapter
from dspy.adapters.json_adapter import JSONAdapter
return JSONAdapter().parse(signature, completion)
def format_finetune_data(
self,
signature: Signature,
demos: list[dict],
inputs: dict[str, Any],
outputs: dict[str, Any],
):
"""
For fine-tuning: we produce a single training example with system + user msgs,
then we add the ground-truth <output> as an assistant message. Similar to ChatAdapter.
"""
messages = self.format(signature, demos, inputs)
# Now append the correct output block as an assistant message
assistant_msg = format_turn(signature, outputs, role='assistant', incomplete=False)
messages.append(assistant_msg)
return {'messages': messages}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment