Spaces:
Sleeping
Sleeping
import logging | |
from typing import Annotated, Optional, Tuple | |
import os | |
from fastapi import APIRouter, BackgroundTasks, HTTPException, Response, UploadFile, Depends | |
from components.llm.common import LlmParams, LlmPredictParams, Message | |
from components.llm.deepinfra_api import DeepInfraApi | |
from components.llm.llm_api import LlmApi | |
from components.llm.common import ChatRequest | |
from common.constants import PROMPT | |
from components.llm.prompts import SYSTEM_PROMPT | |
from components.llm.utils import append_llm_response_to_history, convert_to_openai_format | |
from components.nmd.aggregate_answers import preprocessed_chunks | |
from components.nmd.llm_chunk_search import LLMChunkSearch | |
from components.services.dataset import DatasetService | |
from common.configuration import Configuration, Query, SummaryChunks | |
from components.datasets.dispatcher import Dispatcher | |
from common.exceptions import LLMResponseException | |
from components.dbo.models.log import Log | |
from components.services.llm_config import LLMConfigService | |
from components.services.llm_prompt import LlmPromptService | |
from schemas.dataset import (Dataset, DatasetExpanded, DatasetProcessing, | |
SortQuery, SortQueryList) | |
import common.dependencies as DI | |
from sqlalchemy.orm import Session | |
router = APIRouter(prefix='/llm') | |
logger = logging.getLogger(__name__) | |
conf = DI.get_config() | |
llm_params = LlmParams(**{ | |
"url": conf.llm_config.base_url, | |
"model": conf.llm_config.model, | |
"tokenizer": "unsloth/Llama-3.3-70B-Instruct", | |
"type": "deepinfra", | |
"default": True, | |
"predict_params": LlmPredictParams( | |
temperature=0.15, top_p=0.95, min_p=0.05, seed=42, | |
repetition_penalty=1.2, presence_penalty=1.1, n_predict=2000 | |
), | |
"api_key": os.environ.get(conf.llm_config.api_key_env), | |
"context_length": 128000 | |
}) | |
#TODO: унести в DI | |
llm_api = DeepInfraApi(params=llm_params) | |
def get_chunks(query: Query, dispatcher: Annotated[Dispatcher, Depends(DI.get_dispatcher)]) -> SummaryChunks: | |
logger.info(f"Handling POST request to /chunks with query: {query.query}") | |
try: | |
result = dispatcher.search_answer(query) | |
logger.info("Successfully retrieved chunks") | |
return result | |
except Exception as e: | |
logger.error(f"Error retrieving chunks: {str(e)}") | |
raise e | |
def llm_answer(query: str, answer_chunks: SummaryChunks, config: Configuration | |
) -> Tuple[str, str, str, int]: | |
""" | |
Метод для поиска правильного ответа с помощью LLM. | |
Args: | |
query: Запрос. | |
answer_chunks: Ответы векторного поиска и elastic. | |
Returns: | |
Возвращает исходные chunks из поисков, и chunk который выбрала модель. | |
""" | |
prompt = PROMPT | |
llm_search = LLMChunkSearch(config.llm_config, PROMPT, logger) | |
return llm_search.llm_chunk_search(query, answer_chunks, prompt) | |
def get_llm_answer(query: Query, chunks: SummaryChunks, db: Annotated[Session, Depends(DI.get_db)], config: Annotated[Configuration, Depends(DI.get_config)]): | |
logger.info(f"Handling POST request to /answer_llm with query: {query.query}") | |
try: | |
text_chunks, answer_llm, llm_prompt, _ = llm_answer(query.query, chunks, config) | |
if not answer_llm: | |
logger.error("LLM returned empty response") | |
raise LLMResponseException() | |
log_entry = Log( | |
llmPrompt=llm_prompt, | |
llmResponse=answer_llm, | |
userRequest=query.query, | |
query_type=chunks.query_type, | |
userName=query.userName, | |
) | |
with db() as session: | |
session.add(log_entry) | |
session.commit() | |
session.refresh(log_entry) | |
logger.info(f"Successfully processed LLM request, log_id: {log_entry.id}") | |
return { | |
"answer_llm": answer_llm, | |
"log_id": log_entry.id, | |
} | |
except Exception as e: | |
logger.error(f"Error processing LLM request: {str(e)}") | |
raise e | |
async def chat(request: ChatRequest, config: Annotated[Configuration, Depends(DI.get_config)], llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)], prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)], llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)], dispatcher: Annotated[Dispatcher, Depends(DI.get_dispatcher)]): | |
try: | |
p = llm_config_service.get_default() | |
system_prompt = prompt_service.get_default() | |
predict_params = LlmPredictParams( | |
temperature=p.temperature, top_p=p.top_p, min_p=p.min_p, seed=p.seed, | |
frequency_penalty=p.frequency_penalty, presence_penalty=p.presence_penalty, n_predict=p.n_predict, stop=[] | |
) | |
#TODO: Вынести | |
def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]: | |
return next( | |
( | |
msg for msg in reversed(chat_request.history) | |
if msg.role == "user" and (msg.searchResults is None or not msg.searchResults) | |
), | |
None | |
) | |
def insert_search_results_to_message(chat_request: ChatRequest, new_content: str) -> bool: | |
for msg in reversed(chat_request.history): | |
if msg.role == "user" and (msg.searchResults is None or not msg.searchResults): | |
msg.content = new_content | |
return True | |
return False | |
last_query = get_last_user_message(request) | |
search_result = None | |
if last_query: | |
search_result = dispatcher.search_answer(Query(query=last_query.content, query_abbreviation=last_query.content)) | |
text_chunks = preprocessed_chunks(search_result, None, logger) | |
new_message = f'{last_query.content} /n<search-results>/n{text_chunks}/n</search-results>' | |
insert_search_results_to_message(request, new_message) | |
response = await llm_api.predict_chat_stream(request, system_prompt.text, predict_params) | |
result = append_llm_response_to_history(request, response) | |
return result | |
except Exception as e: | |
logger.error(f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10) | |
return {"error": str(e)} |