Spaces:
Sleeping
Sleeping
import json | |
import logging | |
import os | |
from typing import Annotated, AsyncGenerator, List, Optional | |
from fastapi import APIRouter, Depends, HTTPException | |
from fastapi.responses import StreamingResponse | |
import common.dependencies as DI | |
from common import auth | |
from common.configuration import Configuration | |
from components.llm.common import (ChatRequest, LlmParams, LlmPredictParams, | |
Message) | |
from components.llm.deepinfra_api import DeepInfraApi | |
from components.llm.utils import append_llm_response_to_history | |
from components.services.dataset import DatasetService | |
from components.services.dialogue import DialogueService, QEResult | |
from components.services.entity import EntityService | |
from components.services.llm_config import LLMConfigService | |
from components.services.llm_prompt import LlmPromptService | |
router = APIRouter(prefix='/llm', tags=['LLM chat']) | |
logger = logging.getLogger(__name__) | |
conf = DI.get_config() | |
llm_params = LlmParams( | |
**{ | |
"url": conf.llm_config.base_url, | |
"model": conf.llm_config.model, | |
"tokenizer": "unsloth/Llama-3.3-70B-Instruct", | |
"type": "deepinfra", | |
"default": True, | |
"predict_params": LlmPredictParams( | |
temperature=0.15, | |
top_p=0.95, | |
min_p=0.05, | |
seed=42, | |
repetition_penalty=1.2, | |
presence_penalty=1.1, | |
n_predict=2000, | |
), | |
"api_key": os.environ.get(conf.llm_config.api_key_env), | |
"context_length": 128000, | |
} | |
) | |
# TODO: унести в DI | |
llm_api = DeepInfraApi(params=llm_params) | |
# TODO: Вынести | |
def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]: | |
return next( | |
( | |
msg | |
for msg in reversed(chat_request.history) | |
if msg.role == "user" | |
and (msg.searchResults is None or not msg.searchResults) | |
), | |
None, | |
) | |
def insert_search_results_to_message( | |
chat_request: ChatRequest, new_content: str | |
) -> bool: | |
for msg in reversed(chat_request.history): | |
if msg.role == "user" and ( | |
msg.searchResults is None or not msg.searchResults | |
): | |
msg.content = new_content | |
return True | |
return False | |
def try_insert_search_results( | |
chat_request: ChatRequest, search_results: List[str], entities: List[List[str]] | |
) -> bool: | |
i = 0 | |
for msg in reversed(chat_request.history): | |
if msg.role == "user" and not msg.searchResults: | |
msg.searchResults = search_results[i] | |
msg.searchEntities = entities[i] | |
i += 1 | |
if i == len(search_results): | |
return True | |
return False | |
def collapse_history_to_first_message(chat_request: ChatRequest) -> ChatRequest: | |
""" | |
Сворачивает историю в первое сообщение и возвращает новый объект ChatRequest. | |
Формат: | |
<search-results>[Источник] - текст</search-results> | |
role: текст сообщения | |
""" | |
if not chat_request.history: | |
return ChatRequest(history=[]) | |
# Собираем историю в одну строку | |
collapsed_content = [] | |
for msg in chat_request.history: | |
# Добавляем search-results, если они есть | |
if msg.searchResults: | |
collapsed_content.append(f"<search-results>{msg.searchResults}</search-results>") | |
# Добавляем текст сообщения с указанием роли | |
if msg.content.strip(): | |
collapsed_content.append(f"{msg.role}: {msg.content.strip()}") | |
# Формируем финальный текст с переносами строк | |
new_content = "\n".join(collapsed_content) | |
# Создаем новое сообщение и новый объект ChatRequest | |
new_message = Message( | |
role='user', | |
content=new_content, | |
searchResults='' | |
) | |
return ChatRequest(history=[new_message]) | |
async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prompt: str, | |
predict_params: LlmPredictParams, | |
dataset_service: DatasetService, | |
entity_service: EntityService, | |
dialogue_service: DialogueService) -> AsyncGenerator[str, None]: | |
""" | |
Генератор для стриминга ответа LLM через SSE. | |
""" | |
try: | |
qe_result = await dialogue_service.get_qe_result(request.history) | |
qe_event = { | |
"event": "debug", | |
"data": { | |
"text": qe_result.debug_message | |
} | |
} | |
yield f"data: {json.dumps(qe_event, ensure_ascii=False)}\n\n" | |
except Exception as e: | |
logger.error(f"Error in SSE chat stream while dialogue_service.get_qe_result: {str(e)}", stack_info=True) | |
yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n" | |
qe_result = dialogue_service.get_qe_result_from_chat(request.history) | |
try: | |
if qe_result.use_search and qe_result.search_query is not None: | |
dataset = dataset_service.get_current_dataset() | |
if dataset is None: | |
raise HTTPException(status_code=400, detail="Dataset not found") | |
previous_entities = [msg.searchEntities for msg in request.history if msg.searchEntities is not None] | |
previous_entities, chunk_ids, scores = entity_service.search_similar(qe_result.search_query, | |
dataset.id, previous_entities) | |
text_chunks = entity_service.build_text(chunk_ids, scores) | |
all_text_chunks = [text_chunks] + [entity_service.build_text(entities) for entities in previous_entities] | |
all_entities = [chunk_ids] + previous_entities | |
search_results_event = { | |
"event": "search_results", | |
"data": { | |
"text": text_chunks, | |
"ids": chunk_ids | |
} | |
} | |
yield f"data: {json.dumps(search_results_event, ensure_ascii=False)}\n\n" | |
# new_message = f'<search-results>\n{text_chunks}\n</search-results>\n{last_query.content}' | |
try_insert_search_results(request, all_text_chunks, all_entities) | |
except Exception as e: | |
logger.error(f"Error in SSE chat stream while searching: {str(e)}", stack_info=True) | |
yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n" | |
try: | |
# Сворачиваем историю в первое сообщение | |
collapsed_request = collapse_history_to_first_message(request) | |
# Стриминг токенов ответа | |
async for token in llm_api.get_predict_chat_generator(collapsed_request, system_prompt, predict_params): | |
token_event = {"event": "token", "data": token} | |
# logger.info(f"Streaming token: {token}") | |
yield f"data: {json.dumps(token_event, ensure_ascii=False)}\n\n" | |
# Финальное событие | |
yield "data: {\"event\": \"done\"}\n\n" | |
except Exception as e: | |
logger.error(f"Error in SSE chat stream while generating response: {str(e)}", stack_info=True) | |
yield "data: {\"event\": \"error\", \"data\":\""+str(e)+"\" }\n\n" | |
async def chat_stream( | |
request: ChatRequest, | |
config: Annotated[Configuration, Depends(DI.get_config)], | |
llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)], | |
prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)], | |
llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)], | |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)], | |
dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)], | |
dialogue_service: Annotated[DialogueService, Depends(DI.get_dialogue_service)], | |
current_user: Annotated[any, Depends(auth.get_current_user)] | |
): | |
try: | |
p = llm_config_service.get_default() | |
system_prompt = prompt_service.get_default() | |
predict_params = LlmPredictParams( | |
temperature=p.temperature, | |
top_p=p.top_p, | |
min_p=p.min_p, | |
seed=p.seed, | |
frequency_penalty=p.frequency_penalty, | |
presence_penalty=p.presence_penalty, | |
n_predict=p.n_predict, | |
stop=[], | |
) | |
headers = { | |
"Content-Type": "text/event-stream", | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"Access-Control-Allow-Origin": "*", | |
} | |
return StreamingResponse( | |
sse_generator(request, llm_api, system_prompt.text, predict_params, dataset_service, entity_service, dialogue_service), | |
media_type="text/event-stream", | |
headers=headers | |
) | |
except Exception as e: | |
logger.error(f"Error in SSE chat stream: {str(e)}", stack_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def chat( | |
request: ChatRequest, | |
config: Annotated[Configuration, Depends(DI.get_config)], | |
llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)], | |
prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)], | |
llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)], | |
entity_service: Annotated[EntityService, Depends(DI.get_entity_service)], | |
dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)], | |
dialogue_service: Annotated[DialogueService, Depends(DI.get_dialogue_service)], | |
): | |
try: | |
p = llm_config_service.get_default() | |
system_prompt = prompt_service.get_default() | |
predict_params = LlmPredictParams( | |
temperature=p.temperature, | |
top_p=p.top_p, | |
min_p=p.min_p, | |
seed=p.seed, | |
frequency_penalty=p.frequency_penalty, | |
presence_penalty=p.presence_penalty, | |
n_predict=p.n_predict, | |
stop=[], | |
) | |
try: | |
qe_result = await dialogue_service.get_qe_result(request.history) | |
except Exception as e: | |
logger.error(f"Error in chat while dialogue_service.get_qe_result: {str(e)}", stack_info=True) | |
qe_result = dialogue_service.get_qe_result_from_chat(request.history) | |
last_message = get_last_user_message(request) | |
logger.info(f"qe_result: {qe_result}") | |
if qe_result.use_search and qe_result.search_query is not None: | |
dataset = dataset_service.get_current_dataset() | |
if dataset is None: | |
raise HTTPException(status_code=400, detail="Dataset not found") | |
logger.info(f"qe_result.search_query: {qe_result.search_query}") | |
previous_entities = [msg.searchEntities for msg in request.history] | |
previous_entities, chunk_ids, scores = entity_service.search_similar( | |
qe_result.search_query, dataset.id, previous_entities | |
) | |
chunks = entity_service.chunk_repository.get_entities_by_ids(chunk_ids) | |
logger.info(f"chunk_ids: {chunk_ids[:3]}...{chunk_ids[-3:]}") | |
logger.info(f"scores: {scores[:3]}...{scores[-3:]}") | |
text_chunks = entity_service.build_text(chunks, scores) | |
logger.info(f"text_chunks: {text_chunks[:3]}...{text_chunks[-3:]}") | |
new_message = f'{last_message.content} /n<search-results>/n{text_chunks}/n</search-results>' | |
insert_search_results_to_message(request, new_message) | |
logger.info(f"request: {request}") | |
response = await llm_api.predict_chat_stream( | |
request, system_prompt.text, predict_params | |
) | |
result = append_llm_response_to_history(request, response) | |
return result | |
except Exception as e: | |
logger.error( | |
f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10 | |
) | |
return {"error": str(e)} |