|
import base64 |
|
import copy |
|
import io |
|
import json |
|
import logging |
|
import math |
|
from collections.abc import Generator |
|
from typing import Optional, Union, cast |
|
|
|
import tiktoken |
|
from PIL import Image |
|
from dify_plugin.entities.model import AIModelEntity |
|
from dify_plugin.entities.model.llm import ( |
|
LLMResult, |
|
LLMResultChunk, |
|
LLMResultChunkDelta, |
|
) |
|
from dify_plugin.entities.model.message import ( |
|
AssistantPromptMessage, |
|
AudioPromptMessageContent, |
|
ImagePromptMessageContent, |
|
PromptMessage, |
|
PromptMessageContentType, |
|
SystemPromptMessage, |
|
TextPromptMessageContent, |
|
ToolPromptMessage, |
|
UserPromptMessage, |
|
PromptMessageTool, |
|
) |
|
from dify_plugin.errors.model import CredentialsValidateFailedError |
|
from dify_plugin.interfaces.model.large_language_model import LargeLanguageModel |
|
from openai import AzureOpenAI, Stream |
|
from openai.types.responses import ResponseStreamEvent, Response |
|
|
|
from ..common import _CommonAzureOpenAI |
|
from ..constants import LLM_BASE_MODELS |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): |
|
def _invoke( |
|
self, |
|
model: str, |
|
credentials: dict, |
|
prompt_messages: list[PromptMessage], |
|
model_parameters: dict, |
|
tools: Optional[list[PromptMessageTool]] = None, |
|
stop: Optional[list[str]] = None, |
|
stream: bool = True, |
|
user: Optional[str] = None, |
|
) -> Union[LLMResult, Generator]: |
|
""" |
|
GPT-5系モデル専用のInvokeメソッド。 |
|
すべてのリクエストを Responses API (client.responses.create) で処理します。 |
|
""" |
|
return self._chat_generate_with_responses( |
|
model=model, |
|
credentials=credentials, |
|
prompt_messages=prompt_messages, |
|
model_parameters=model_parameters, |
|
tools=tools, |
|
stop=stop, |
|
stream=stream, |
|
user=user, |
|
) |
|
|
|
def get_num_tokens( |
|
self, |
|
model: str, |
|
credentials: dict, |
|
prompt_messages: list[PromptMessage], |
|
tools: Optional[list[PromptMessageTool]] = None, |
|
) -> int: |
|
return self._num_tokens_from_messages(credentials, prompt_messages, tools) |
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
if "openai_api_base" not in credentials: |
|
raise CredentialsValidateFailedError( |
|
"Azure OpenAI API Base Endpoint is required" |
|
) |
|
|
|
auth_method = credentials.get("auth_method", "api_key") |
|
if auth_method == "api_key" and "openai_api_key" not in credentials: |
|
raise CredentialsValidateFailedError( |
|
"Azure OpenAI API key is required when using API Key authentication" |
|
) |
|
|
|
if "base_model_name" not in credentials: |
|
raise CredentialsValidateFailedError("Base Model Name is required") |
|
|
|
base_model_name = self._get_base_model_name(credentials) |
|
ai_model_entity = self._get_ai_model_entity( |
|
base_model_name=base_model_name, model=model |
|
) |
|
if not ai_model_entity: |
|
raise CredentialsValidateFailedError( |
|
f"Base Model Name {credentials['base_model_name']} is invalid" |
|
) |
|
|
|
try: |
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials)) |
|
# Responses API を使用して接続確認を行う |
|
client.responses.create( |
|
model=model, |
|
input=[{"role": "user", "content": "ping"}], |
|
max_output_tokens=20, |
|
stream=False |
|
) |
|
except Exception as ex: |
|
raise CredentialsValidateFailedError(str(ex)) |
|
|
|
def get_customizable_model_schema( |
|
self, model: str, credentials: dict |
|
) -> Optional[AIModelEntity]: |
|
base_model_name = self._get_base_model_name(credentials) |
|
ai_model_entity = self._get_ai_model_entity( |
|
base_model_name=base_model_name, model=model |
|
) |
|
return ai_model_entity.entity if ai_model_entity else None |
|
|
|
def _chat_generate_with_responses( |
|
self, |
|
model: str, |
|
credentials: dict, |
|
prompt_messages: list[PromptMessage], |
|
model_parameters: dict, |
|
tools: Optional[list[PromptMessageTool]] = None, |
|
stop: Optional[list[str]] = None, |
|
stream: bool = True, |
|
user: Optional[str] = None, |
|
) -> Union[LLMResult, Generator]: |
|
""" |
|
Generate chat responses with the OpenAI Responses API. |
|
Only supported parameters for Responses API are forwarded. |
|
""" |
|
client = AzureOpenAI(**self._to_credential_kwargs(credentials)) |
|
input_messages = self._convert_prompt_messages_to_responses_input(prompt_messages) |
|
|
|
responses_params = { |
|
"model": model, |
|
"input": input_messages, |
|
} |
|
|
|
# Responses API: use only max_output_tokens (do not accept max_tokens/max_completion_tokens) |
|
if "max_output_tokens" in model_parameters: |
|
responses_params["max_output_tokens"] = model_parameters["max_output_tokens"] |
|
|
|
# Tools |
|
if tools: |
|
responses_params["tools"] = [] |
|
for tool in tools: |
|
parameters = tool.parameters |
|
if isinstance(parameters, str): |
|
try: |
|
parameters = json.loads(parameters) |
|
except json.JSONDecodeError: |
|
parameters = {"type": "object", "properties": {}} |
|
elif not isinstance(parameters, dict): |
|
parameters = {"type": "object", "properties": {}} |
|
|
|
tool_dict = { |
|
"type": "function", |
|
"name": tool.name, |
|
"description": tool.description or "", |
|
"parameters": parameters |
|
} |
|
responses_params["tools"].append(tool_dict) |
|
responses_params["tool_choice"] = "auto" |
|
|
|
if user: |
|
responses_params["user"] = user |
|
|
|
if stop: |
|
responses_params["stop"] = stop # Responses API supports stop sequences |
|
|
|
# Structured output via text.format |
|
response_format = model_parameters.get("response_format") |
|
if response_format: |
|
if response_format == "json_schema": |
|
json_schema_data = model_parameters.get("json_schema", {}) |
|
if isinstance(json_schema_data, str): |
|
try: |
|
json_schema_data = json.loads(json_schema_data) |
|
except json.JSONDecodeError: |
|
json_schema_data = {} |
|
|
|
responses_params["text"] = { |
|
"format": { |
|
"type": "json_schema", |
|
"name": json_schema_data.get("name", "response"), |
|
"strict": json_schema_data.get("strict", True), |
|
"schema": json_schema_data.get("json_schema", {}) |
|
} |
|
} |
|
else: |
|
responses_params["text"] = { |
|
"format": {"type": response_format} |
|
} |
|
|
|
# Reasoning effort (supported) |
|
if "reasoning_effort" in model_parameters: |
|
responses_params["reasoning"] = {"effort": model_parameters["reasoning_effort"]} |
|
|
|
logger.info( |
|
f"llm request with responses api: model={model}, stream={stream}, " |
|
f"parameters={responses_params}" |
|
) |
|
|
|
response = client.responses.create( |
|
**responses_params, |
|
stream=stream, |
|
) |
|
|
|
if stream: |
|
return self._handle_responses_stream_response( |
|
model, credentials, response, prompt_messages, tools |
|
) |
|
else: |
|
return self._handle_responses_response( |
|
model, credentials, response, prompt_messages, tools |
|
) |
|
|
|
def _convert_prompt_messages_to_responses_input( |
|
self, prompt_messages: list[PromptMessage] |
|
) -> list[dict]: |
|
""" |
|
Convert Dify PromptMessage objects into Azure Responses API input format, |
|
including image/audio support for VLM. |
|
画像は data URI の場合、最大辺 1024px にリサイズし WebP(quality=80) へ圧縮する。 |
|
""" |
|
input_messages: list[dict] = [] |
|
|
|
for message in prompt_messages: |
|
# System → developer |
|
if isinstance(message, SystemPromptMessage): |
|
input_messages.append({ |
|
"role": "developer", |
|
"content": message.content |
|
}) |
|
continue |
|
|
|
# User |
|
if isinstance(message, UserPromptMessage): |
|
if isinstance(message.content, str): |
|
input_messages.append({ |
|
"role": "user", |
|
"content": message.content |
|
}) |
|
continue |
|
|
|
content_parts = [] |
|
assert message.content is not None |
|
for content_item in message.content: |
|
# TEXT |
|
if getattr(content_item, "type", None) in ( |
|
PromptMessageContentType.TEXT, |
|
"text", |
|
"input_text", |
|
): |
|
content_parts.append({ |
|
"type": "input_text", |
|
"text": content_item.data, |
|
}) |
|
# IMAGE |
|
elif getattr(content_item, "type", None) in ( |
|
PromptMessageContentType.IMAGE, |
|
"image_url", |
|
"image", |
|
"input_image", |
|
): |
|
image_url = content_item.data |
|
if isinstance(image_url, dict) and "url" in image_url: |
|
image_url = image_url["url"] |
|
# ここで data URI のみ圧縮&リサイズ |
|
image_url = self._prepare_image_url(image_url) |
|
content_parts.append({ |
|
"type": "input_image", |
|
"image_url": image_url, # string only (data URI or https URL) |
|
}) |
|
# AUDIO |
|
elif getattr(content_item, "type", None) in ( |
|
PromptMessageContentType.AUDIO, |
|
"audio", |
|
"input_audio", |
|
): |
|
content_parts.append({ |
|
"type": "input_audio", |
|
"input_audio": { |
|
"data": getattr(content_item, "base64_data", None) or content_item.data, |
|
"format": getattr(content_item, "format", "wav"), |
|
}, |
|
}) |
|
|
|
input_messages.append({ |
|
"role": "user", |
|
"content": content_parts, |
|
}) |
|
continue |
|
|
|
# Assistant: 直前の応答などをコンテキストに積む |
|
if isinstance(message, AssistantPromptMessage): |
|
input_messages.append({ |
|
"role": "assistant", |
|
"content": message.content, |
|
}) |
|
continue |
|
|
|
# Tool (function call の結果) |
|
if isinstance(message, ToolPromptMessage): |
|
input_messages.append({ |
|
"role": "assistant", |
|
"content": message.content, |
|
}) |
|
continue |
|
|
|
return input_messages |
|
|
|
def _handle_responses_response( |
|
self, |
|
model: str, |
|
credentials: dict, |
|
response: Response, |
|
prompt_messages: list[PromptMessage], |
|
tools: Optional[list[PromptMessageTool]] = None, |
|
) -> LLMResult: |
|
content = "" |
|
image_outputs = [] # 画像のURLやdata URIを収集しておく場合に使用 |
|
|
|
if hasattr(response, 'output') and response.output: |
|
for item in response.output: |
|
item_type = getattr(item, 'type', '') |
|
if item_type == "message": |
|
item_content = getattr(item, 'content', None) |
|
if isinstance(item_content, str): |
|
if item_content: |
|
content += item_content |
|
elif isinstance(item_content, list): |
|
for part in item_content: |
|
part_type = getattr(part, 'type', '') |
|
if part_type in ("output_text", "text", "input_text"): |
|
text_val = getattr(part, 'text', '') |
|
if text_val: |
|
content += text_val |
|
elif part_type in ("output_image", "image"): |
|
img = getattr(part, "image_url", None) or getattr(part, "content", None) |
|
if img: |
|
image_outputs.append(img) |
|
content += f"[image:{img}]" |
|
elif item_type in ("output_text", "text"): |
|
text_val = getattr(item, 'text', '') |
|
if text_val: |
|
content += text_val |
|
elif item_type in ("output_image", "image"): |
|
img = getattr(item, "image_url", None) or getattr(item, "content", None) |
|
if img: |
|
image_outputs.append(img) |
|
content += f"[image:{img}]" |
|
elif hasattr(response, 'text') and response.text: |
|
content = response.text |
|
elif hasattr(response, 'content') and response.content: |
|
content = response.content |
|
|
|
tool_calls = [] |
|
if hasattr(response, 'output') and response.output: |
|
for item in response.output: |
|
item_type = getattr(item, 'type', '') |
|
if item_type == "function_call": |
|
function_name = getattr(item, 'name', '') |
|
function_args = getattr(item, 'arguments', '') |
|
call_id = getattr(item, 'call_id', '') or getattr(item, 'id', '') |
|
|
|
if isinstance(function_args, dict): |
|
args_str = json.dumps(function_args) |
|
elif isinstance(function_args, str): |
|
args_str = function_args |
|
else: |
|
args_str = "{}" |
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
id=call_id, |
|
type="function", |
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
name=function_name, |
|
arguments=args_str |
|
) |
|
) |
|
tool_calls.append(tool_call) |
|
|
|
assistant_prompt_message = AssistantPromptMessage( |
|
content=content, |
|
tool_calls=tool_calls |
|
) |
|
|
|
prompt_tokens = 0 |
|
completion_tokens = 0 |
|
prompt_tokens_details: Optional[dict] = None |
|
completion_tokens_details: Optional[dict] = None |
|
|
|
if hasattr(response, 'usage') and response.usage: |
|
usage_obj = response.usage |
|
prompt_tokens = getattr(usage_obj, 'input_tokens', None) or getattr(usage_obj, 'prompt_tokens', 0) |
|
completion_tokens = getattr(usage_obj, 'output_tokens', None) or getattr(usage_obj, 'completion_tokens', 0) |
|
|
|
if hasattr(usage_obj, 'prompt_tokens_details') and usage_obj.prompt_tokens_details: |
|
_ptd = usage_obj.prompt_tokens_details |
|
if hasattr(_ptd, 'to_dict'): |
|
prompt_tokens_details = _ptd.to_dict() |
|
elif isinstance(_ptd, dict): |
|
prompt_tokens_details = _ptd |
|
else: |
|
prompt_tokens_details = { |
|
'cached_tokens': getattr(_ptd, 'cached_tokens', None) |
|
} |
|
elif hasattr(usage_obj, 'input_tokens_details') and usage_obj.input_tokens_details: |
|
it = usage_obj.input_tokens_details |
|
if hasattr(it, 'to_dict'): |
|
prompt_tokens_details = it.to_dict() |
|
else: |
|
prompt_tokens_details = { |
|
'cached_tokens': getattr(it, 'cached_tokens', None) |
|
} |
|
|
|
if hasattr(usage_obj, 'completion_tokens_details') and usage_obj.completion_tokens_details: |
|
completion_tokens_details = usage_obj.completion_tokens_details |
|
elif hasattr(usage_obj, 'output_tokens_details') and usage_obj.output_tokens_details: |
|
ot = usage_obj.output_tokens_details |
|
if hasattr(ot, 'to_dict'): |
|
completion_tokens_details = ot.to_dict() |
|
else: |
|
completion_tokens_details = { |
|
'reasoning_tokens': getattr(ot, 'reasoning_tokens', None) |
|
} |
|
else: |
|
prompt_tokens = self._num_tokens_from_messages( |
|
credentials, prompt_messages, tools |
|
) |
|
completion_tokens = self._num_tokens_from_messages( |
|
credentials, [assistant_prompt_message] |
|
) |
|
|
|
usage = self._calc_response_usage( |
|
model, credentials, prompt_tokens, completion_tokens, |
|
prompt_tokens_details=prompt_tokens_details, |
|
completion_tokens_details=completion_tokens_details |
|
) |
|
|
|
return LLMResult( |
|
model=model, |
|
prompt_messages=prompt_messages, |
|
message=assistant_prompt_message, |
|
usage=usage, |
|
system_fingerprint=getattr(response, 'id', ''), |
|
) |
|
|
|
def _handle_responses_stream_response( |
|
self, |
|
model: str, |
|
credentials: dict, |
|
response: Stream[ResponseStreamEvent], |
|
prompt_messages: list[PromptMessage], |
|
tools: Optional[list[PromptMessageTool]] = None, |
|
) -> Generator: |
|
full_text = "" |
|
index = 0 |
|
is_first = True |
|
|
|
pending_tool_calls = {} |
|
current_tool_call = None |
|
|
|
for chunk in response: |
|
if is_first: |
|
is_first = False |
|
|
|
chunk_type = getattr(chunk, 'type', '') |
|
|
|
if chunk_type == 'response.output_text.delta': |
|
delta_text = getattr(chunk, 'delta', '') |
|
if delta_text: |
|
full_text += delta_text |
|
assistant_prompt_message = AssistantPromptMessage( |
|
content=delta_text, |
|
tool_calls=[] |
|
) |
|
yield LLMResultChunk( |
|
model=model, |
|
prompt_messages=prompt_messages, |
|
system_fingerprint=getattr(chunk, 'item_id', ''), |
|
delta=LLMResultChunkDelta( |
|
index=index, |
|
message=assistant_prompt_message |
|
), |
|
) |
|
index += 1 |
|
|
|
elif chunk_type == 'response.output_item.added': |
|
item = getattr(chunk, 'item', None) |
|
if item and hasattr(item, 'type'): |
|
item_type = getattr(item, 'type', '') |
|
if item_type == 'function_call': |
|
function_name = getattr(item, 'name', '') |
|
call_id = getattr(item, 'call_id', '') |
|
if function_name and call_id: |
|
pending_tool_calls[call_id] = { |
|
'id': call_id, |
|
'name': function_name, |
|
'arguments': '' |
|
} |
|
current_tool_call = call_id |
|
|
|
elif chunk_type == 'response.function_call_arguments.delta': |
|
delta_args = getattr(chunk, 'delta', '') |
|
if current_tool_call and current_tool_call in pending_tool_calls: |
|
pending_tool_calls[current_tool_call]['arguments'] += delta_args |
|
|
|
elif chunk_type == 'response.function_call_arguments.done': |
|
call_id = getattr(chunk, 'item_id', '') |
|
final_args = getattr(chunk, 'arguments', '') |
|
if call_id and call_id in pending_tool_calls: |
|
pending_tool_calls[call_id]['arguments'] = final_args |
|
|
|
elif chunk_type == 'response.output_item.done': |
|
item = getattr(chunk, 'item', None) |
|
if item and hasattr(item, 'type'): |
|
item_type = getattr(item, 'type', '') |
|
if item_type == 'function_call': |
|
function_name = getattr(item, 'name', '') |
|
function_args = getattr(item, 'arguments', '') |
|
call_id = getattr(item, 'call_id', '') |
|
|
|
if call_id in pending_tool_calls: |
|
final_args = pending_tool_calls[call_id]['arguments'] or function_args |
|
else: |
|
final_args = function_args |
|
|
|
if function_name: |
|
tool_call = AssistantPromptMessage.ToolCall( |
|
id=call_id, |
|
type="function", |
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
name=function_name, |
|
arguments=final_args or "{}" |
|
) |
|
) |
|
assistant_prompt_message = AssistantPromptMessage( |
|
content="", |
|
tool_calls=[tool_call] |
|
) |
|
yield LLMResultChunk( |
|
model=model, |
|
prompt_messages=prompt_messages, |
|
system_fingerprint=call_id, |
|
delta=LLMResultChunkDelta( |
|
index=index, |
|
message=assistant_prompt_message |
|
), |
|
) |
|
index += 1 |
|
if call_id in pending_tool_calls: |
|
del pending_tool_calls[call_id] |
|
if call_id == current_tool_call: |
|
current_tool_call = None |
|
|
|
elif hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'): |
|
delta_text = chunk.delta.text or "" |
|
if delta_text: |
|
full_text += delta_text |
|
assistant_prompt_message = AssistantPromptMessage( |
|
content=delta_text, |
|
tool_calls=[] |
|
) |
|
yield LLMResultChunk( |
|
model=model, |
|
prompt_messages=prompt_messages, |
|
system_fingerprint=getattr(chunk, 'item_id', ''), |
|
delta=LLMResultChunkDelta( |
|
index=index, |
|
message=assistant_prompt_message |
|
), |
|
) |
|
index += 1 |
|
|
|
prompt_tokens = self._num_tokens_from_messages( |
|
credentials, prompt_messages, tools |
|
) |
|
full_assistant_prompt_message = AssistantPromptMessage(content=full_text) |
|
completion_tokens = self._num_tokens_from_messages( |
|
credentials, [full_assistant_prompt_message] |
|
) |
|
usage = self._calc_response_usage( |
|
model, credentials, prompt_tokens, completion_tokens |
|
) |
|
yield LLMResultChunk( |
|
model=model, |
|
prompt_messages=prompt_messages, |
|
system_fingerprint="", |
|
delta=LLMResultChunkDelta( |
|
index=index, |
|
message=AssistantPromptMessage(content=""), |
|
finish_reason="stop", |
|
usage=usage, |
|
), |
|
) |
|
|
|
@staticmethod |
|
def _convert_prompt_message_to_dict(message: PromptMessage): |
|
if isinstance(message, UserPromptMessage): |
|
message = cast(UserPromptMessage, message) |
|
if isinstance(message.content, str): |
|
message_dict = {"role": "user", "content": message.content} |
|
else: |
|
sub_messages = [] |
|
assert message.content is not None |
|
for message_content in message.content: |
|
if message_content.type == PromptMessageContentType.TEXT: |
|
message_content = cast( |
|
TextPromptMessageContent, message_content |
|
) |
|
sub_message_dict = { |
|
"type": "text", |
|
"text": message_content.data, |
|
} |
|
sub_messages.append(sub_message_dict) |
|
elif message_content.type == PromptMessageContentType.IMAGE: |
|
message_content = cast( |
|
ImagePromptMessageContent, message_content |
|
) |
|
sub_message_dict = { |
|
"type": "image_url", |
|
"image_url": { |
|
"url": message_content.data, |
|
"detail": message_content.detail.value, |
|
}, |
|
} |
|
sub_messages.append(sub_message_dict) |
|
elif message_content.type == PromptMessageContentType.AUDIO: |
|
message_content = cast( |
|
AudioPromptMessageContent, message_content |
|
) |
|
sub_message_dict = { |
|
"type": "input_audio", |
|
"input_audio": { |
|
"data": message_content.base64_data, |
|
"format": message_content.format, |
|
}, |
|
} |
|
sub_messages.append(sub_message_dict) |
|
message_dict = {"role": "user", "content": sub_messages} |
|
elif isinstance(message, AssistantPromptMessage): |
|
message_dict = {"role": "assistant", "content": message.content} |
|
if message.tool_calls: |
|
message_dict["tool_calls"] = [ |
|
tool_call.model_dump(mode="json") |
|
for tool_call in message.tool_calls |
|
] |
|
elif isinstance(message, SystemPromptMessage): |
|
message = cast(SystemPromptMessage, message) |
|
message_dict = {"role": "system", "content": message.content} |
|
elif isinstance(message, ToolPromptMessage): |
|
message = cast(ToolPromptMessage, message) |
|
message_dict = { |
|
"role": "tool", |
|
"name": message.name, |
|
"content": message.content, |
|
"tool_call_id": message.tool_call_id, |
|
} |
|
else: |
|
raise ValueError(f"Got unknown type {message}") |
|
if message.name: |
|
message_dict["name"] = message.name |
|
return message_dict |
|
|
|
def _num_tokens_from_messages( |
|
self, |
|
credentials: dict, |
|
messages: list[PromptMessage], |
|
tools: Optional[list[PromptMessageTool]] = None, |
|
) -> int: |
|
model = credentials["base_model_name"] |
|
# GPT-5系はGPT-4oと同じエンコーディングを使用すると仮定 |
|
try: |
|
encoding = tiktoken.encoding_for_model("gpt-4o") |
|
except KeyError: |
|
logger.warning("Warning: model not found. Using cl100k_base encoding.") |
|
encoding = tiktoken.get_encoding("cl100k_base") |
|
|
|
# 新しいモデルは一律この計算式を使用 |
|
tokens_per_message = 3 |
|
tokens_per_name = 1 |
|
|
|
num_tokens = 0 |
|
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] |
|
image_details: list[dict] = [] |
|
for message in messages_dict: |
|
num_tokens += tokens_per_message |
|
for key, value in message.items(): |
|
if isinstance(value, list): |
|
text = "" |
|
for item in value: |
|
if isinstance(item, dict): |
|
if item["type"] == "text": |
|
text += item["text"] |
|
elif item["type"] == "image_url": |
|
image_details.append(item["image_url"]) |
|
value = text |
|
if key == "tool_calls": |
|
for tool_call in value: |
|
assert isinstance(tool_call, dict) |
|
for t_key, t_value in tool_call.items(): |
|
num_tokens += len(encoding.encode(t_key)) |
|
if t_key == "function": |
|
for f_key, f_value in t_value.items(): |
|
num_tokens += len(encoding.encode(f_key)) |
|
num_tokens += len(encoding.encode(f_value)) |
|
else: |
|
num_tokens += len(encoding.encode(t_key)) |
|
num_tokens += len(encoding.encode(t_value)) |
|
else: |
|
num_tokens += len(encoding.encode(str(value))) |
|
if key == "name": |
|
num_tokens += tokens_per_name |
|
num_tokens += 3 |
|
if tools: |
|
num_tokens += self._num_tokens_for_tools(encoding, tools) |
|
if len(image_details) > 0: |
|
num_tokens += self._num_tokens_from_images( |
|
image_details=image_details, |
|
base_model_name=credentials["base_model_name"], |
|
) |
|
return num_tokens |
|
|
|
@staticmethod |
|
def _num_tokens_for_tools( |
|
encoding: tiktoken.Encoding, tools: list[PromptMessageTool] |
|
) -> int: |
|
num_tokens = 0 |
|
for tool in tools: |
|
num_tokens += len(encoding.encode("type")) |
|
num_tokens += len(encoding.encode("function")) |
|
num_tokens += len(encoding.encode("name")) |
|
num_tokens += len(encoding.encode(tool.name)) |
|
num_tokens += len(encoding.encode("description")) |
|
num_tokens += len(encoding.encode(tool.description)) |
|
parameters = tool.parameters |
|
num_tokens += len(encoding.encode("parameters")) |
|
if "title" in parameters: |
|
num_tokens += len(encoding.encode("title")) |
|
num_tokens += len(encoding.encode(parameters["title"])) |
|
num_tokens += len(encoding.encode("type")) |
|
num_tokens += len(encoding.encode(parameters["type"])) |
|
if "properties" in parameters: |
|
num_tokens += len(encoding.encode("properties")) |
|
for key, value in parameters["properties"].items(): |
|
num_tokens += len(encoding.encode(key)) |
|
for field_key, field_value in value.items(): |
|
num_tokens += len(encoding.encode(field_key)) |
|
if field_key == "enum": |
|
for enum_field in field_value: |
|
num_tokens += 3 |
|
num_tokens += len(encoding.encode(enum_field)) |
|
else: |
|
num_tokens += len(encoding.encode(field_key)) |
|
num_tokens += len(encoding.encode(str(field_value))) |
|
if "required" in parameters: |
|
num_tokens += len(encoding.encode("required")) |
|
for required_field in parameters["required"]: |
|
num_tokens += 3 |
|
num_tokens += len(encoding.encode(required_field)) |
|
return num_tokens |
|
|
|
@staticmethod |
|
def _get_ai_model_entity(base_model_name: str, model: str): |
|
for ai_model_entity in LLM_BASE_MODELS: |
|
if ai_model_entity.base_model_name == base_model_name: |
|
ai_model_entity_copy = copy.deepcopy(ai_model_entity) |
|
ai_model_entity_copy.entity.model = model |
|
ai_model_entity_copy.entity.label.en_US = model |
|
ai_model_entity_copy.entity.label.zh_Hans = model |
|
return ai_model_entity_copy |
|
|
|
def _get_base_model_name(self, credentials: dict) -> str: |
|
base_model_name = credentials.get("base_model_name") |
|
if not base_model_name: |
|
raise ValueError("Base Model Name is required") |
|
return base_model_name |
|
|
|
def _get_image_patches(self, n: int) -> float: |
|
return (n + 32 - 1) // 32 |
|
|
|
def _num_tokens_from_images( |
|
self, base_model_name: str, image_details: list[dict] |
|
) -> int: |
|
num_tokens: int = 0 |
|
base_tokens: int = 85 |
|
tile_tokens: int = 170 |
|
|
|
# gpt-5系の画像トークン計算はgpt-4oと同じロジックを適用 |
|
# 必要に応じて分岐を追加 |
|
if "mini" in base_model_name: |
|
base_tokens = 2833 |
|
tile_tokens = 5667 |
|
|
|
for image_detail in image_details: |
|
url = image_detail["url"] |
|
if "base64" in url and "," in url: |
|
try: |
|
base64_str = url.split(",")[1] |
|
image_data = base64.b64decode(base64_str) |
|
image = Image.open(io.BytesIO(image_data)) |
|
width, height = image.size |
|
|
|
if image_detail.get("detail") == "low": |
|
num_tokens += 85 |
|
else: |
|
if width > 2048 or height > 2048: |
|
aspect_ratio = width / height |
|
if aspect_ratio > 1: |
|
width, height = 2048, int(2048 / aspect_ratio) |
|
else: |
|
width, height = int(2048 * aspect_ratio), 2048 |
|
|
|
if width >= height and height > 768: |
|
width, height = int((768 / height) * width), 768 |
|
elif height > width and width > 768: |
|
width, height = 768, int((768 / width) * height) |
|
|
|
w_tiles = math.ceil(width / 512) |
|
h_tiles = math.ceil(height / 512) |
|
total_tiles = w_tiles * h_tiles |
|
|
|
num_tokens += base_tokens + total_tiles * tile_tokens |
|
except Exception: |
|
# 画像デコード失敗時は最低限のトークンを加算 |
|
num_tokens += base_tokens |
|
else: |
|
# URL の場合は概算の最小値を加算するなどの簡易対応 |
|
num_tokens += base_tokens |
|
|
|
return num_tokens |
|
|
|
# --- 追加: 画像圧縮(data URI のみ) --- |
|
def _prepare_image_url(self, image_url: str) -> str: |
|
""" |
|
data URI の画像だけを対象に、最大辺 1024px の WebP(quality=80) に再圧縮する。 |
|
失敗した場合は元の文字列を返す。 |
|
""" |
|
try: |
|
return self._compress_image_data_uri(image_url) |
|
except Exception: |
|
return image_url |
|
|
|
@staticmethod |
|
def _compress_image_data_uri( |
|
data_uri: str, |
|
max_side: int = 1024, |
|
fmt: str = "WEBP", |
|
quality: int = 80, |
|
) -> str: |
|
if not (isinstance(data_uri, str) and data_uri.startswith("data:image")): |
|
return data_uri |
|
if "," not in data_uri: |
|
return data_uri |
|
|
|
header, b64data = data_uri.split(",", 1) |
|
img_bytes = base64.b64decode(b64data) |
|
image = Image.open(io.BytesIO(img_bytes)) |
|
image = image.convert("RGB") # WebP で安定させるため |
|
image.thumbnail((max_side, max_side), Image.Resampling.LANCZOS) |
|
|
|
buf = io.BytesIO() |
|
# method=6 は WebP のデフォルト圧縮設定 (可変) |
|
image.save(buf, format=fmt, quality=quality, method=6) |
|
b64_new = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
return f"data:image/{fmt.lower()};base64,{b64_new}" |