Spaces:
Running
Running
import time | |
import logging | |
import sys | |
import os | |
import base64 | |
import asyncio | |
from aiocache import cached | |
from typing import Any, Optional | |
import random | |
import json | |
import html | |
import inspect | |
import re | |
import ast | |
from uuid import uuid4 | |
from concurrent.futures import ThreadPoolExecutor | |
from fastapi import Request | |
from fastapi import BackgroundTasks | |
from starlette.responses import Response, StreamingResponse | |
from open_webui.models.chats import Chats | |
from open_webui.models.users import Users | |
from open_webui.socket.main import ( | |
get_event_call, | |
get_event_emitter, | |
get_active_status_by_user_id, | |
) | |
from open_webui.routers.tasks import ( | |
generate_queries, | |
generate_title, | |
generate_image_prompt, | |
generate_chat_tags, | |
) | |
from open_webui.routers.retrieval import process_web_search, SearchForm | |
from open_webui.routers.images import image_generations, GenerateImageForm | |
from open_webui.routers.pipelines import ( | |
process_pipeline_inlet_filter, | |
process_pipeline_outlet_filter, | |
) | |
from open_webui.utils.webhook import post_webhook | |
from open_webui.models.users import UserModel | |
from open_webui.models.functions import Functions | |
from open_webui.models.models import Models | |
from open_webui.retrieval.utils import get_sources_from_files | |
from open_webui.utils.chat import generate_chat_completion | |
from open_webui.utils.task import ( | |
get_task_model_id, | |
rag_template, | |
tools_function_calling_generation_template, | |
) | |
from open_webui.utils.misc import ( | |
deep_update, | |
get_message_list, | |
add_or_update_system_message, | |
add_or_update_user_message, | |
get_last_user_message, | |
get_last_assistant_message, | |
prepend_to_first_user_message_content, | |
convert_logit_bias_input_to_json, | |
) | |
from open_webui.utils.tools import get_tools | |
from open_webui.utils.plugin import load_function_module_by_id | |
from open_webui.utils.filter import ( | |
get_sorted_filter_ids, | |
process_filter_functions, | |
) | |
from open_webui.utils.code_interpreter import execute_code_jupyter | |
from open_webui.tasks import create_task | |
from open_webui.config import ( | |
CACHE_DIR, | |
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, | |
DEFAULT_CODE_INTERPRETER_PROMPT, | |
) | |
from open_webui.env import ( | |
SRC_LOG_LEVELS, | |
GLOBAL_LOG_LEVEL, | |
BYPASS_MODEL_ACCESS_CONTROL, | |
ENABLE_REALTIME_CHAT_SAVE, | |
) | |
from open_webui.constants import TASKS | |
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) | |
log = logging.getLogger(__name__) | |
log.setLevel(SRC_LOG_LEVELS["MAIN"]) | |
async def chat_completion_tools_handler( | |
request: Request, body: dict, user: UserModel, models, tools | |
) -> tuple[dict, dict]: | |
async def get_content_from_response(response) -> Optional[str]: | |
content = None | |
if hasattr(response, "body_iterator"): | |
async for chunk in response.body_iterator: | |
data = json.loads(chunk.decode("utf-8")) | |
content = data["choices"][0]["message"]["content"] | |
# Cleanup any remaining background tasks if necessary | |
if response.background is not None: | |
await response.background() | |
else: | |
content = response["choices"][0]["message"]["content"] | |
return content | |
def get_tools_function_calling_payload(messages, task_model_id, content): | |
user_message = get_last_user_message(messages) | |
history = "\n".join( | |
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" | |
for message in messages[::-1][:4] | |
) | |
prompt = f"History:\n{history}\nQuery: {user_message}" | |
return { | |
"model": task_model_id, | |
"messages": [ | |
{"role": "system", "content": content}, | |
{"role": "user", "content": f"Query: {prompt}"}, | |
], | |
"stream": False, | |
"metadata": {"task": str(TASKS.FUNCTION_CALLING)}, | |
} | |
task_model_id = get_task_model_id( | |
body["model"], | |
request.app.state.config.TASK_MODEL, | |
request.app.state.config.TASK_MODEL_EXTERNAL, | |
models, | |
) | |
skip_files = False | |
sources = [] | |
specs = [tool["spec"] for tool in tools.values()] | |
tools_specs = json.dumps(specs) | |
if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": | |
template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE | |
else: | |
template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE | |
tools_function_calling_prompt = tools_function_calling_generation_template( | |
template, tools_specs | |
) | |
log.info(f"{tools_function_calling_prompt=}") | |
payload = get_tools_function_calling_payload( | |
body["messages"], task_model_id, tools_function_calling_prompt | |
) | |
try: | |
response = await generate_chat_completion(request, form_data=payload, user=user) | |
log.debug(f"{response=}") | |
content = await get_content_from_response(response) | |
log.debug(f"{content=}") | |
if not content: | |
return body, {} | |
try: | |
content = content[content.find("{") : content.rfind("}") + 1] | |
if not content: | |
raise Exception("No JSON object found in the response") | |
result = json.loads(content) | |
async def tool_call_handler(tool_call): | |
nonlocal skip_files | |
log.debug(f"{tool_call=}") | |
tool_function_name = tool_call.get("name", None) | |
if tool_function_name not in tools: | |
return body, {} | |
tool_function_params = tool_call.get("parameters", {}) | |
try: | |
required_params = ( | |
tools[tool_function_name] | |
.get("spec", {}) | |
.get("parameters", {}) | |
.get("required", []) | |
) | |
tool_function = tools[tool_function_name]["callable"] | |
tool_function_params = { | |
k: v | |
for k, v in tool_function_params.items() | |
if k in required_params | |
} | |
tool_output = await tool_function(**tool_function_params) | |
except Exception as e: | |
tool_output = str(e) | |
if isinstance(tool_output, str): | |
if tools[tool_function_name]["citation"]: | |
sources.append( | |
{ | |
"source": { | |
"name": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" | |
}, | |
"document": [tool_output], | |
"metadata": [ | |
{ | |
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" | |
} | |
], | |
} | |
) | |
else: | |
sources.append( | |
{ | |
"source": {}, | |
"document": [tool_output], | |
"metadata": [ | |
{ | |
"source": f"TOOL:{tools[tool_function_name]['toolkit_id']}/{tool_function_name}" | |
} | |
], | |
} | |
) | |
if tools[tool_function_name]["file_handler"]: | |
skip_files = True | |
# check if "tool_calls" in result | |
if result.get("tool_calls"): | |
for tool_call in result.get("tool_calls"): | |
await tool_call_handler(tool_call) | |
else: | |
await tool_call_handler(result) | |
except Exception as e: | |
log.exception(f"Error: {e}") | |
content = None | |
except Exception as e: | |
log.exception(f"Error: {e}") | |
content = None | |
log.debug(f"tool_contexts: {sources}") | |
if skip_files and "files" in body.get("metadata", {}): | |
del body["metadata"]["files"] | |
return body, {"sources": sources} | |
async def chat_web_search_handler( | |
request: Request, form_data: dict, extra_params: dict, user | |
): | |
event_emitter = extra_params["__event_emitter__"] | |
await event_emitter( | |
{ | |
"type": "status", | |
"data": { | |
"action": "web_search", | |
"description": "Generating search query", | |
"done": False, | |
}, | |
} | |
) | |
messages = form_data["messages"] | |
user_message = get_last_user_message(messages) | |
queries = [] | |
try: | |
res = await generate_queries( | |
request, | |
{ | |
"model": form_data["model"], | |
"messages": messages, | |
"prompt": user_message, | |
"type": "web_search", | |
}, | |
user, | |
) | |
response = res["choices"][0]["message"]["content"] | |
try: | |
bracket_start = response.find("{") | |
bracket_end = response.rfind("}") + 1 | |
if bracket_start == -1 or bracket_end == -1: | |
raise Exception("No JSON object found in the response") | |
response = response[bracket_start:bracket_end] | |
queries = json.loads(response) | |
queries = queries.get("queries", []) | |
except Exception as e: | |
queries = [response] | |
except Exception as e: | |
log.exception(e) | |
queries = [user_message] | |
if len(queries) == 0: | |
await event_emitter( | |
{ | |
"type": "status", | |
"data": { | |
"action": "web_search", | |
"description": "No search query generated", | |
"done": True, | |
}, | |
} | |
) | |
return form_data | |
all_results = [] | |
for searchQuery in queries: | |
await event_emitter( | |
{ | |
"type": "status", | |
"data": { | |
"action": "web_search", | |
"description": 'Searching "{{searchQuery}}"', | |
"query": searchQuery, | |
"done": False, | |
}, | |
} | |
) | |
try: | |
results = await process_web_search( | |
request, | |
SearchForm( | |
**{ | |
"query": searchQuery, | |
} | |
), | |
user=user, | |
) | |
if results: | |
all_results.append(results) | |
files = form_data.get("files", []) | |
if results.get("collection_name"): | |
files.append( | |
{ | |
"collection_name": results["collection_name"], | |
"name": searchQuery, | |
"type": "web_search", | |
"urls": results["filenames"], | |
} | |
) | |
elif results.get("docs"): | |
files.append( | |
{ | |
"docs": results.get("docs", []), | |
"name": searchQuery, | |
"type": "web_search", | |
"urls": results["filenames"], | |
} | |
) | |
form_data["files"] = files | |
except Exception as e: | |
log.exception(e) | |
await event_emitter( | |
{ | |
"type": "status", | |
"data": { | |
"action": "web_search", | |
"description": 'Error searching "{{searchQuery}}"', | |
"query": searchQuery, | |
"done": True, | |
"error": True, | |
}, | |
} | |
) | |
if all_results: | |
urls = [] | |
for results in all_results: | |
if "filenames" in results: | |
urls.extend(results["filenames"]) | |
await event_emitter( | |
{ | |
"type": "status", | |
"data": { | |
"action": "web_search", | |
"description": "Searched {{count}} sites", | |
"urls": urls, | |
"done": True, | |
}, | |
} | |
) | |
else: | |
await event_emitter( | |
{ | |
"type": "status", | |
"data": { | |
"action": "web_search", | |
"description": "No search results found", | |
"done": True, | |
"error": True, | |
}, | |
} | |
) | |
return form_data | |
async def chat_image_generation_handler( | |
request: Request, form_data: dict, extra_params: dict, user | |
): | |
__event_emitter__ = extra_params["__event_emitter__"] | |
await __event_emitter__( | |
{ | |
"type": "status", | |
"data": {"description": "Generating an image", "done": False}, | |
} | |
) | |
messages = form_data["messages"] | |
user_message = get_last_user_message(messages) | |
prompt = user_message | |
negative_prompt = "" | |
if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION: | |
try: | |
res = await generate_image_prompt( | |
request, | |
{ | |
"model": form_data["model"], | |
"messages": messages, | |
}, | |
user, | |
) | |
response = res["choices"][0]["message"]["content"] | |
try: | |
bracket_start = response.find("{") | |
bracket_end = response.rfind("}") + 1 | |
if bracket_start == -1 or bracket_end == -1: | |
raise Exception("No JSON object found in the response") | |
response = response[bracket_start:bracket_end] | |
response = json.loads(response) | |
prompt = response.get("prompt", []) | |
except Exception as e: | |
prompt = user_message | |
except Exception as e: | |
log.exception(e) | |
prompt = user_message | |
system_message_content = "" | |
try: | |
images = await image_generations( | |
request=request, | |
form_data=GenerateImageForm(**{"prompt": prompt}), | |
user=user, | |
) | |
await __event_emitter__( | |
{ | |
"type": "status", | |
"data": {"description": "Generated an image", "done": True}, | |
} | |
) | |
for image in images: | |
await __event_emitter__( | |
{ | |
"type": "message", | |
"data": {"content": f"\n"}, | |
} | |
) | |
system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>" | |
except Exception as e: | |
log.exception(e) | |
await __event_emitter__( | |
{ | |
"type": "status", | |
"data": { | |
"description": f"An error occurred while generating an image", | |
"done": True, | |
}, | |
} | |
) | |
system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>" | |
if system_message_content: | |
form_data["messages"] = add_or_update_system_message( | |
system_message_content, form_data["messages"] | |
) | |
return form_data | |
async def chat_completion_files_handler( | |
request: Request, body: dict, user: UserModel | |
) -> tuple[dict, dict[str, list]]: | |
sources = [] | |
if files := body.get("metadata", {}).get("files", None): | |
queries = [] | |
try: | |
queries_response = await generate_queries( | |
request, | |
{ | |
"model": body["model"], | |
"messages": body["messages"], | |
"type": "retrieval", | |
}, | |
user, | |
) | |
queries_response = queries_response["choices"][0]["message"]["content"] | |
try: | |
bracket_start = queries_response.find("{") | |
bracket_end = queries_response.rfind("}") + 1 | |
if bracket_start == -1 or bracket_end == -1: | |
raise Exception("No JSON object found in the response") | |
queries_response = queries_response[bracket_start:bracket_end] | |
queries_response = json.loads(queries_response) | |
except Exception as e: | |
queries_response = {"queries": [queries_response]} | |
queries = queries_response.get("queries", []) | |
except: | |
pass | |
if len(queries) == 0: | |
queries = [get_last_user_message(body["messages"])] | |
try: | |
# Offload get_sources_from_files to a separate thread | |
loop = asyncio.get_running_loop() | |
with ThreadPoolExecutor() as executor: | |
sources = await loop.run_in_executor( | |
executor, | |
lambda: get_sources_from_files( | |
request=request, | |
files=files, | |
queries=queries, | |
embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION( | |
query, user=user | |
), | |
k=request.app.state.config.TOP_K, | |
reranking_function=request.app.state.rf, | |
r=request.app.state.config.RELEVANCE_THRESHOLD, | |
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, | |
full_context=request.app.state.config.RAG_FULL_CONTEXT, | |
), | |
) | |
except Exception as e: | |
log.exception(e) | |
log.debug(f"rag_contexts:sources: {sources}") | |
return body, {"sources": sources} | |
def apply_params_to_form_data(form_data, model): | |
params = form_data.pop("params", {}) | |
if model.get("ollama"): | |
form_data["options"] = params | |
if "format" in params: | |
form_data["format"] = params["format"] | |
if "keep_alive" in params: | |
form_data["keep_alive"] = params["keep_alive"] | |
else: | |
if "seed" in params: | |
form_data["seed"] = params["seed"] | |
if "stop" in params: | |
form_data["stop"] = params["stop"] | |
if "temperature" in params: | |
form_data["temperature"] = params["temperature"] | |
if "max_tokens" in params: | |
form_data["max_tokens"] = params["max_tokens"] | |
if "top_p" in params: | |
form_data["top_p"] = params["top_p"] | |
if "frequency_penalty" in params: | |
form_data["frequency_penalty"] = params["frequency_penalty"] | |
if "reasoning_effort" in params: | |
form_data["reasoning_effort"] = params["reasoning_effort"] | |
if "logit_bias" in params: | |
try: | |
form_data["logit_bias"] = json.loads( | |
convert_logit_bias_input_to_json(params["logit_bias"]) | |
) | |
except Exception as e: | |
print(f"Error parsing logit_bias: {e}") | |
return form_data | |
async def process_chat_payload(request, form_data, user, metadata, model): | |
form_data = apply_params_to_form_data(form_data, model) | |
log.debug(f"form_data: {form_data}") | |
event_emitter = get_event_emitter(metadata) | |
event_call = get_event_call(metadata) | |
extra_params = { | |
"__event_emitter__": event_emitter, | |
"__event_call__": event_call, | |
"__user__": { | |
"id": user.id, | |
"email": user.email, | |
"name": user.name, | |
"role": user.role, | |
}, | |
"__metadata__": metadata, | |
"__request__": request, | |
"__model__": model, | |
} | |
# Initialize events to store additional event to be sent to the client | |
# Initialize contexts and citation | |
if getattr(request.state, "direct", False) and hasattr(request.state, "model"): | |
models = { | |
request.state.model["id"]: request.state.model, | |
} | |
else: | |
models = request.app.state.MODELS | |
task_model_id = get_task_model_id( | |
form_data["model"], | |
request.app.state.config.TASK_MODEL, | |
request.app.state.config.TASK_MODEL_EXTERNAL, | |
models, | |
) | |
events = [] | |
sources = [] | |
user_message = get_last_user_message(form_data["messages"]) | |
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False) | |
if model_knowledge: | |
await event_emitter( | |
{ | |
"type": "status", | |
"data": { | |
"action": "knowledge_search", | |
"query": user_message, | |
"done": False, | |
}, | |
} | |
) | |
knowledge_files = [] | |
for item in model_knowledge: | |
if item.get("collection_name"): | |
knowledge_files.append( | |
{ | |
"id": item.get("collection_name"), | |
"name": item.get("name"), | |
"legacy": True, | |
} | |
) | |
elif item.get("collection_names"): | |
knowledge_files.append( | |
{ | |
"name": item.get("name"), | |
"type": "collection", | |
"collection_names": item.get("collection_names"), | |
"legacy": True, | |
} | |
) | |
else: | |
knowledge_files.append(item) | |
files = form_data.get("files", []) | |
files.extend(knowledge_files) | |
form_data["files"] = files | |
variables = form_data.pop("variables", None) | |
# Process the form_data through the pipeline | |
try: | |
form_data = await process_pipeline_inlet_filter( | |
request, form_data, user, models | |
) | |
except Exception as e: | |
raise e | |
try: | |
form_data, flags = await process_filter_functions( | |
request=request, | |
filter_ids=get_sorted_filter_ids(model), | |
filter_type="inlet", | |
form_data=form_data, | |
extra_params=extra_params, | |
) | |
except Exception as e: | |
raise Exception(f"Error: {e}") | |
features = form_data.pop("features", None) | |
if features: | |
if "web_search" in features and features["web_search"]: | |
form_data = await chat_web_search_handler( | |
request, form_data, extra_params, user | |
) | |
if "image_generation" in features and features["image_generation"]: | |
form_data = await chat_image_generation_handler( | |
request, form_data, extra_params, user | |
) | |
if "code_interpreter" in features and features["code_interpreter"]: | |
form_data["messages"] = add_or_update_user_message( | |
( | |
request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE | |
if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != "" | |
else DEFAULT_CODE_INTERPRETER_PROMPT | |
), | |
form_data["messages"], | |
) | |
tool_ids = form_data.pop("tool_ids", None) | |
files = form_data.pop("files", None) | |
# Remove files duplicates | |
if files: | |
files = list({json.dumps(f, sort_keys=True): f for f in files}.values()) | |
metadata = { | |
**metadata, | |
"tool_ids": tool_ids, | |
"files": files, | |
} | |
form_data["metadata"] = metadata | |
tool_ids = metadata.get("tool_ids", None) | |
log.debug(f"{tool_ids=}") | |
if tool_ids: | |
# If tool_ids field is present, then get the tools | |
tools = get_tools( | |
request, | |
tool_ids, | |
user, | |
{ | |
**extra_params, | |
"__model__": models[task_model_id], | |
"__messages__": form_data["messages"], | |
"__files__": metadata.get("files", []), | |
}, | |
) | |
log.info(f"{tools=}") | |
if metadata.get("function_calling") == "native": | |
# If the function calling is native, then call the tools function calling handler | |
metadata["tools"] = tools | |
form_data["tools"] = [ | |
{"type": "function", "function": tool.get("spec", {})} | |
for tool in tools.values() | |
] | |
else: | |
# If the function calling is not native, then call the tools function calling handler | |
try: | |
form_data, flags = await chat_completion_tools_handler( | |
request, form_data, user, models, tools | |
) | |
sources.extend(flags.get("sources", [])) | |
except Exception as e: | |
log.exception(e) | |
try: | |
form_data, flags = await chat_completion_files_handler(request, form_data, user) | |
sources.extend(flags.get("sources", [])) | |
except Exception as e: | |
log.exception(e) | |
# If context is not empty, insert it into the messages | |
if len(sources) > 0: | |
context_string = "" | |
for source_idx, source in enumerate(sources): | |
if "document" in source: | |
for doc_idx, doc_context in enumerate(source["document"]): | |
context_string += f"<source><source_id>{source_idx}</source_id><source_context>{doc_context}</source_context></source>\n" | |
context_string = context_string.strip() | |
prompt = get_last_user_message(form_data["messages"]) | |
if prompt is None: | |
raise Exception("No user message found") | |
if ( | |
request.app.state.config.RELEVANCE_THRESHOLD == 0 | |
and context_string.strip() == "" | |
): | |
log.debug( | |
f"With a 0 relevancy threshold for RAG, the context cannot be empty" | |
) | |
# Workaround for Ollama 2.0+ system prompt issue | |
# TODO: replace with add_or_update_system_message | |
if model.get("owned_by") == "ollama": | |
form_data["messages"] = prepend_to_first_user_message_content( | |
rag_template( | |
request.app.state.config.RAG_TEMPLATE, context_string, prompt | |
), | |
form_data["messages"], | |
) | |
else: | |
form_data["messages"] = add_or_update_system_message( | |
rag_template( | |
request.app.state.config.RAG_TEMPLATE, context_string, prompt | |
), | |
form_data["messages"], | |
) | |
# If there are citations, add them to the data_items | |
sources = [source for source in sources if source.get("source", {}).get("name", "")] | |
if len(sources) > 0: | |
events.append({"sources": sources}) | |
if model_knowledge: | |
await event_emitter( | |
{ | |
"type": "status", | |
"data": { | |
"action": "knowledge_search", | |
"query": user_message, | |
"done": True, | |
"hidden": True, | |
}, | |
} | |
) | |
return form_data, metadata, events | |
async def process_chat_response( | |
request, response, form_data, user, metadata, model, events, tasks | |
): | |
async def background_tasks_handler(): | |
message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) | |
message = message_map.get(metadata["message_id"]) if message_map else None | |
if message: | |
messages = get_message_list(message_map, message.get("id")) | |
if tasks and messages: | |
if TASKS.TITLE_GENERATION in tasks: | |
if tasks[TASKS.TITLE_GENERATION]: | |
res = await generate_title( | |
request, | |
{ | |
"model": message["model"], | |
"messages": messages, | |
"chat_id": metadata["chat_id"], | |
}, | |
user, | |
) | |
if res and isinstance(res, dict): | |
if len(res.get("choices", [])) == 1: | |
title_string = ( | |
res.get("choices", [])[0] | |
.get("message", {}) | |
.get("content", message.get("content", "New Chat")) | |
) | |
else: | |
title_string = "" | |
title_string = title_string[ | |
title_string.find("{") : title_string.rfind("}") + 1 | |
] | |
try: | |
title = json.loads(title_string).get( | |
"title", "New Chat" | |
) | |
except Exception as e: | |
title = "" | |
if not title: | |
title = messages[0].get("content", "New Chat") | |
Chats.update_chat_title_by_id(metadata["chat_id"], title) | |
await event_emitter( | |
{ | |
"type": "chat:title", | |
"data": title, | |
} | |
) | |
elif len(messages) == 2: | |
title = messages[0].get("content", "New Chat") | |
Chats.update_chat_title_by_id(metadata["chat_id"], title) | |
await event_emitter( | |
{ | |
"type": "chat:title", | |
"data": message.get("content", "New Chat"), | |
} | |
) | |
if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]: | |
res = await generate_chat_tags( | |
request, | |
{ | |
"model": message["model"], | |
"messages": messages, | |
"chat_id": metadata["chat_id"], | |
}, | |
user, | |
) | |
if res and isinstance(res, dict): | |
if len(res.get("choices", [])) == 1: | |
tags_string = ( | |
res.get("choices", [])[0] | |
.get("message", {}) | |
.get("content", "") | |
) | |
else: | |
tags_string = "" | |
tags_string = tags_string[ | |
tags_string.find("{") : tags_string.rfind("}") + 1 | |
] | |
try: | |
tags = json.loads(tags_string).get("tags", []) | |
Chats.update_chat_tags_by_id( | |
metadata["chat_id"], tags, user | |
) | |
await event_emitter( | |
{ | |
"type": "chat:tags", | |
"data": tags, | |
} | |
) | |
except Exception as e: | |
pass | |
event_emitter = None | |
event_caller = None | |
if ( | |
"session_id" in metadata | |
and metadata["session_id"] | |
and "chat_id" in metadata | |
and metadata["chat_id"] | |
and "message_id" in metadata | |
and metadata["message_id"] | |
): | |
event_emitter = get_event_emitter(metadata) | |
event_caller = get_event_call(metadata) | |
# Non-streaming response | |
if not isinstance(response, StreamingResponse): | |
if event_emitter: | |
if "selected_model_id" in response: | |
Chats.upsert_message_to_chat_by_id_and_message_id( | |
metadata["chat_id"], | |
metadata["message_id"], | |
{ | |
"selectedModelId": response["selected_model_id"], | |
}, | |
) | |
if response.get("choices", [])[0].get("message", {}).get("content"): | |
content = response["choices"][0]["message"]["content"] | |
if content: | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": response, | |
} | |
) | |
title = Chats.get_chat_title_by_id(metadata["chat_id"]) | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": { | |
"done": True, | |
"content": content, | |
"title": title, | |
}, | |
} | |
) | |
# Save message in the database | |
Chats.upsert_message_to_chat_by_id_and_message_id( | |
metadata["chat_id"], | |
metadata["message_id"], | |
{ | |
"content": content, | |
}, | |
) | |
# Send a webhook notification if the user is not active | |
if get_active_status_by_user_id(user.id) is None: | |
webhook_url = Users.get_user_webhook_url_by_id(user.id) | |
if webhook_url: | |
post_webhook( | |
request.app.state.WEBUI_NAME, | |
webhook_url, | |
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", | |
{ | |
"action": "chat", | |
"message": content, | |
"title": title, | |
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", | |
}, | |
) | |
await background_tasks_handler() | |
return response | |
else: | |
return response | |
# Non standard response | |
if not any( | |
content_type in response.headers["Content-Type"] | |
for content_type in ["text/event-stream", "application/x-ndjson"] | |
): | |
return response | |
extra_params = { | |
"__event_emitter__": event_emitter, | |
"__event_call__": event_caller, | |
"__user__": { | |
"id": user.id, | |
"email": user.email, | |
"name": user.name, | |
"role": user.role, | |
}, | |
"__metadata__": metadata, | |
"__request__": request, | |
"__model__": model, | |
} | |
filter_ids = get_sorted_filter_ids(model) | |
print(f"{filter_ids=}") | |
# Streaming response | |
if event_emitter and event_caller: | |
task_id = str(uuid4()) # Create a unique task ID. | |
model_id = form_data.get("model", "") | |
Chats.upsert_message_to_chat_by_id_and_message_id( | |
metadata["chat_id"], | |
metadata["message_id"], | |
{ | |
"model": model_id, | |
}, | |
) | |
def split_content_and_whitespace(content): | |
content_stripped = content.rstrip() | |
original_whitespace = ( | |
content[len(content_stripped) :] | |
if len(content) > len(content_stripped) | |
else "" | |
) | |
return content_stripped, original_whitespace | |
def is_opening_code_block(content): | |
backtick_segments = content.split("```") | |
# Even number of segments means the last backticks are opening a new block | |
return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 | |
# Handle as a background task | |
async def post_response_handler(response, events): | |
def serialize_content_blocks(content_blocks, raw=False): | |
content = "" | |
for block in content_blocks: | |
if block["type"] == "text": | |
content = f"{content}{block['content'].strip()}\n" | |
elif block["type"] == "tool_calls": | |
attributes = block.get("attributes", {}) | |
block_content = block.get("content", []) | |
results = block.get("results", []) | |
if results: | |
result_display_content = "" | |
for result in results: | |
tool_call_id = result.get("tool_call_id", "") | |
tool_name = "" | |
for tool_call in block_content: | |
if tool_call.get("id", "") == tool_call_id: | |
tool_name = tool_call.get("function", {}).get( | |
"name", "" | |
) | |
break | |
result_display_content = f"{result_display_content}\n> {tool_name}: {result.get('content', '')}" | |
if not raw: | |
content = f'{content}\n<details type="tool_calls" done="true" content="{html.escape(json.dumps(block_content))}" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n{result_display_content}\n</details>\n' | |
else: | |
tool_calls_display_content = "" | |
for tool_call in block_content: | |
tool_calls_display_content = f"{tool_calls_display_content}\n> Executing {tool_call.get('function', {}).get('name', '')}" | |
if not raw: | |
content = f'{content}\n<details type="tool_calls" done="false" content="{html.escape(json.dumps(block_content))}">\n<summary>Tool Executing...</summary>\n{tool_calls_display_content}\n</details>\n' | |
elif block["type"] == "reasoning": | |
reasoning_display_content = "\n".join( | |
(f"> {line}" if not line.startswith(">") else line) | |
for line in block["content"].splitlines() | |
) | |
reasoning_duration = block.get("duration", None) | |
if reasoning_duration is not None: | |
if raw: | |
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' | |
else: | |
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n' | |
else: | |
if raw: | |
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' | |
else: | |
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n' | |
elif block["type"] == "code_interpreter": | |
attributes = block.get("attributes", {}) | |
output = block.get("output", None) | |
lang = attributes.get("lang", "") | |
content_stripped, original_whitespace = ( | |
split_content_and_whitespace(content) | |
) | |
if is_opening_code_block(content_stripped): | |
# Remove trailing backticks that would open a new block | |
content = ( | |
content_stripped.rstrip("`").rstrip() | |
+ original_whitespace | |
) | |
else: | |
# Keep content as is - either closing backticks or no backticks | |
content = content_stripped + original_whitespace | |
if output: | |
output = html.escape(json.dumps(output)) | |
if raw: | |
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n' | |
else: | |
content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n' | |
else: | |
if raw: | |
content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n' | |
else: | |
content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n' | |
else: | |
block_content = str(block["content"]).strip() | |
content = f"{content}{block['type']}: {block_content}\n" | |
return content.strip() | |
def convert_content_blocks_to_messages(content_blocks): | |
messages = [] | |
temp_blocks = [] | |
for idx, block in enumerate(content_blocks): | |
if block["type"] == "tool_calls": | |
messages.append( | |
{ | |
"role": "assistant", | |
"content": serialize_content_blocks(temp_blocks), | |
"tool_calls": block.get("content"), | |
} | |
) | |
results = block.get("results", []) | |
for result in results: | |
messages.append( | |
{ | |
"role": "tool", | |
"tool_call_id": result["tool_call_id"], | |
"content": result["content"], | |
} | |
) | |
temp_blocks = [] | |
else: | |
temp_blocks.append(block) | |
if temp_blocks: | |
content = serialize_content_blocks(temp_blocks) | |
if content: | |
messages.append( | |
{ | |
"role": "assistant", | |
"content": content, | |
} | |
) | |
return messages | |
def tag_content_handler(content_type, tags, content, content_blocks): | |
end_flag = False | |
def extract_attributes(tag_content): | |
"""Extract attributes from a tag if they exist.""" | |
attributes = {} | |
if not tag_content: # Ensure tag_content is not None | |
return attributes | |
# Match attributes in the format: key="value" (ignores single quotes for simplicity) | |
matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content) | |
for key, value in matches: | |
attributes[key] = value | |
return attributes | |
if content_blocks[-1]["type"] == "text": | |
for start_tag, end_tag in tags: | |
# Match start tag e.g., <tag> or <tag attr="value"> | |
start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>" | |
match = re.search(start_tag_pattern, content) | |
if match: | |
attr_content = ( | |
match.group(1) if match.group(1) else "" | |
) # Ensure it's not None | |
attributes = extract_attributes( | |
attr_content | |
) # Extract attributes safely | |
# Capture everything before and after the matched tag | |
before_tag = content[ | |
: match.start() | |
] # Content before opening tag | |
after_tag = content[ | |
match.end() : | |
] # Content after opening tag | |
# Remove the start tag and after from the currently handling text block | |
content_blocks[-1]["content"] = content_blocks[-1][ | |
"content" | |
].replace(match.group(0) + after_tag, "") | |
if before_tag: | |
content_blocks[-1]["content"] = before_tag | |
if not content_blocks[-1]["content"]: | |
content_blocks.pop() | |
# Append the new block | |
content_blocks.append( | |
{ | |
"type": content_type, | |
"start_tag": start_tag, | |
"end_tag": end_tag, | |
"attributes": attributes, | |
"content": "", | |
"started_at": time.time(), | |
} | |
) | |
if after_tag: | |
content_blocks[-1]["content"] = after_tag | |
break | |
elif content_blocks[-1]["type"] == content_type: | |
start_tag = content_blocks[-1]["start_tag"] | |
end_tag = content_blocks[-1]["end_tag"] | |
# Match end tag e.g., </tag> | |
end_tag_pattern = rf"<{re.escape(end_tag)}>" | |
# Check if the content has the end tag | |
if re.search(end_tag_pattern, content): | |
end_flag = True | |
block_content = content_blocks[-1]["content"] | |
# Strip start and end tags from the content | |
start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>" | |
block_content = re.sub( | |
start_tag_pattern, "", block_content | |
).strip() | |
end_tag_regex = re.compile(end_tag_pattern, re.DOTALL) | |
split_content = end_tag_regex.split(block_content, maxsplit=1) | |
# Content inside the tag | |
block_content = ( | |
split_content[0].strip() if split_content else "" | |
) | |
# Leftover content (everything after `</tag>`) | |
leftover_content = ( | |
split_content[1].strip() if len(split_content) > 1 else "" | |
) | |
if block_content: | |
content_blocks[-1]["content"] = block_content | |
content_blocks[-1]["ended_at"] = time.time() | |
content_blocks[-1]["duration"] = int( | |
content_blocks[-1]["ended_at"] | |
- content_blocks[-1]["started_at"] | |
) | |
# Reset the content_blocks by appending a new text block | |
if content_type != "code_interpreter": | |
if leftover_content: | |
content_blocks.append( | |
{ | |
"type": "text", | |
"content": leftover_content, | |
} | |
) | |
else: | |
content_blocks.append( | |
{ | |
"type": "text", | |
"content": "", | |
} | |
) | |
else: | |
# Remove the block if content is empty | |
content_blocks.pop() | |
if leftover_content: | |
content_blocks.append( | |
{ | |
"type": "text", | |
"content": leftover_content, | |
} | |
) | |
else: | |
content_blocks.append( | |
{ | |
"type": "text", | |
"content": "", | |
} | |
) | |
# Clean processed content | |
content = re.sub( | |
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>", | |
"", | |
content, | |
flags=re.DOTALL, | |
) | |
return content, content_blocks, end_flag | |
message = Chats.get_message_by_id_and_message_id( | |
metadata["chat_id"], metadata["message_id"] | |
) | |
tool_calls = [] | |
last_assistant_message = None | |
try: | |
if form_data["messages"][-1]["role"] == "assistant": | |
last_assistant_message = get_last_assistant_message( | |
form_data["messages"] | |
) | |
except Exception as e: | |
pass | |
content = ( | |
message.get("content", "") | |
if message | |
else last_assistant_message if last_assistant_message else "" | |
) | |
content_blocks = [ | |
{ | |
"type": "text", | |
"content": content, | |
} | |
] | |
# We might want to disable this by default | |
DETECT_REASONING = True | |
DETECT_SOLUTION = True | |
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get( | |
"code_interpreter", False | |
) | |
reasoning_tags = [ | |
("think", "/think"), | |
("thinking", "/thinking"), | |
("reason", "/reason"), | |
("reasoning", "/reasoning"), | |
("thought", "/thought"), | |
("Thought", "/Thought"), | |
("|begin_of_thought|", "|end_of_thought|"), | |
] | |
code_interpreter_tags = [("code_interpreter", "/code_interpreter")] | |
solution_tags = [("|begin_of_solution|", "|end_of_solution|")] | |
try: | |
for event in events: | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": event, | |
} | |
) | |
# Save message in the database | |
Chats.upsert_message_to_chat_by_id_and_message_id( | |
metadata["chat_id"], | |
metadata["message_id"], | |
{ | |
**event, | |
}, | |
) | |
async def stream_body_handler(response): | |
nonlocal content | |
nonlocal content_blocks | |
response_tool_calls = [] | |
async for line in response.body_iterator: | |
line = line.decode("utf-8") if isinstance(line, bytes) else line | |
data = line | |
# Skip empty lines | |
if not data.strip(): | |
continue | |
# "data:" is the prefix for each event | |
if not data.startswith("data:"): | |
continue | |
# Remove the prefix | |
data = data[len("data:") :].strip() | |
try: | |
data = json.loads(data) | |
data, _ = await process_filter_functions( | |
request=request, | |
filter_ids=filter_ids, | |
filter_type="stream", | |
form_data=data, | |
extra_params=extra_params, | |
) | |
if data: | |
if "selected_model_id" in data: | |
model_id = data["selected_model_id"] | |
Chats.upsert_message_to_chat_by_id_and_message_id( | |
metadata["chat_id"], | |
metadata["message_id"], | |
{ | |
"selectedModelId": model_id, | |
}, | |
) | |
else: | |
choices = data.get("choices", []) | |
if not choices: | |
usage = data.get("usage", {}) | |
if usage: | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": { | |
"usage": usage, | |
}, | |
} | |
) | |
continue | |
delta = choices[0].get("delta", {}) | |
delta_tool_calls = delta.get("tool_calls", None) | |
if delta_tool_calls: | |
for delta_tool_call in delta_tool_calls: | |
tool_call_index = delta_tool_call.get( | |
"index" | |
) | |
if tool_call_index is not None: | |
if ( | |
len(response_tool_calls) | |
<= tool_call_index | |
): | |
response_tool_calls.append( | |
delta_tool_call | |
) | |
else: | |
delta_name = delta_tool_call.get( | |
"function", {} | |
).get("name") | |
delta_arguments = ( | |
delta_tool_call.get( | |
"function", {} | |
).get("arguments") | |
) | |
if delta_name: | |
response_tool_calls[ | |
tool_call_index | |
]["function"][ | |
"name" | |
] += delta_name | |
if delta_arguments: | |
response_tool_calls[ | |
tool_call_index | |
]["function"][ | |
"arguments" | |
] += delta_arguments | |
value = delta.get("content") | |
if value: | |
content = f"{content}{value}" | |
if not content_blocks: | |
content_blocks.append( | |
{ | |
"type": "text", | |
"content": "", | |
} | |
) | |
content_blocks[-1]["content"] = ( | |
content_blocks[-1]["content"] + value | |
) | |
if DETECT_REASONING: | |
content, content_blocks, _ = ( | |
tag_content_handler( | |
"reasoning", | |
reasoning_tags, | |
content, | |
content_blocks, | |
) | |
) | |
if DETECT_CODE_INTERPRETER: | |
content, content_blocks, end = ( | |
tag_content_handler( | |
"code_interpreter", | |
code_interpreter_tags, | |
content, | |
content_blocks, | |
) | |
) | |
if end: | |
break | |
if DETECT_SOLUTION: | |
content, content_blocks, _ = ( | |
tag_content_handler( | |
"solution", | |
solution_tags, | |
content, | |
content_blocks, | |
) | |
) | |
if ENABLE_REALTIME_CHAT_SAVE: | |
# Save message in the database | |
Chats.upsert_message_to_chat_by_id_and_message_id( | |
metadata["chat_id"], | |
metadata["message_id"], | |
{ | |
"content": serialize_content_blocks( | |
content_blocks | |
), | |
}, | |
) | |
else: | |
data = { | |
"content": serialize_content_blocks( | |
content_blocks | |
), | |
} | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": data, | |
} | |
) | |
except Exception as e: | |
done = "data: [DONE]" in line | |
if done: | |
pass | |
else: | |
log.debug("Error: ", e) | |
continue | |
if content_blocks: | |
# Clean up the last text block | |
if content_blocks[-1]["type"] == "text": | |
content_blocks[-1]["content"] = content_blocks[-1][ | |
"content" | |
].strip() | |
if not content_blocks[-1]["content"]: | |
content_blocks.pop() | |
if not content_blocks: | |
content_blocks.append( | |
{ | |
"type": "text", | |
"content": "", | |
} | |
) | |
if response_tool_calls: | |
tool_calls.append(response_tool_calls) | |
if response.background: | |
await response.background() | |
await stream_body_handler(response) | |
MAX_TOOL_CALL_RETRIES = 5 | |
tool_call_retries = 0 | |
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES: | |
tool_call_retries += 1 | |
response_tool_calls = tool_calls.pop(0) | |
content_blocks.append( | |
{ | |
"type": "tool_calls", | |
"content": response_tool_calls, | |
} | |
) | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": { | |
"content": serialize_content_blocks(content_blocks), | |
}, | |
} | |
) | |
tools = metadata.get("tools", {}) | |
results = [] | |
for tool_call in response_tool_calls: | |
tool_call_id = tool_call.get("id", "") | |
tool_name = tool_call.get("function", {}).get("name", "") | |
tool_function_params = {} | |
try: | |
# json.loads cannot be used because some models do not produce valid JSON | |
tool_function_params = ast.literal_eval( | |
tool_call.get("function", {}).get("arguments", "{}") | |
) | |
except Exception as e: | |
log.debug(e) | |
tool_result = None | |
if tool_name in tools: | |
tool = tools[tool_name] | |
spec = tool.get("spec", {}) | |
try: | |
required_params = spec.get("parameters", {}).get( | |
"required", [] | |
) | |
tool_function = tool["callable"] | |
tool_function_params = { | |
k: v | |
for k, v in tool_function_params.items() | |
if k in required_params | |
} | |
tool_result = await tool_function( | |
**tool_function_params | |
) | |
except Exception as e: | |
tool_result = str(e) | |
results.append( | |
{ | |
"tool_call_id": tool_call_id, | |
"content": tool_result, | |
} | |
) | |
content_blocks[-1]["results"] = results | |
content_blocks.append( | |
{ | |
"type": "text", | |
"content": "", | |
} | |
) | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": { | |
"content": serialize_content_blocks(content_blocks), | |
}, | |
} | |
) | |
try: | |
res = await generate_chat_completion( | |
request, | |
{ | |
"model": model_id, | |
"stream": True, | |
"tools": form_data["tools"], | |
"messages": [ | |
*form_data["messages"], | |
*convert_content_blocks_to_messages(content_blocks), | |
], | |
}, | |
user, | |
) | |
if isinstance(res, StreamingResponse): | |
await stream_body_handler(res) | |
else: | |
break | |
except Exception as e: | |
log.debug(e) | |
break | |
if DETECT_CODE_INTERPRETER: | |
MAX_RETRIES = 5 | |
retries = 0 | |
while ( | |
content_blocks[-1]["type"] == "code_interpreter" | |
and retries < MAX_RETRIES | |
): | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": { | |
"content": serialize_content_blocks(content_blocks), | |
}, | |
} | |
) | |
retries += 1 | |
log.debug(f"Attempt count: {retries}") | |
output = "" | |
try: | |
if content_blocks[-1]["attributes"].get("type") == "code": | |
code = content_blocks[-1]["content"] | |
if ( | |
request.app.state.config.CODE_INTERPRETER_ENGINE | |
== "pyodide" | |
): | |
output = await event_caller( | |
{ | |
"type": "execute:python", | |
"data": { | |
"id": str(uuid4()), | |
"code": code, | |
"session_id": metadata.get( | |
"session_id", None | |
), | |
}, | |
} | |
) | |
elif ( | |
request.app.state.config.CODE_INTERPRETER_ENGINE | |
== "jupyter" | |
): | |
output = await execute_code_jupyter( | |
request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, | |
code, | |
( | |
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN | |
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH | |
== "token" | |
else None | |
), | |
( | |
request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD | |
if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH | |
== "password" | |
else None | |
), | |
request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, | |
) | |
else: | |
output = { | |
"stdout": "Code interpreter engine not configured." | |
} | |
log.debug(f"Code interpreter output: {output}") | |
if isinstance(output, dict): | |
stdout = output.get("stdout", "") | |
if isinstance(stdout, str): | |
stdoutLines = stdout.split("\n") | |
for idx, line in enumerate(stdoutLines): | |
if "data:image/png;base64" in line: | |
id = str(uuid4()) | |
# ensure the path exists | |
os.makedirs( | |
os.path.join(CACHE_DIR, "images"), | |
exist_ok=True, | |
) | |
image_path = os.path.join( | |
CACHE_DIR, | |
f"images/{id}.png", | |
) | |
with open(image_path, "wb") as f: | |
f.write( | |
base64.b64decode( | |
line.split(",")[1] | |
) | |
) | |
stdoutLines[idx] = ( | |
f"" | |
) | |
output["stdout"] = "\n".join(stdoutLines) | |
result = output.get("result", "") | |
if isinstance(result, str): | |
resultLines = result.split("\n") | |
for idx, line in enumerate(resultLines): | |
if "data:image/png;base64" in line: | |
id = str(uuid4()) | |
# ensure the path exists | |
os.makedirs( | |
os.path.join(CACHE_DIR, "images"), | |
exist_ok=True, | |
) | |
image_path = os.path.join( | |
CACHE_DIR, | |
f"images/{id}.png", | |
) | |
with open(image_path, "wb") as f: | |
f.write( | |
base64.b64decode( | |
line.split(",")[1] | |
) | |
) | |
resultLines[idx] = ( | |
f"" | |
) | |
output["result"] = "\n".join(resultLines) | |
except Exception as e: | |
output = str(e) | |
content_blocks[-1]["output"] = output | |
content_blocks.append( | |
{ | |
"type": "text", | |
"content": "", | |
} | |
) | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": { | |
"content": serialize_content_blocks(content_blocks), | |
}, | |
} | |
) | |
log.info(f"content_blocks={content_blocks}") | |
log.info( | |
f"serialize_content_blocks={serialize_content_blocks(content_blocks)}" | |
) | |
try: | |
res = await generate_chat_completion( | |
request, | |
{ | |
"model": model_id, | |
"stream": True, | |
"messages": [ | |
*form_data["messages"], | |
{ | |
"role": "assistant", | |
"content": serialize_content_blocks( | |
content_blocks, raw=True | |
), | |
}, | |
], | |
}, | |
user, | |
) | |
if isinstance(res, StreamingResponse): | |
await stream_body_handler(res) | |
else: | |
break | |
except Exception as e: | |
log.debug(e) | |
break | |
title = Chats.get_chat_title_by_id(metadata["chat_id"]) | |
data = { | |
"done": True, | |
"content": serialize_content_blocks(content_blocks), | |
"title": title, | |
} | |
if not ENABLE_REALTIME_CHAT_SAVE: | |
# Save message in the database | |
Chats.upsert_message_to_chat_by_id_and_message_id( | |
metadata["chat_id"], | |
metadata["message_id"], | |
{ | |
"content": serialize_content_blocks(content_blocks), | |
}, | |
) | |
# Send a webhook notification if the user is not active | |
if get_active_status_by_user_id(user.id) is None: | |
webhook_url = Users.get_user_webhook_url_by_id(user.id) | |
if webhook_url: | |
post_webhook( | |
request.app.state.WEBUI_NAME, | |
webhook_url, | |
f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", | |
{ | |
"action": "chat", | |
"message": content, | |
"title": title, | |
"url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", | |
}, | |
) | |
await event_emitter( | |
{ | |
"type": "chat:completion", | |
"data": data, | |
} | |
) | |
await background_tasks_handler() | |
except asyncio.CancelledError: | |
log.warning("Task was cancelled!") | |
await event_emitter({"type": "task-cancelled"}) | |
if not ENABLE_REALTIME_CHAT_SAVE: | |
# Save message in the database | |
Chats.upsert_message_to_chat_by_id_and_message_id( | |
metadata["chat_id"], | |
metadata["message_id"], | |
{ | |
"content": serialize_content_blocks(content_blocks), | |
}, | |
) | |
if response.background is not None: | |
await response.background() | |
# background_tasks.add_task(post_response_handler, response, events) | |
task_id, _ = create_task(post_response_handler(response, events)) | |
return {"status": True, "task_id": task_id} | |
else: | |
# Fallback to the original response | |
async def stream_wrapper(original_generator, events): | |
def wrap_item(item): | |
return f"data: {item}\n\n" | |
for event in events: | |
event, _ = await process_filter_functions( | |
request=request, | |
filter_ids=filter_ids, | |
filter_type="stream", | |
form_data=event, | |
extra_params=extra_params, | |
) | |
if event: | |
yield wrap_item(json.dumps(event)) | |
async for data in original_generator: | |
data, _ = await process_filter_functions( | |
request=request, | |
filter_ids=filter_ids, | |
filter_type="stream", | |
form_data=data, | |
extra_params=extra_params, | |
) | |
if data: | |
yield data | |
return StreamingResponse( | |
stream_wrapper(response.body_iterator, events), | |
headers=dict(response.headers), | |
background=response.background, | |
) | |