Spaces:
Sleeping
Sleeping
File size: 6,567 Bytes
57cf043 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import logging
from typing import Annotated, Optional, Tuple
import os
from fastapi import APIRouter, BackgroundTasks, HTTPException, Response, UploadFile, Depends
from components.llm.common import LlmParams, LlmPredictParams, Message
from components.llm.deepinfra_api import DeepInfraApi
from components.llm.llm_api import LlmApi
from components.llm.common import ChatRequest
from common.constants import PROMPT
from components.llm.prompts import SYSTEM_PROMPT
from components.llm.utils import append_llm_response_to_history, convert_to_openai_format
from components.nmd.aggregate_answers import preprocessed_chunks
from components.nmd.llm_chunk_search import LLMChunkSearch
from components.services.dataset import DatasetService
from common.configuration import Configuration, Query, SummaryChunks
from components.datasets.dispatcher import Dispatcher
from common.exceptions import LLMResponseException
from components.dbo.models.log import Log
from components.services.llm_config import LLMConfigService
from components.services.llm_prompt import LlmPromptService
from schemas.dataset import (Dataset, DatasetExpanded, DatasetProcessing,
SortQuery, SortQueryList)
import common.dependencies as DI
from sqlalchemy.orm import Session
router = APIRouter(prefix='/llm')
logger = logging.getLogger(__name__)
conf = DI.get_config()
llm_params = LlmParams(**{
"url": conf.llm_config.base_url,
"model": conf.llm_config.model,
"tokenizer": "unsloth/Llama-3.3-70B-Instruct",
"type": "deepinfra",
"default": True,
"predict_params": LlmPredictParams(
temperature=0.15, top_p=0.95, min_p=0.05, seed=42,
repetition_penalty=1.2, presence_penalty=1.1, n_predict=2000
),
"api_key": os.environ.get(conf.llm_config.api_key_env),
"context_length": 128000
})
#TODO: унести в DI
llm_api = DeepInfraApi(params=llm_params)
@router.post("/chunks")
def get_chunks(query: Query, dispatcher: Annotated[Dispatcher, Depends(DI.get_dispatcher)]) -> SummaryChunks:
logger.info(f"Handling POST request to /chunks with query: {query.query}")
try:
result = dispatcher.search_answer(query)
logger.info("Successfully retrieved chunks")
return result
except Exception as e:
logger.error(f"Error retrieving chunks: {str(e)}")
raise e
def llm_answer(query: str, answer_chunks: SummaryChunks, config: Configuration
) -> Tuple[str, str, str, int]:
"""
Метод для поиска правильного ответа с помощью LLM.
Args:
query: Запрос.
answer_chunks: Ответы векторного поиска и elastic.
Returns:
Возвращает исходные chunks из поисков, и chunk который выбрала модель.
"""
prompt = PROMPT
llm_search = LLMChunkSearch(config.llm_config, PROMPT, logger)
return llm_search.llm_chunk_search(query, answer_chunks, prompt)
@router.post("/answer_llm")
def get_llm_answer(query: Query, chunks: SummaryChunks, db: Annotated[Session, Depends(DI.get_db)], config: Annotated[Configuration, Depends(DI.get_config)]):
logger.info(f"Handling POST request to /answer_llm with query: {query.query}")
try:
text_chunks, answer_llm, llm_prompt, _ = llm_answer(query.query, chunks, config)
if not answer_llm:
logger.error("LLM returned empty response")
raise LLMResponseException()
log_entry = Log(
llmPrompt=llm_prompt,
llmResponse=answer_llm,
userRequest=query.query,
query_type=chunks.query_type,
userName=query.userName,
)
with db() as session:
session.add(log_entry)
session.commit()
session.refresh(log_entry)
logger.info(f"Successfully processed LLM request, log_id: {log_entry.id}")
return {
"answer_llm": answer_llm,
"log_id": log_entry.id,
}
except Exception as e:
logger.error(f"Error processing LLM request: {str(e)}")
raise e
@router.post("/chat")
async def chat(request: ChatRequest, config: Annotated[Configuration, Depends(DI.get_config)], llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)], prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)], llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)], dispatcher: Annotated[Dispatcher, Depends(DI.get_dispatcher)]):
try:
p = llm_config_service.get_default()
system_prompt = prompt_service.get_default()
predict_params = LlmPredictParams(
temperature=p.temperature, top_p=p.top_p, min_p=p.min_p, seed=p.seed,
frequency_penalty=p.frequency_penalty, presence_penalty=p.presence_penalty, n_predict=p.n_predict, stop=[]
)
#TODO: Вынести
def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]:
return next(
(
msg for msg in reversed(chat_request.history)
if msg.role == "user" and (msg.searchResults is None or not msg.searchResults)
),
None
)
def insert_search_results_to_message(chat_request: ChatRequest, new_content: str) -> bool:
for msg in reversed(chat_request.history):
if msg.role == "user" and (msg.searchResults is None or not msg.searchResults):
msg.content = new_content
return True
return False
last_query = get_last_user_message(request)
search_result = None
if last_query:
search_result = dispatcher.search_answer(Query(query=last_query.content, query_abbreviation=last_query.content))
text_chunks = preprocessed_chunks(search_result, None, logger)
new_message = f'{last_query.content} /n<search-results>/n{text_chunks}/n</search-results>'
insert_search_results_to_message(request, new_message)
response = await llm_api.predict_chat_stream(request, system_prompt.text, predict_params)
result = append_llm_response_to_history(request, response)
return result
except Exception as e:
logger.error(f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10)
return {"error": str(e)} |