File size: 3,301 Bytes
57cf043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from pydantic import BaseModel, Field
from typing import Optional, List, Protocol

class LlmPredictParams(BaseModel):
    """
    Параметры для предсказания LLM.
    """
    system_prompt: Optional[str] = Field(None, description="Системный промпт.")
    user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
    n_predict: Optional[int] = None
    temperature: Optional[float] = None
    top_k: Optional[int] = None
    top_p: Optional[float] = None
    min_p: Optional[float] = None
    seed: Optional[int] = None
    repeat_penalty: Optional[float] = None
    repeat_last_n: Optional[int] = None
    retry_if_text_not_present: Optional[str] = None
    retry_count: Optional[int] = None
    presence_penalty: Optional[float] = None
    frequency_penalty: Optional[float] = None
    n_keep: Optional[int] = None
    cache_prompt: Optional[bool] = None
    stop: Optional[List[str]] = None


class LlmParams(BaseModel):
    """
    Основные параметры для LLM.
    """
    url: str
    model: Optional[str] = Field(None, description="Предполагается, что для локального API этот параметр не будет указываться, т.к. будем брать первую модель из списка потому, что модель доступна всего одна. Для deepinfra такой подход не подойдет и модель нужно задавать явно.")
    tokenizer: Optional[str]  = Field(None, description="При использовании стороннего API, не поддерживающего токенизацию, будет использован AutoTokenizer для модели из этого поля. Используется в случае, если название модели в API не совпадает с оригинальным названием на Huggingface.")
    type: Optional[str] = None
    default: Optional[bool] = None
    template: Optional[str] = None
    predict_params: Optional[LlmPredictParams] = None
    api_key: Optional[str] = None
    context_length: Optional[int] = None
    
class LlmApiProtocol(Protocol):
    async def tokenize(self, prompt: str) -> Optional[dict]:
        ...
    async def detokenize(self, tokens: List[int]) -> Optional[str]:
        ...
    async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
        ...
    async def predict(self, prompt: str) -> str:
        ...
        
class LlmApi:
    """
    Базовый клас для работы с API LLM.
    """
    params: LlmParams = None
    
    def __init__(self):
        self.params = None
    
    def set_params(self, params: LlmParams):
        self.params = params
        
    def create_headers(self) -> dict[str, str]:
        headers = {"Content-Type": "application/json"}
                
        if self.params.api_key is not None:
            headers["Authorization"] = self.params.api_key
            
        return headers
        
        
class Message(BaseModel):
    role: str
    content: str
    searchResults: List[str]

class ChatRequest(BaseModel):
    history: List[Message]