File size: 4,733 Bytes
57cf043
 
86c402d
 
57cf043
 
86c402d
 
 
 
 
 
 
 
57cf043
 
 
 
 
 
 
86c402d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57cf043
 
 
 
86c402d
 
 
 
 
 
 
 
 
57cf043
 
 
86c402d
57cf043
86c402d
 
 
 
 
 
 
 
57cf043
86c402d
 
57cf043
 
 
86c402d
 
 
 
57cf043
86c402d
57cf043
86c402d
 
 
 
57cf043
86c402d
 
 
57cf043
 
 
86c402d
57cf043
 
 
86c402d
 
57cf043
86c402d
 
 
 
 
 
 
 
 
 
 
 
57cf043
86c402d
 
57cf043
 
 
86c402d
 
 
 
 
57cf043
 
 
86c402d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
import os
from typing import Annotated, Optional
from uuid import UUID

from components.services.dataset import DatasetService
from components.services.entity import EntityService
from fastapi import APIRouter, Depends, HTTPException

import common.dependencies as DI
from common.configuration import Configuration, Query
from components.llm.common import ChatRequest, LlmParams, LlmPredictParams, Message
from components.llm.deepinfra_api import DeepInfraApi
from components.llm.utils import append_llm_response_to_history
from components.services.llm_config import LLMConfigService
from components.services.llm_prompt import LlmPromptService

router = APIRouter(prefix='/llm')
logger = logging.getLogger(__name__)

conf = DI.get_config()
llm_params = LlmParams(
    **{
        "url": conf.llm_config.base_url,
        "model": conf.llm_config.model,
        "tokenizer": "unsloth/Llama-3.3-70B-Instruct",
        "type": "deepinfra",
        "default": True,
        "predict_params": LlmPredictParams(
            temperature=0.15,
            top_p=0.95,
            min_p=0.05,
            seed=42,
            repetition_penalty=1.2,
            presence_penalty=1.1,
            n_predict=2000,
        ),
        "api_key": os.environ.get(conf.llm_config.api_key_env),
        "context_length": 128000,
    }
)
# TODO: унести в DI
llm_api = DeepInfraApi(params=llm_params)


@router.post("/chat")
async def chat(
    request: ChatRequest,
    config: Annotated[Configuration, Depends(DI.get_config)],
    llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)],
    prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)],
    llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
    entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
    dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
):
    try:
        p = llm_config_service.get_default()
        system_prompt = prompt_service.get_default()

        predict_params = LlmPredictParams(
            temperature=p.temperature,
            top_p=p.top_p,
            min_p=p.min_p,
            seed=p.seed,
            frequency_penalty=p.frequency_penalty,
            presence_penalty=p.presence_penalty,
            n_predict=p.n_predict,
            stop=[],
        )

        # TODO: Вынести
        def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]:
            return next(
                (
                    msg
                    for msg in reversed(chat_request.history)
                    if msg.role == "user"
                    and (msg.searchResults is None or not msg.searchResults)
                ),
                None,
            )

        def insert_search_results_to_message(
            chat_request: ChatRequest, new_content: str
        ) -> bool:
            for msg in reversed(chat_request.history):
                if msg.role == "user" and (
                    msg.searchResults is None or not msg.searchResults
                ):
                    msg.content = new_content
                    return True
            return False

        last_query = get_last_user_message(request)
        search_result = None
        
        logger.info(f"last_query: {last_query}")

        if last_query:
            dataset = dataset_service.get_current_dataset()
            if dataset is None:
                raise HTTPException(status_code=400, detail="Dataset not found")
            logger.info(f"last_query: {last_query.content}")
            _, scores, chunk_ids = entity_service.search_similar(last_query.content, dataset.id)
            
            chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
            
            logger.info(f"chunk_ids: {chunk_ids[:3]}...{chunk_ids[-3:]}")
            logger.info(f"scores: {scores[:3]}...{scores[-3:]}")
            
            text_chunks = entity_service.build_text(chunks, scores)
            
            logger.info(f"text_chunks: {text_chunks[:3]}...{text_chunks[-3:]}")

            new_message = f'{last_query.content} /n<search-results>/n{text_chunks}/n</search-results>'
            insert_search_results_to_message(request, new_message)
            
        logger.info(f"request: {request}")

        response = await llm_api.predict_chat_stream(
            request, system_prompt.text, predict_params
        )
        result = append_llm_response_to_history(request, response)
        return result
    except Exception as e:
        logger.error(
            f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10
        )
        return {"error": str(e)}