muryshev's picture
update
fd78d64
raw
history blame
4.95 kB
import logging
import re
from typing import List, Optional
from pydantic import BaseModel
from components.llm.common import ChatRequest, LlmPredictParams, Message
from components.llm.deepinfra_api import DeepInfraApi
from components.llm.prompts import PROMPT_QE
from components.services.llm_config import LLMConfigService
logger = logging.getLogger(__name__)
class QEResult(BaseModel):
use_search: bool
search_query: str | None
debug_message: Optional[str | None] = ""
class DialogueService:
def __init__(
self,
llm_api: DeepInfraApi,
llm_config_service: LLMConfigService,
) -> None:
self.prompt = PROMPT_QE
self.llm_api = llm_api
p = llm_config_service.get_default()
self.llm_params = LlmPredictParams(
temperature=p.temperature,
top_p=p.top_p,
min_p=p.min_p,
seed=p.seed,
frequency_penalty=p.frequency_penalty,
presence_penalty=p.presence_penalty,
n_predict=p.n_predict,
)
async def get_qe_result(self, history: List[Message]) -> QEResult:
"""
Получает результат QE.
Args:
history: История диалога в виде списка сообщений
Returns:
QEResult: Результат QE
"""
request = self._get_qe_request(history)
response = await self.llm_api.predict_chat_stream(
request,
"",
self.llm_params,
)
logger.info(f"QE response: {response}")
try:
return self._postprocess_qe(response)
except Exception as e:
logger.error(f"Error in _postprocess_qe: {e}")
from_chat = self._get_search_query(history)
return QEResult(
use_search=from_chat is not None,
search_query=from_chat.content if from_chat else None,
debug_message=response,
)
def get_qe_result_from_chat(self, history: List[Message]) -> QEResult:
from_chat = self._get_search_query(history)
return QEResult(
use_search=from_chat is not None,
search_query=from_chat.content if from_chat else None,
)
def _get_qe_request(self, history: List[Message]) -> ChatRequest:
"""
Подготавливает полный промпт для QE запроса.
Args:
history: История диалога в виде списка сообщений
Returns:
str: Отформатированный промпт с историей диалога
"""
formatted_history = "\n".join(
[self._format_message(msg) for msg in history]
).strip()
message = self.prompt.format(history=formatted_history)
return ChatRequest(
history=[Message(role="user", content=message, searchResults='')]
)
def _format_message(self, message: Message) -> str:
"""
Форматирует сообщение для запроса QE.
Args:
message: Сообщение для форматирования
"""
# if message.searchResults:
# return f'{message.role}: {message.content}\n<search-results>\n{message.searchResults}\n</search-results>'
return f'{message.role}: {message.content}'
@staticmethod
def _postprocess_qe(input_text: str) -> QEResult:
# Находим все вхождения квадратных скобок
matches = re.findall(r'\[([^\]]*)\]', input_text)
# Проверяем количество найденных скобок
if len(matches) != 2:
raise ValueError("В тексте должно быть ровно две пары квадратных скобок.")
# Извлекаем значения из скобок
first_part = matches[0].strip().lower()
second_part = matches[1].strip()
if first_part == "да":
bool_var = True
elif first_part == "нет":
bool_var = False
else:
raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.")
return QEResult(
use_search=bool_var, search_query=second_part, debug_message=input_text
)
def _get_search_query(self, history: List[Message]) -> Message | None:
"""
Получает запрос для поиска на основе последнего сообщения пользователя.
"""
return next(
(
msg
for msg in reversed(history)
if msg.role == "user"
and (msg.searchResults is None or not msg.searchResults)
),
None,
)