File size: 3,472 Bytes
a9daaf3
a844653
3618620
56205e9
152d4a2
1cc747b
a9daaf3
 
56205e9
a9daaf3
a844653
a9daaf3
152d4a2
a844653
 
 
152d4a2
a844653
 
 
 
47af0fc
a9daaf3
12f5431
a844653
 
a9daaf3
12f5431
a844653
 
12f5431
a844653
 
12f5431
 
 
 
 
 
a844653
 
12f5431
17d2ff9
 
fb6e36a
a844653
12f5431
 
 
 
 
 
a844653
 
 
 
 
 
 
 
3618620
a844653
 
 
 
 
 
 
a9daaf3
a844653
 
 
 
 
 
3618620
a844653
 
17d2ff9
a844653
12f5431
a844653
 
 
 
 
 
3618620
a844653
 
 
 
 
 
12f5431
a844653
 
152d4a2
12f5431
2bf7b09
a9daaf3
 
f04979d
a9daaf3
 
f19d93a
a9daaf3
 
 
 
 
 
 
 
 
 
 
f19d93a
a9daaf3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import asyncio
import logging
from typing import Union, Optional, SupportsIndex
from fastapi import FastAPI
from llama_cpp import Llama

from bot import start_bot

app = FastAPI()

CHAT_TEMPLATE = '<|system|> {system_prompt}<|end|><|user|> {prompt}<|end|><|assistant|>'.strip()
SYSTEM_PROMPT = ''

logging.basicConfig(
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

logger.info("Запускаемся... 🥳🥳🥳")

REPO_ID = "Vikhrmodels/Vikhr-Qwen-2.5-1.5B-Instruct-GGUF"
FILE_NAME = "Vikhr-Qwen-2.5-1.5b-Instruct-Q8_0.gguf",

# Инициализация модели
try:
    logger.info(f"Загрузка модели {FILE_NAME}...")

    # загрузка модели для локального хранилища
    # llm = Llama(
    #     model_path=f"./models/{model_name}.gguf",
    #     verbose=False,
    #     n_gpu_layers=-1,
    #     n_ctx=1512,
    #     temperature=0.3,
    #     num_return_sequences=1,
    #     no_repeat_ngram_size=2,
    #     top_k=50,
    #     top_p=0.95,
    # )

    # if not llm:
    LLM = Llama.from_pretrained(
        repo_id=REPO_ID,
        filename='Vikhr-Qwen-2.5-1.5b-Instruct-Q8_0.gguf',
        n_gpu_layers=-1,
        n_ctx=1512,
        temperature=0.3,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        top_k=50,
        top_p=0.95,
    )

except Exception as e:
    logger.error(f"Ошибка загрузки модели: {str(e)}")
    raise


# составление промта для модели
def create_prompt(text: str) -> Union[str, None]:
    try:
        user_input = text
        logger.info(f"Получено сообщение: {user_input}")


        # Генерация шаблона
        return CHAT_TEMPLATE.format(
            system_prompt=SYSTEM_PROMPT or 'Ответ должен быть точным, кратким и с юмором.',
            prompt=user_input,
        )
    except Exception as e:
        logger.error(e)


def generate_response(prompt: str) -> Optional[str]:
    try:
        # Обработка текстового сообщения
        output = LLM(
            prompt,
            max_tokens=64,
            stop=["<|end|>"],
        )

        logger.info('Output:')
        logger.info(output)

        response = output['choices'][0]['text']

        # Отправка ответа
        if response:
            return response

        return 'Произошла ошибка при обработке запроса'

    except Exception as e:
        logger.error(f"Ошибка обработки сообщения: {str(e)}")


@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.put("/system-prompt")
async def set_system_prompt(text: str):
    # Генерация ответа с помощью модели
    logger.info('post/system-prompt')
    global SYSTEM_PROMPT
    SYSTEM_PROMPT = text


@app.post("/predict")
async def predict(text: str):
    # Генерация ответа с помощью модели
    logger.info('post/predict')
    prompt = create_prompt(text)
    response = generate_response(prompt)
    return {"response": response}

# Запуск Telegram-бота при старте приложения
@app.on_event("startup")
async def startup_event():
    asyncio.create_task(start_bot())  # Запускаем бота в фоновом режиме