Skip to content

Instantly share code, notes, and snippets.

@tomy-kyu
Last active December 15, 2025 08:39
Show Gist options
  • Select an option

  • Save tomy-kyu/43057dd58302ab95c0f99d4b3dca9d51 to your computer and use it in GitHub Desktop.

Select an option

Save tomy-kyu/43057dd58302ab95c0f99d4b3dca9d51 to your computer and use it in GitHub Desktop.
Azure OpenAI Responses API Plugin for Dify

I attempted to create core plugin functionality for the Responses API by combining gpt-5.1-codex-max with Gemini3, but indeed, their specifications differ significantly from the Completion API. When attempting to use Vision features, I confirmed that GPT responses become noticeably slower. The responses come back extremely sluggish.

When interactions don't involve images, the system appears to function without issues. Since gpt-5.1-codex is difficult to use in Azure AI Foundry's Chat Playground, having it compatible with Dify would certainly be extremely helpful. In some cases, I might need to narrow down its functionality, but it's also possible that I just lack the necessary technical expertise. I'd like to have someone more experienced review this area to see if it could be improved further.

Using this setup, I've created a plugin for my self-hosted Dify-CE tenant that currently restricts itself to GPT-5.x versions. (I've excluded models like Embedding, TTS, and STT from this implementation.)

