|
import copy |
|
import os |
|
import json |
|
import logging |
|
import uuid |
|
from urllib.parse import urljoin |
|
|
|
import httpx |
|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from dotenv import load_dotenv |
|
|
|
# ======================== |
|
# Конфигурация и инициализация |
|
# ======================== |
|
|
|
load_dotenv() |
|
|
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "info").upper() |
|
logging.basicConfig(level=LOG_LEVEL) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
OPENWEBUI_URL = os.getenv("OPENWEBUI_URL", "your_endpoint_here") |
|
API_KEY = os.getenv("OPENWEBUI_API_KEY", "your_api_key_here") |
|
TIMEOUT = 30.0 |
|
|
|
ZED_SYSTEM_PROMPT_FILE = os.getenv("ZED_SYSTEM_PROMPT_FILE") |
|
ZED_SYSTEM_PROMPT_MODE = os.getenv("ZED_SYSTEM_PROMPT_MODE", "default").lower() |
|
EMULATE_TOOLS_CALLING = os.getenv("EMULATE_TOOLS_CALLING", True) |
|
|
|
START_MARKER = "<|tools▁calls▁start|>" |
|
END_MARKER = "<|tools▁calls▁end|>" |
|
MARKER_MAX_LEN = max(len(START_MARKER), len(END_MARKER)) |
|
|
|
# ======================== |
|
# Загрузка и обработка системного промта |
|
# ======================== |
|
|
|
def load_system_prompt() -> str | None: |
|
if ZED_SYSTEM_PROMPT_FILE and os.path.exists(ZED_SYSTEM_PROMPT_FILE): |
|
with open(ZED_SYSTEM_PROMPT_FILE, encoding="utf-8") as f: |
|
return f.read().strip() |
|
return None |
|
|
|
def apply_system_prompt_policy(messages: list[dict], mode: str, custom_prompt: str | None) -> list[dict]: |
|
if mode == "disable": |
|
return [m for m in messages if m.get("role") != "system"] |
|
if mode == "replace" and custom_prompt: |
|
filtered = [m for m in messages if m.get("role") != "system"] |
|
filtered.append({"role": "system", "content": custom_prompt}) |
|
return filtered |
|
return messages |
|
|
|
def inject_tools_as_prompt(tools: dict, messages: list[dict]) -> None: |
|
if not tools: |
|
return |
|
|
|
tools_text = json.dumps(tools, indent=2, ensure_ascii=False) |
|
tools_message = { |
|
"role": "system", |
|
"content": f"Available tools:\n{tools_text}" |
|
} |
|
|
|
for i, msg in enumerate(messages): |
|
if msg.get("role") == "system": |
|
messages.insert(i + 1, tools_message) |
|
return |
|
messages.append(tools_message) |
|
|
|
# ======================== |
|
# Основной endpoint |
|
# ======================== |
|
|
|
@app.post("/v1/chat/completions") |
|
async def openai_proxy(request: Request): |
|
logger.info(">>> Вызван openai_proxy") |
|
body = await request.json() |
|
original_body = copy.deepcopy(body) |
|
|
|
# Системный промт |
|
system_prompt = load_system_prompt() |
|
body["messages"] = apply_system_prompt_policy(body.get("messages", []), ZED_SYSTEM_PROMPT_MODE, system_prompt) |
|
|
|
# Интеграция tools в messages |
|
if EMULATE_TOOLS_CALLING: |
|
tools = body.pop("tools", None) |
|
if tools: |
|
inject_tools_as_prompt(tools, body.get("messages", [])) |
|
logger.info("Инструменты встроены в messages, ключ 'tools' удалён") |
|
|
|
if body != original_body: |
|
logger.info(f"Тело запроса изменено: {json.dumps(body, ensure_ascii=False)}") |
|
else: |
|
logger.info(f"Тело запроса без изменений: {json.dumps(body, ensure_ascii=False)}") |
|
|
|
# Извлекаем Authorization из исходного запроса, если есть |
|
auth_header = request.headers.get("Authorization", f"Bearer {API_KEY}") |
|
headers = { |
|
"Authorization": auth_header, |
|
"Content-Type": "application/json", |
|
"Accept": "text/event-stream" if body.get("stream") else "application/json", |
|
} |
|
|
|
generator = func_calling_event_generator if EMULATE_TOOLS_CALLING else default_event_generator |
|
return StreamingResponse(generator(body, headers), media_type="text/event-stream") |
|
|
|
# ======================== |
|
# Прокси для всех /v1/* путей |
|
# ======================== |
|
|
|
@app.api_route("/v1/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) |
|
async def proxy_all(request: Request, path: str): |
|
if path == "chat/completions": |
|
return await openai_proxy(request) |
|
|
|
target_url = urljoin(f"{OPENWEBUI_URL}/", path) |
|
try: |
|
request_body = None |
|
if request.method in ["POST", "PUT"]: |
|
try: |
|
request_body = await request.json() |
|
except json.JSONDecodeError: |
|
request_body = None |
|
|
|
async with httpx.AsyncClient(timeout=TIMEOUT) as client: |
|
response = await client.request( |
|
method=request.method, |
|
url=target_url, |
|
headers={ |
|
"Authorization": f"Bearer {API_KEY}", |
|
"Content-Type": "application/json", |
|
}, |
|
json=request_body, |
|
params=dict(request.query_params), |
|
) |
|
|
|
filtered_headers = { |
|
k: v for k, v in response.headers.items() |
|
if k.lower() not in ["content-encoding", "content-length", "transfer-encoding", "connection"] |
|
} |
|
|
|
return JSONResponse( |
|
content=response.json(), |
|
status_code=response.status_code, |
|
headers=filtered_headers, |
|
) |
|
|
|
except httpx.ReadTimeout: |
|
logger.error("Таймаут при обращении к Open WebUI") |
|
raise HTTPException(status_code=504, detail="Таймаут соединения с Open WebUI") |
|
except Exception as e: |
|
logger.error(f"Ошибка проксирования: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
# ======================== |
|
# Генераторы событий |
|
# ======================== |
|
|
|
async def default_event_generator(body: dict, headers: dict): |
|
max_log_chunk = 200 |
|
try: |
|
async with httpx.AsyncClient(timeout=TIMEOUT) as client: |
|
async with client.stream("POST", f"{OPENWEBUI_URL}/api/chat/completions", json=body, headers=headers) as response: |
|
if response.status_code != 200: |
|
text = await response.aread() |
|
logger.error(f"OpenWebUI error: {text.decode()}") |
|
yield format_error_event(text.decode()) |
|
return |
|
|
|
async for line in response.aiter_lines(): |
|
if not line.strip(): |
|
continue |
|
if line.startswith("data: "): |
|
json_str = line[len("data: "):].strip() |
|
try: |
|
data = json.loads(json_str) |
|
if "sources" in data: |
|
snippet = json_str[:max_log_chunk].replace("\n", " ") |
|
logger.info(f"Пропущен чанк с 'sources': {snippet}...") |
|
continue |
|
except json.JSONDecodeError: |
|
pass |
|
logger.info(line) |
|
yield f"{line}\n" |
|
except Exception as e: |
|
logger.error(f"Ошибка стриминга: {e}") |
|
yield format_error_event("Internal server error") |
|
async def func_calling_event_generator(body: dict, headers: dict): |
|
text_accumulator = [] |
|
ignore_rest = False |
|
|
|
try: |
|
async with httpx.AsyncClient(timeout=60) as client: |
|
async with client.stream("POST", f"{OPENWEBUI_URL}/api/chat/completions", json=body, headers=headers) as response: |
|
if response.status_code != 200: |
|
text = await response.aread() |
|
logger.error(f"Ошибка от API: {text.decode()}") |
|
yield f"data: {{\"error\": \"{text.decode()}\"}}\n\n" |
|
return |
|
|
|
async for line in response.aiter_lines(): |
|
if ignore_rest: |
|
if line.strip() == "data: [DONE]": |
|
yield line + "\n" |
|
return |
|
continue |
|
|
|
if not line.startswith("data: "): |
|
continue |
|
|
|
data_part = line[len("data: "):].strip() |
|
if not data_part: |
|
continue |
|
|
|
try: |
|
data_json = json.loads(data_part) |
|
choice = data_json.get("choices", [{}])[0] |
|
delta = choice.get("delta", {}) |
|
content = delta.get("content", "") |
|
if not content: |
|
continue |
|
except Exception: |
|
continue |
|
|
|
if "sources" in data_json: |
|
# Пропускаем системные служебные чанки OpenWEBUI |
|
continue |
|
|
|
# Добавляем текст к аккумулятору |
|
text_accumulator.append(content) |
|
accumulated = "".join(text_accumulator) |
|
|
|
start_idx = accumulated.find(START_MARKER) |
|
end_idx = accumulated.find(END_MARKER) |
|
|
|
if start_idx != -1 and end_idx != -1 and end_idx > start_idx: |
|
pre_text = accumulated[:start_idx].strip() |
|
json_block = accumulated[start_idx + len(START_MARKER):end_idx] |
|
|
|
if pre_text: |
|
yield emit_with_content(data_json, pre_text) |
|
|
|
first_brace = json_block.find("{") |
|
last_brace = json_block.rfind("}") |
|
|
|
if first_brace != -1 and last_brace != -1: |
|
json_str = json_block[first_brace : last_brace + 1] |
|
try: |
|
parsed = json.loads(json_str) |
|
functions = parsed.get("functions", []) |
|
if functions: |
|
# Отдаем все вызовы функций |
|
async for tool_chunk in stream_tool_calls(data_json, functions): |
|
logger.info(tool_chunk[:-2]) |
|
yield tool_chunk |
|
except json.JSONDecodeError: |
|
pass |
|
logger.info("data: [DONE]\n") |
|
yield "data: [DONE]\n\n" |
|
return |
|
else: |
|
# Пока маркеры не найдены — ищем безопасную часть для отдачи |
|
accumulated = ''.join(text_accumulator) |
|
|
|
def find_safe_cutoff(buffer: str, marker: str) -> int: |
|
""" |
|
Возвращает позицию начала маркера `ABC` (или его префикса `A`, `AB`) в любом месте строки. |
|
- Если найден маркер → индекс его начала. |
|
- Если найден префикс маркера в конце → его индекс. |
|
- Если ничего не найдено → длина буфера. |
|
""" |
|
# Сначала проверяем полный маркер |
|
pos = buffer.find(marker) |
|
if pos != -1: |
|
return pos |
|
|
|
# Проверяем префиксы маркера в конце строки |
|
for i in range(len(marker) - 1, 0, -1): |
|
prefix = marker[:i] |
|
if buffer.endswith(prefix): |
|
return len(buffer) - len(prefix) |
|
|
|
return len(buffer) # Маркер не найден |
|
|
|
safe_idx = find_safe_cutoff(accumulated, START_MARKER) |
|
|
|
safe_part = accumulated[:safe_idx] |
|
remaining_tail = accumulated[safe_idx:] |
|
|
|
|
|
chunk = emit_with_content(data_json, safe_part) |
|
logger.info(chunk[:-2]) |
|
yield chunk |
|
|
|
text_accumulator = [remaining_tail] |
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
except Exception as e: |
|
logger.error(f"Ошибка стриминга: {e}") |
|
yield f"data: {{\"error\": \"Internal server error\"}}\n\n" |
|
|
|
# ======================== |
|
# Потоковый вывод функций |
|
# ======================== |
|
def emit_with_content(base_json: dict, content: str) -> str: |
|
base = dict(base_json) |
|
base["choices"][0]["delta"] = {"role": "assistant", "content": content} |
|
return f"data: {json.dumps(base)}\n\n" |
|
|
|
async def stream_tool_calls(base_json: dict, functions: list): |
|
for i, func in enumerate(functions): |
|
fc = func.get("function_call", {}) |
|
chunk = { |
|
"id": f"{base_json["id"]}", |
|
"object": "chat.completion.chunk", |
|
"model": f"{base_json["model"]}", |
|
"created": f"{base_json["created"]}", |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"delta": { |
|
"tool_calls": [ |
|
{ |
|
"id": f"call_{i}", |
|
"index": i, |
|
"type": "function", |
|
"function": { |
|
"name": fc.get("name", ""), |
|
"arguments": fc.get("arguments", "") |
|
} |
|
} |
|
] |
|
}, |
|
"finish_reason": "tool_calls" if i == len(functions) - 1 else None, |
|
"native_finish_reason": "tool_calls" if i == len(functions) - 1 else None |
|
} |
|
] |
|
} |
|
yield f"data: {json.dumps(chunk)}\n\n" |
|
|
|
# ======================== |
|
# Утилиты |
|
# ======================== |
|
|
|
def format_error_event(message): |
|
return f'data: {{"error": "{message}"}}\n\n' |
|
|
|
def find_safe_cutoff(buffer: str, marker: str) -> int: |
|
pos = buffer.find(marker) |
|
if pos != -1: |
|
return pos |
|
for i in range(len(marker) - 1, 0, -1): |
|
if buffer.endswith(marker[:i]): |
|
return len(buffer) - i |
|
return len(buffer) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="127.0.0.1", port=5000, log_level=LOG_LEVEL.lower()) |