Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import re | |
| from typing import List | |
| from pydantic import BaseModel | |
| from common.configuration import Configuration | |
| from components.llm.common import ChatRequest, LlmParams, LlmPredictParams, Message | |
| from components.llm.deepinfra_api import DeepInfraApi | |
| from components.llm.prompts import PROMPT_QE | |
| from components.services.dataset import DatasetService | |
| from components.services.entity import EntityService | |
| from components.services.llm_config import LLMConfigService | |
| logger = logging.getLogger(__name__) | |
| class QEResult(BaseModel): | |
| use_search: bool | |
| search_query: str | None | |
| class DialogueService: | |
| def __init__( | |
| self, | |
| config: Configuration, | |
| entity_service: EntityService, | |
| dataset_service: DatasetService, | |
| llm_api: DeepInfraApi, | |
| llm_config_service: LLMConfigService, | |
| ) -> None: | |
| self.prompt = PROMPT_QE | |
| self.entity_service = entity_service | |
| self.dataset_service = dataset_service | |
| self.llm_api = llm_api | |
| p = llm_config_service.get_default() | |
| self.llm_params = LlmPredictParams( | |
| temperature=p.temperature, | |
| top_p=p.top_p, | |
| min_p=p.min_p, | |
| seed=p.seed, | |
| frequency_penalty=p.frequency_penalty, | |
| presence_penalty=p.presence_penalty, | |
| n_predict=p.n_predict, | |
| ) | |
| async def get_qe_result(self, history: List[Message]) -> QEResult: | |
| """ | |
| Получает результат QE. | |
| Args: | |
| history: История диалога в виде списка сообщений | |
| Returns: | |
| QEResult: Результат QE | |
| """ | |
| request = self._get_qe_request(history) | |
| response = await self.llm_api.predict_chat_stream( | |
| request, | |
| "", | |
| self.llm_params, | |
| ) | |
| logger.info(f"QE response: {response}") | |
| try: | |
| return self._postprocess_qe(response) | |
| except Exception as e: | |
| logger.error(f"Error in _postprocess_qe: {e}") | |
| from_chat = self._get_search_query(history) | |
| return QEResult(use_search=from_chat is not None, search_query=from_chat) | |
| def _get_qe_request(self, history: List[Message]) -> ChatRequest: | |
| """ | |
| Подготавливает полный промпт для QE запроса. | |
| Args: | |
| history: История диалога в виде списка сообщений | |
| Returns: | |
| str: Отформатированный промпт с историей диалога | |
| """ | |
| formatted_history = "\n".join( | |
| [self._format_message(msg) for msg in history] | |
| ).strip() | |
| message = self.prompt.format(history=formatted_history) | |
| return ChatRequest( | |
| history=[Message(role="user", content=message, searchResults='')] | |
| ) | |
| def _format_message(self, message: Message) -> str: | |
| """ | |
| Форматирует сообщение для запроса QE. | |
| Args: | |
| message: Сообщение для форматирования | |
| """ | |
| if message.searchResults: | |
| return f'{message.role}: {message.content}\n<search-results>\n{message.searchResults}\n</search-results>' | |
| return f'{message.role}: {message.content}' | |
| def _postprocess_qe(input_text: str) -> QEResult: | |
| # Находим все вхождения квадратных скобок | |
| matches = re.findall(r'\[([^\]]*)\]', input_text) | |
| # Проверяем количество найденных скобок | |
| if len(matches) != 2: | |
| raise ValueError("В тексте должно быть ровно две пары квадратных скобок.") | |
| # Извлекаем значения из скобок | |
| first_part = matches[0].strip().lower() | |
| second_part = matches[1].strip() | |
| if first_part == "да": | |
| bool_var = True | |
| elif first_part == "нет": | |
| bool_var = False | |
| else: | |
| raise ValueError("Первая часть текста должна содержать 'ДА' или 'НЕТ'.") | |
| return QEResult(use_search=bool_var, search_query=second_part) | |
| def _get_search_query(self, history: List[Message]) -> str | None: | |
| """ | |
| Получает запрос для поиска на основе последнего сообщения пользователя. | |
| """ | |
| return next( | |
| ( | |
| msg | |
| for msg in reversed(history) | |
| if msg.role == "user" | |
| and (msg.searchResults is None or not msg.searchResults) | |
| ), | |
| None, | |
| ) | |