Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import re | |
from typing import List, Optional, Tuple | |
from pydantic import BaseModel | |
from common.configuration import Configuration | |
from components.llm.common import ChatRequest, LlmParams, LlmPredictParams, Message | |
from components.llm.deepinfra_api import DeepInfraApi | |
from components.llm.prompts import PROMPT_QE | |
from components.services.dataset import DatasetService | |
from components.services.entity import EntityService | |
from components.services.llm_config import LLMConfigService | |
logger = logging.getLogger(__name__) | |
class QEResult(BaseModel): | |
use_search: bool | |
search_query: str | None | |
debug_message: Optional[str | None] = "" | |
class DialogueService: | |
def __init__( | |
self, | |
config: Configuration, | |
entity_service: EntityService, | |
dataset_service: DatasetService, | |
llm_api: DeepInfraApi, | |
llm_config_service: LLMConfigService, | |
) -> None: | |
self.prompt = PROMPT_QE | |
self.entity_service = entity_service | |
self.dataset_service = dataset_service | |
self.llm_api = llm_api | |
p = llm_config_service.get_default() | |
self.llm_params = LlmPredictParams( | |
temperature=p.temperature, | |
top_p=p.top_p, | |
min_p=p.min_p, | |
seed=p.seed, | |
frequency_penalty=p.frequency_penalty, | |
presence_penalty=p.presence_penalty, | |
n_predict=p.n_predict, | |
) | |
async def get_qe_result(self, history: List[Message]) -> QEResult: | |
""" | |
Получает результат QE. | |
Args: | |
history: История диалога в виде списка сообщений | |
Returns: | |
QEResult: Результат QE | |
""" | |
request = self._get_qe_request(history) | |
response = await self.llm_api.predict_chat_stream( | |
request, | |
"", | |
self.llm_params, | |
) | |
logger.info(f"QE response: {response}") | |
try: | |
return self._postprocess_qe(response) | |
except Exception as e: | |
logger.error(f"Error in _postprocess_qe: {e}") | |
from_chat = self._get_search_query(history) | |
return QEResult( | |
use_search=from_chat is not None, | |
search_query=from_chat.content if from_chat else None, | |
debug_message=response | |
) | |
def get_qe_result_from_chat(self, history: List[Message]) -> QEResult: | |
from_chat = self._get_search_query(history) | |
return QEResult( | |
use_search=from_chat is not None, | |
search_query=from_chat.content if from_chat else None, | |
) | |
def _get_qe_request(self, history: List[Message]) -> ChatRequest: | |
""" | |
Подготавливает полный промпт для QE запроса. | |
Args: | |
history: История диалога в виде списка сообщений | |
Returns: | |
str: Отформатированный промпт с историей диалога | |
""" | |
formatted_history = "\n".join( | |
[self._format_message(msg) for msg in history] | |
).strip() | |
message = self.prompt.format(history=formatted_history) | |
return ChatRequest( | |
history=[Message(role="user", content=message, searchResults='')] | |
) | |
def _format_message(self, message: Message) -> str: | |
""" | |
Форматирует сообщение для запроса QE. | |
Args: | |
message: Сообщение для форматирования | |
""" | |
if message.searchResults: | |
return f'{message.role}: {message.content}\n<search-results>\n{message.searchResults}\n</search-results>' | |
return f'{message.role}: {message.content}' | |
def _postprocess_qe(input_text: str) -> QEResult: | |
# Находим все вхождения квадратных скобок | |
matches = re.findall(r'\[([^\]]*)\]', input_text) | |
# Проверяем количество найденных скобок | |
if len(matches) != 2: | |
raise ValueError("В тексте должно быть ровно две пары квадратных скобок.") | |
# Извлекаем значения из скобок | |
first_part = matches[0].strip().lower() | |
second_part = matches[1].strip() | |
if first_part == "да": | |
bool_var = True | |
elif first_part == "нет": | |
bool_var = False | |
else: | |
raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.") | |
return QEResult(use_search=bool_var, search_query=second_part, | |
debug_message=input_text) | |
def _get_search_query(self, history: List[Message]) -> Message | None: | |
""" | |
Получает запрос для поиска на основе последнего сообщения пользователя. | |
""" | |
return next( | |
( | |
msg | |
for msg in reversed(history) | |
if msg.role == "user" | |
and (msg.searchResults is None or not msg.searchResults) | |
), | |
None, | |
) | |