import base64
import copy
import io
import json
import logging
import math
from collections.abc import Generator
from typing import Optional, Union, cast
import tiktoken
from PIL import Image
from dify_plugin.entities.model import AIModelEntity
from dify_plugin.entities.model.llm import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from dify_plugin.entities.model.message import (
AssistantPromptMessage,
AudioPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
PromptMessageTool,
)
from dify_plugin.errors.model import CredentialsValidateFailedError
from dify_plugin.interfaces.model.large_language_model import LargeLanguageModel
from openai import AzureOpenAI, Stream
from openai.types.responses import ResponseStreamEvent, Response
from ..common import _CommonAzureOpenAI
from ..constants import LLM_BASE_MODELS
logger = logging.getLogger(__name__)
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
GPT-5系モデル専用のInvokeメソッド。
すべてのリクエストを Responses API (client.responses.create) で処理します。
"""
return self._chat_generate_with_responses(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
def validate_credentials(self, model: str, credentials: dict) -> None:
if "openai_api_base" not in credentials:
raise CredentialsValidateFailedError(
"Azure OpenAI API Base Endpoint is required"
)
auth_method = credentials.get("auth_method", "api_key")
if auth_method == "api_key" and "openai_api_key" not in credentials:
raise CredentialsValidateFailedError(
"Azure OpenAI API key is required when using API Key authentication"
)
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(
base_model_name=base_model_name, model=model
)
if not ai_model_entity:
raise CredentialsValidateFailedError(
f"Base Model Name {credentials['base_model_name']} is invalid"
)
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
# Responses API を使用して接続確認を行う
client.responses.create(
model=model,
input=[{"role": "user", "content": "ping"}],
max_output_tokens=20,
stream=False
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(
self, model: str, credentials: dict
) -> Optional[AIModelEntity]:
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(
base_model_name=base_model_name, model=model
)
return ai_model_entity.entity if ai_model_entity else None
def _chat_generate_with_responses(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Generate chat responses with the OpenAI Responses API.
Only supported parameters for Responses API are forwarded.
"""
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
input_messages = self._convert_prompt_messages_to_responses_input(prompt_messages)
responses_params = {
"model": model,
"input": input_messages,
}
# Responses API: use only max_output_tokens (do not accept max_tokens/max_completion_tokens)
if "max_output_tokens" in model_parameters:
responses_params["max_output_tokens"] = model_parameters["max_output_tokens"]
# Tools
if tools:
responses_params["tools"] = []
for tool in tools:
parameters = tool.parameters
if isinstance(parameters, str):
try:
parameters = json.loads(parameters)
except json.JSONDecodeError:
parameters = {"type": "object", "properties": {}}
elif not isinstance(parameters, dict):
parameters = {"type": "object", "properties": {}}
tool_dict = {
"type": "function",
"name": tool.name,
"description": tool.description or "",
"parameters": parameters
}
responses_params["tools"].append(tool_dict)
responses_params["tool_choice"] = "auto"
if user:
responses_params["user"] = user
if stop:
responses_params["stop"] = stop # Responses API supports stop sequences
# Structured output via text.format
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_schema":
json_schema_data = model_parameters.get("json_schema", {})
if isinstance(json_schema_data, str):
try:
json_schema_data = json.loads(json_schema_data)
except json.JSONDecodeError:
json_schema_data = {}
responses_params["text"] = {
"format": {
"type": "json_schema",
"name": json_schema_data.get("name", "response"),
"strict": json_schema_data.get("strict", True),
"schema": json_schema_data.get("json_schema", {})
}
}
else:
responses_params["text"] = {
"format": {"type": response_format}
}
# Reasoning effort (supported)
if "reasoning_effort" in model_parameters:
responses_params["reasoning"] = {"effort": model_parameters["reasoning_effort"]}
logger.info(
f"llm request with responses api: model={model}, stream={stream}, "
f"parameters={responses_params}"
)
response = client.responses.create(
**responses_params,
stream=stream,
)
if stream:
return self._handle_responses_stream_response(
model, credentials, response, prompt_messages, tools
)
else:
return self._handle_responses_response(
model, credentials, response, prompt_messages, tools
)
def _convert_prompt_messages_to_responses_input(
self, prompt_messages: list[PromptMessage]
) -> list[dict]:
"""
Convert Dify PromptMessage objects into Azure Responses API input format,
including image/audio support for VLM.
画像は data URI の場合、最大辺 1024px にリサイズし WebP(quality=80) へ圧縮する。
"""
input_messages: list[dict] = []
for message in prompt_messages:
# System → developer
if isinstance(message, SystemPromptMessage):
input_messages.append({
"role": "developer",
"content": message.content
})
continue
# User
if isinstance(message, UserPromptMessage):
if isinstance(message.content, str):
input_messages.append({
"role": "user",
"content": message.content
})
continue
content_parts = []
assert message.content is not None
for content_item in message.content:
# TEXT
if getattr(content_item, "type", None) in (
PromptMessageContentType.TEXT,
"text",
"input_text",
):
content_parts.append({
"type": "input_text",
"text": content_item.data,
})
# IMAGE
elif getattr(content_item, "type", None) in (
PromptMessageContentType.IMAGE,
"image_url",
"image",
"input_image",
):
image_url = content_item.data
if isinstance(image_url, dict) and "url" in image_url:
image_url = image_url["url"]
# ここで data URI のみ圧縮&リサイズ
image_url = self._prepare_image_url(image_url)
content_parts.append({
"type": "input_image",
"image_url": image_url, # string only (data URI or https URL)
})
# AUDIO
elif getattr(content_item, "type", None) in (
PromptMessageContentType.AUDIO,
"audio",
"input_audio",
):
content_parts.append({
"type": "input_audio",
"input_audio": {
"data": getattr(content_item, "base64_data", None) or content_item.data,
"format": getattr(content_item, "format", "wav"),
},
})
input_messages.append({
"role": "user",
"content": content_parts,
})
continue
# Assistant: 直前の応答などをコンテキストに積む
if isinstance(message, AssistantPromptMessage):
input_messages.append({
"role": "assistant",
"content": message.content,
})
continue
# Tool (function call の結果)
if isinstance(message, ToolPromptMessage):
input_messages.append({
"role": "assistant",
"content": message.content,
})
continue
return input_messages
def _handle_responses_response(
self,
model: str,
credentials: dict,
response: Response,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> LLMResult:
content = ""
image_outputs = [] # 画像のURLやdata URIを収集しておく場合に使用
if hasattr(response, 'output') and response.output:
for item in response.output:
item_type = getattr(item, 'type', '')
if item_type == "message":
item_content = getattr(item, 'content', None)
if isinstance(item_content, str):
if item_content:
content += item_content
elif isinstance(item_content, list):
for part in item_content:
part_type = getattr(part, 'type', '')
if part_type in ("output_text", "text", "input_text"):
text_val = getattr(part, 'text', '')
if text_val:
content += text_val
elif part_type in ("output_image", "image"):
img = getattr(part, "image_url", None) or getattr(part, "content", None)
if img:
image_outputs.append(img)
content += f"[image:{img}]"
elif item_type in ("output_text", "text"):
text_val = getattr(item, 'text', '')
if text_val:
content += text_val
elif item_type in ("output_image", "image"):
img = getattr(item, "image_url", None) or getattr(item, "content", None)
if img:
image_outputs.append(img)
content += f"[image:{img}]"
elif hasattr(response, 'text') and response.text:
content = response.text
elif hasattr(response, 'content') and response.content:
content = response.content
tool_calls = []
if hasattr(response, 'output') and response.output:
for item in response.output:
item_type = getattr(item, 'type', '')
if item_type == "function_call":
function_name = getattr(item, 'name', '')
function_args = getattr(item, 'arguments', '')
call_id = getattr(item, 'call_id', '') or getattr(item, 'id', '')
if isinstance(function_args, dict):
args_str = json.dumps(function_args)
elif isinstance(function_args, str):
args_str = function_args
else:
args_str = "{}"
tool_call = AssistantPromptMessage.ToolCall(
id=call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=function_name,
arguments=args_str
)
)
tool_calls.append(tool_call)
assistant_prompt_message = AssistantPromptMessage(
content=content,
tool_calls=tool_calls
)
prompt_tokens = 0
completion_tokens = 0
prompt_tokens_details: Optional[dict] = None
completion_tokens_details: Optional[dict] = None
if hasattr(response, 'usage') and response.usage:
usage_obj = response.usage
prompt_tokens = getattr(usage_obj, 'input_tokens', None) or getattr(usage_obj, 'prompt_tokens', 0)
completion_tokens = getattr(usage_obj, 'output_tokens', None) or getattr(usage_obj, 'completion_tokens', 0)
if hasattr(usage_obj, 'prompt_tokens_details') and usage_obj.prompt_tokens_details:
_ptd = usage_obj.prompt_tokens_details
if hasattr(_ptd, 'to_dict'):
prompt_tokens_details = _ptd.to_dict()
elif isinstance(_ptd, dict):
prompt_tokens_details = _ptd
else:
prompt_tokens_details = {
'cached_tokens': getattr(_ptd, 'cached_tokens', None)
}
elif hasattr(usage_obj, 'input_tokens_details') and usage_obj.input_tokens_details:
it = usage_obj.input_tokens_details
if hasattr(it, 'to_dict'):
prompt_tokens_details = it.to_dict()
else:
prompt_tokens_details = {
'cached_tokens': getattr(it, 'cached_tokens', None)
}
if hasattr(usage_obj, 'completion_tokens_details') and usage_obj.completion_tokens_details:
completion_tokens_details = usage_obj.completion_tokens_details
elif hasattr(usage_obj, 'output_tokens_details') and usage_obj.output_tokens_details:
ot = usage_obj.output_tokens_details
if hasattr(ot, 'to_dict'):
completion_tokens_details = ot.to_dict()
else:
completion_tokens_details = {
'reasoning_tokens': getattr(ot, 'reasoning_tokens', None)
}
else:
prompt_tokens = self._num_tokens_from_messages(
credentials, prompt_messages, tools
)
completion_tokens = self._num_tokens_from_messages(
credentials, [assistant_prompt_message]
)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details
)
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=getattr(response, 'id', ''),
)
def _handle_responses_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ResponseStreamEvent],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> Generator:
full_text = ""
index = 0
is_first = True
pending_tool_calls = {}
current_tool_call = None
for chunk in response:
if is_first:
is_first = False
chunk_type = getattr(chunk, 'type', '')
if chunk_type == 'response.output_text.delta':
delta_text = getattr(chunk, 'delta', '')
if delta_text:
full_text += delta_text
assistant_prompt_message = AssistantPromptMessage(
content=delta_text,
tool_calls=[]
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=getattr(chunk, 'item_id', ''),
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
),
)
index += 1
elif chunk_type == 'response.output_item.added':
item = getattr(chunk, 'item', None)
if item and hasattr(item, 'type'):
item_type = getattr(item, 'type', '')
if item_type == 'function_call':
function_name = getattr(item, 'name', '')
call_id = getattr(item, 'call_id', '')
if function_name and call_id:
pending_tool_calls[call_id] = {
'id': call_id,
'name': function_name,
'arguments': ''
}
current_tool_call = call_id
elif chunk_type == 'response.function_call_arguments.delta':
delta_args = getattr(chunk, 'delta', '')
if current_tool_call and current_tool_call in pending_tool_calls:
pending_tool_calls[current_tool_call]['arguments'] += delta_args
elif chunk_type == 'response.function_call_arguments.done':
call_id = getattr(chunk, 'item_id', '')
final_args = getattr(chunk, 'arguments', '')
if call_id and call_id in pending_tool_calls:
pending_tool_calls[call_id]['arguments'] = final_args
elif chunk_type == 'response.output_item.done':
item = getattr(chunk, 'item', None)
if item and hasattr(item, 'type'):
item_type = getattr(item, 'type', '')
if item_type == 'function_call':
function_name = getattr(item, 'name', '')
function_args = getattr(item, 'arguments', '')
call_id = getattr(item, 'call_id', '')
if call_id in pending_tool_calls:
final_args = pending_tool_calls[call_id]['arguments'] or function_args
else:
final_args = function_args
if function_name:
tool_call = AssistantPromptMessage.ToolCall(
id=call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=function_name,
arguments=final_args or "{}"
)
)
assistant_prompt_message = AssistantPromptMessage(
content="",
tool_calls=[tool_call]
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=call_id,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
),
)
index += 1
if call_id in pending_tool_calls:
del pending_tool_calls[call_id]
if call_id == current_tool_call:
current_tool_call = None
elif hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
delta_text = chunk.delta.text or ""
if delta_text:
full_text += delta_text
assistant_prompt_message = AssistantPromptMessage(
content=delta_text,
tool_calls=[]
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=getattr(chunk, 'item_id', ''),
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
),
)
index += 1
prompt_tokens = self._num_tokens_from_messages(
credentials, prompt_messages, tools
)
full_assistant_prompt_message = AssistantPromptMessage(content=full_text)
completion_tokens = self._num_tokens_from_messages(
credentials, [full_assistant_prompt_message]
)
usage = self._calc_response_usage(
model, credentials, prompt_tokens, completion_tokens
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=""),
finish_reason="stop",
usage=usage,
),
)
@staticmethod
def _convert_prompt_message_to_dict(message: PromptMessage):
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
assert message.content is not None
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(
TextPromptMessageContent, message_content
)
sub_message_dict = {
"type": "text",
"text": message_content.data,
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(
ImagePromptMessageContent, message_content
)
sub_message_dict = {
"type": "image_url",
"image_url": {
"url": message_content.data,
"detail": message_content.detail.value,
},
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.AUDIO:
message_content = cast(
AudioPromptMessageContent, message_content
)
sub_message_dict = {
"type": "input_audio",
"input_audio": {
"data": message_content.base64_data,
"format": message_content.format,
},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [
tool_call.model_dump(mode="json")
for tool_call in message.tool_calls
]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "tool",
"name": message.name,
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_messages(
self,
credentials: dict,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
model = credentials["base_model_name"]
# GPT-5系はGPT-4oと同じエンコーディングを使用すると仮定
try:
encoding = tiktoken.encoding_for_model("gpt-4o")
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
# 新しいモデルは一律この計算式を使用
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
image_details: list[dict] = []
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict):
if item["type"] == "text":
text += item["text"]
elif item["type"] == "image_url":
image_details.append(item["image_url"])
value = text
if key == "tool_calls":
for tool_call in value:
assert isinstance(tool_call, dict)
for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key))
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(t_key))
num_tokens += len(encoding.encode(t_value))
else:
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
if len(image_details) > 0:
num_tokens += self._num_tokens_from_images(
image_details=image_details,
base_model_name=credentials["base_model_name"],
)
return num_tokens
@staticmethod
def _num_tokens_for_tools(
encoding: tiktoken.Encoding, tools: list[PromptMessageTool]
) -> int:
num_tokens = 0
for tool in tools:
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode("function"))
num_tokens += len(encoding.encode("name"))
num_tokens += len(encoding.encode(tool.name))
num_tokens += len(encoding.encode("description"))
num_tokens += len(encoding.encode(tool.description))
parameters = tool.parameters
num_tokens += len(encoding.encode("parameters"))
if "title" in parameters:
num_tokens += len(encoding.encode("title"))
num_tokens += len(encoding.encode(parameters["title"]))
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode(parameters["type"]))
if "properties" in parameters:
num_tokens += len(encoding.encode("properties"))
for key, value in parameters["properties"].items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if "required" in parameters:
num_tokens += len(encoding.encode("required"))
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str):
for ai_model_entity in LLM_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
def _get_base_model_name(self, credentials: dict) -> str:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
return base_model_name
def _get_image_patches(self, n: int) -> float:
return (n + 32 - 1) // 32
def _num_tokens_from_images(
self, base_model_name: str, image_details: list[dict]
) -> int:
num_tokens: int = 0
base_tokens: int = 85
tile_tokens: int = 170
# gpt-5系の画像トークン計算はgpt-4oと同じロジックを適用
# 必要に応じて分岐を追加
if "mini" in base_model_name:
base_tokens = 2833
tile_tokens = 5667
for image_detail in image_details:
url = image_detail["url"]
if "base64" in url and "," in url:
try:
base64_str = url.split(",")[1]
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
width, height = image.size
if image_detail.get("detail") == "low":
num_tokens += 85
else:
if width > 2048 or height > 2048:
aspect_ratio = width / height
if aspect_ratio > 1:
width, height = 2048, int(2048 / aspect_ratio)
else:
width, height = int(2048 * aspect_ratio), 2048
if width >= height and height > 768:
width, height = int((768 / height) * width), 768
elif height > width and width > 768:
width, height = 768, int((768 / width) * height)
w_tiles = math.ceil(width / 512)
h_tiles = math.ceil(height / 512)
total_tiles = w_tiles * h_tiles
num_tokens += base_tokens + total_tiles * tile_tokens
except Exception:
# 画像デコード失敗時は最低限のトークンを加算
num_tokens += base_tokens
else:
# URL の場合は概算の最小値を加算するなどの簡易対応
num_tokens += base_tokens
return num_tokens
# --- 追加: 画像圧縮(data URI のみ) ---
def _prepare_image_url(self, image_url: str) -> str:
"""
data URI の画像だけを対象に、最大辺 1024px の WebP(quality=80) に再圧縮する。
失敗した場合は元の文字列を返す。
"""
try:
return self._compress_image_data_uri(image_url)
except Exception:
return image_url
@staticmethod
def _compress_image_data_uri(
data_uri: str,
max_side: int = 1024,
fmt: str = "WEBP",
quality: int = 80,
) -> str:
if not (isinstance(data_uri, str) and data_uri.startswith("data:image")):
return data_uri
if "," not in data_uri:
return data_uri
header, b64data = data_uri.split(",", 1)
img_bytes = base64.b64decode(b64data)
image = Image.open(io.BytesIO(img_bytes))
image = image.convert("RGB") # WebP で安定させるため
image.thumbnail((max_side, max_side), Image.Resampling.LANCZOS)
buf = io.BytesIO()
# method=6 は WebP のデフォルト圧縮設定 (可変)
image.save(buf, format=fmt, quality=quality, method=6)
b64_new = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/{fmt.lower()};base64,{b64_new}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment