File size: 6,567 Bytes
57cf043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import logging
from typing import Annotated, Optional, Tuple
import os
from fastapi import APIRouter, BackgroundTasks, HTTPException, Response, UploadFile, Depends
from components.llm.common import LlmParams, LlmPredictParams, Message
from components.llm.deepinfra_api import DeepInfraApi
from components.llm.llm_api import LlmApi
from components.llm.common import ChatRequest

from common.constants import PROMPT
from components.llm.prompts import SYSTEM_PROMPT
from components.llm.utils import append_llm_response_to_history, convert_to_openai_format
from components.nmd.aggregate_answers import preprocessed_chunks
from components.nmd.llm_chunk_search import LLMChunkSearch
from components.services.dataset import DatasetService
from common.configuration import Configuration, Query, SummaryChunks
from components.datasets.dispatcher import Dispatcher
from common.exceptions import LLMResponseException
from components.dbo.models.log import Log
from components.services.llm_config import LLMConfigService
from components.services.llm_prompt import LlmPromptService
from schemas.dataset import (Dataset, DatasetExpanded, DatasetProcessing,
                             SortQuery, SortQueryList)
import common.dependencies as DI
from sqlalchemy.orm import Session

router = APIRouter(prefix='/llm')
logger = logging.getLogger(__name__)

conf = DI.get_config()
llm_params = LlmParams(**{
    "url": conf.llm_config.base_url,
    "model": conf.llm_config.model,
    "tokenizer": "unsloth/Llama-3.3-70B-Instruct",
    "type": "deepinfra",
    "default": True,
    "predict_params": LlmPredictParams(
        temperature=0.15, top_p=0.95, min_p=0.05, seed=42,
        repetition_penalty=1.2, presence_penalty=1.1, n_predict=2000
    ),
    "api_key": os.environ.get(conf.llm_config.api_key_env),
    "context_length": 128000
})
#TODO: унести в DI
llm_api = DeepInfraApi(params=llm_params)

@router.post("/chunks")
def get_chunks(query: Query, dispatcher: Annotated[Dispatcher, Depends(DI.get_dispatcher)]) -> SummaryChunks:
    logger.info(f"Handling POST request to /chunks with query: {query.query}")
    try:
        result = dispatcher.search_answer(query)
        logger.info("Successfully retrieved chunks")
        return result
    except Exception as e:
        logger.error(f"Error retrieving chunks: {str(e)}")
        raise e


def llm_answer(query: str, answer_chunks: SummaryChunks, config: Configuration
    ) -> Tuple[str, str, str, int]:
        """
        Метод для поиска правильного ответа с помощью LLM.
        Args:
            query: Запрос.
            answer_chunks: Ответы векторного поиска и elastic.

        Returns:
            Возвращает исходные chunks из поисков, и chunk который выбрала модель.
        """
        prompt = PROMPT
        llm_search = LLMChunkSearch(config.llm_config, PROMPT, logger)
        return llm_search.llm_chunk_search(query, answer_chunks, prompt)
    

@router.post("/answer_llm")
def get_llm_answer(query: Query, chunks: SummaryChunks, db: Annotated[Session, Depends(DI.get_db)], config: Annotated[Configuration, Depends(DI.get_config)]):
    logger.info(f"Handling POST request to /answer_llm with query: {query.query}")
    try:
        text_chunks, answer_llm, llm_prompt, _ = llm_answer(query.query, chunks, config)

        if not answer_llm:
            logger.error("LLM returned empty response")
            raise LLMResponseException()

        log_entry = Log(
            llmPrompt=llm_prompt,
            llmResponse=answer_llm,
            userRequest=query.query,
            query_type=chunks.query_type,
            userName=query.userName,
        )
        with db() as session:
            session.add(log_entry)
            session.commit()
            session.refresh(log_entry)

        logger.info(f"Successfully processed LLM request, log_id: {log_entry.id}")
        return {
            "answer_llm": answer_llm,
            "log_id": log_entry.id,
        }

    except Exception as e:
        logger.error(f"Error processing LLM request: {str(e)}")
        raise e


@router.post("/chat")
async def chat(request: ChatRequest, config: Annotated[Configuration, Depends(DI.get_config)], llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)], prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)], llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)], dispatcher: Annotated[Dispatcher, Depends(DI.get_dispatcher)]):
    try:
        p = llm_config_service.get_default()
        system_prompt = prompt_service.get_default()
        
        predict_params = LlmPredictParams(
            temperature=p.temperature, top_p=p.top_p, min_p=p.min_p, seed=p.seed,
            frequency_penalty=p.frequency_penalty, presence_penalty=p.presence_penalty, n_predict=p.n_predict, stop=[]
        )
        
        #TODO: Вынести
        def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]:
            return next(
                (
                    msg for msg in reversed(chat_request.history) 
                    if msg.role == "user" and (msg.searchResults is None or not msg.searchResults)
                ),
                None
            )
            
        def insert_search_results_to_message(chat_request: ChatRequest, new_content: str) -> bool:
            for msg in reversed(chat_request.history):
                if msg.role == "user" and (msg.searchResults is None or not msg.searchResults):
                    msg.content = new_content
                    return True
            return False
    
        last_query = get_last_user_message(request)
        search_result = None
        
        if last_query:
            search_result = dispatcher.search_answer(Query(query=last_query.content, query_abbreviation=last_query.content))
            text_chunks = preprocessed_chunks(search_result, None, logger)
            
            new_message = f'{last_query.content} /n<search-results>/n{text_chunks}/n</search-results>'
            insert_search_results_to_message(request, new_message)
            
        response = await llm_api.predict_chat_stream(request, system_prompt.text, predict_params)
        result = append_llm_response_to_history(request, response)
        return result
    except Exception as e:
        logger.error(f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10)
        return {"error": str(e)}