File size: 3,141 Bytes
a844653
a1708f0
e00eb41
56205e9
152d4a2
1cc747b
a9daaf3
56205e9
a9daaf3
a844653
a9daaf3
152d4a2
a844653
 
 
152d4a2
a844653
 
 
 
47af0fc
a9daaf3
12f5431
a844653
 
a9daaf3
12f5431
a844653
 
12f5431
a844653
 
12f5431
 
 
 
 
 
a844653
 
12f5431
17d2ff9
 
fb6e36a
2b0aafa
61bd0ae
 
12f5431
 
 
 
 
a844653
 
 
 
 
 
 
 
3618620
a844653
 
 
 
 
 
 
a9daaf3
a844653
 
 
 
 
 
3618620
a844653
 
17d2ff9
a844653
12f5431
a844653
 
 
 
 
 
3618620
a844653
 
 
 
 
 
12f5431
a844653
 
152d4a2
12f5431
2bf7b09
a9daaf3
 
f04979d
a9daaf3
 
 
 
 
 
 
 
 
 
 
 
a1708f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import logging
from typing import Union, Optional

from fastapi import FastAPI
from llama_cpp import Llama


app = FastAPI()

CHAT_TEMPLATE = '<|system|> {system_prompt}<|end|><|user|> {prompt}<|end|><|assistant|>'.strip()
SYSTEM_PROMPT = ''

logging.basicConfig(
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

logger.info("Запускаемся... 🥳🥳🥳")

REPO_ID = "Vikhrmodels/Vikhr-Qwen-2.5-1.5B-Instruct-GGUF"
FILE_NAME = "Vikhr-Qwen-2.5-1.5b-Instruct-Q8_0.gguf",

# Инициализация модели
try:
    logger.info(f"Загрузка модели {FILE_NAME}...")

    # загрузка модели для локального хранилища
    # llm = Llama(
    #     model_path=f"./models/{model_name}.gguf",
    #     verbose=False,
    #     n_gpu_layers=-1,
    #     n_ctx=1512,
    #     temperature=0.3,
    #     num_return_sequences=1,
    #     no_repeat_ngram_size=2,
    #     top_k=50,
    #     top_p=0.95,
    # )

    # if not llm:
    LLM = Llama.from_pretrained(
        repo_id=REPO_ID,
        filename='Vikhr-Qwen-2.5-1.5b-Instruct-Q8_0.gguf',
        n_gpu_layers=-1,
        n_threads=2,
        n_ctx=4096,
        temperature=0.3,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        top_k=50,
        top_p=0.95,
    )

except Exception as e:
    logger.error(f"Ошибка загрузки модели: {str(e)}")
    raise


# составление промта для модели
def create_prompt(text: str) -> Union[str, None]:
    try:
        user_input = text
        logger.info(f"Получено сообщение: {user_input}")


        # Генерация шаблона
        return CHAT_TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT or 'Ответ должен быть точным, кратким и с юмором.',
            prompt=user_input,
        )
    except Exception as e:
        logger.error(e)


def generate_response(prompt: str) -> Optional[str]:
    try:
        # Обработка текстового сообщения
        output = LLM(
            prompt,
            max_tokens=64,
            stop=["<|end|>"],
        )

        logger.info('Output:')
        logger.info(output)

        response = output['choices'][0]['text']

        # Отправка ответа
        if response:
            return response

        return 'Произошла ошибка при обработке запроса'

    except Exception as e:
        logger.error(f"Ошибка обработки сообщения: {str(e)}")


@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.put("/system-prompt")
async def set_system_prompt(text: str):
    logger.info('post/system-prompt')
    global SYSTEM_PROMPT
    SYSTEM_PROMPT = text

@app.post("/predict")
async def predict(text: str):
    # Генерация ответа с помощью модели
    logger.info('post/predict')
    prompt = create_prompt(text)
    response = generate_response(prompt)
    return {"response": response}