allekssandr commited on
Commit
6066e39
·
verified ·
1 Parent(s): c51d203

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -5
app.py CHANGED
@@ -1,7 +1,160 @@
1
- from fastapi import FastAPI
 
 
2
 
3
- app = FastAPI()
 
 
 
 
 
 
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from platform import system
4
 
5
+ from dotenv import load_dotenv
6
+ from huggingface_hub import HfApi, HfFolder
7
+ from humanfriendly.terminal import output
8
+ from telegram import Update
9
+ from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, filters, CallbackContext
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ import torch
12
 
13
+ load_dotenv()
14
+ TOKEN = os.getenv("TELEGRAM_TOKEN")
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+ MAX_LENGTH_REQUEST = 1024
17
+ MAX_NEW_TOKENS = 128
18
+ MAX_LENGTH_RESPONSE = 100
19
+
20
+ # Настройка логирования
21
+ logging.basicConfig(
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
23
+ level=logging.INFO
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Логин через токен
28
+ try:
29
+ api = HfApi()
30
+ HfFolder.save_token(HF_TOKEN)
31
+ except Exception as e:
32
+ logger.error(f"Ошибка авторизации токена: {str(e)}")
33
+ raise
34
+
35
+ rugpt3large_based_on_gpt2_model_name = "ai-forever/rugpt3large_based_on_gpt2"
36
+ rugpt3small_based_on_gpt2_model_name = "ai-forever/rugpt3small_based_on_gpt2"
37
+ sber_rugpt3small_based_on_gpt2_model_name = "sberbank-ai/rugpt3small_based_on_gpt2"
38
+
39
+ # Инициализация модели
40
+ try:
41
+ model_name = rugpt3large_based_on_gpt2_model_name # Меньшая модель
42
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
43
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
44
+ logger.info("Модель успешно загружена")
45
+ except Exception as e:
46
+ logger.error(f"Ошибка загрузки модели: {str(e)}")
47
+ raise
48
+
49
+ # Настройка устройства
50
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+ model.to(device)
52
+ logger.info(f"Используемое устройство: {device}")
53
+
54
+ # Контекст диалога (упрощенная версия)
55
+ chat_contexts = {}
56
+
57
+
58
+ def get_chat_context(chat_id):
59
+ if chat_id not in chat_contexts:
60
+ chat_contexts[chat_id] = {"history": []}
61
+ return chat_contexts[chat_id]
62
+
63
+
64
+ MAX_HISTORY_LENGTH = 10
65
+
66
+
67
+ def add_to_chat_history(chat_id, user_input, bot_response):
68
+ context = get_chat_context(chat_id)
69
+ context["history"].append({"user": user_input, "bot": bot_response})
70
+ if len(context["history"]) > MAX_HISTORY_LENGTH:
71
+ context["history"] = context["history"][-MAX_HISTORY_LENGTH:]
72
+
73
+
74
+ async def start(update: Update, context: CallbackContext) -> None:
75
+ """Обработчик команды /start"""
76
+ await update.message.reply_text('🚀 Привет! Я РУССКИЙ! :) бот.')
77
+
78
+
79
+ async def handle_message(update: Update, context: CallbackContext) -> None:
80
+ """Обработка текстовых сообщений"""
81
+ try:
82
+ user_input = update.message.text
83
+ chat_id = update.message.chat_id
84
+ user_name = update.message.from_user.username
85
+ logger.info(f"Получено сообщение: {user_input}")
86
+
87
+ # Получаем контекст чата
88
+ context = get_chat_context(chat_id)
89
+
90
+ # Формируем входной текст с учетом истории
91
+ input_text = ""
92
+ for msg in context["history"]:
93
+ input_text += f"Пользователь: {msg['user']}\nБот: {msg['bot']}"
94
+
95
+ tokenizer.pad_token = tokenizer.eos_token
96
+ # Генерация промта
97
+ system_prompt = "Ответ должен быть точным и кратким."
98
+ # system_prompt = ""
99
+ # prompt = f"{system_prompt} Вопрос: {user_input}; Ответ: "
100
+ prompt = f"{system_prompt}\n {user_input}\n"
101
+ logger.info(f"Промт: {prompt}")
102
+
103
+ # Генерация ответа
104
+ inputs = tokenizer(
105
+ prompt,
106
+ return_tensors="pt", # Возвращает PyTorch тензоры
107
+ # truncation=True, # Обрезает текст, если он превышает max_length
108
+ # add_special_tokens=True, # Добавляет специальные токены (например, [CLS], [SEP])
109
+ ).to(device)
110
+
111
+ outputs = model.generate(
112
+ inputs.input_ids,
113
+ max_new_tokens=60,
114
+ no_repeat_ngram_size=3,
115
+ repetition_penalty=1.5,
116
+ do_sample=True,
117
+ top_k=100,
118
+ top_p=0.3,
119
+ temperature=0.4,
120
+ stop_strings=['<s>'],
121
+ tokenizer=tokenizer,
122
+ )
123
+
124
+ # Декодирование ответа
125
+ # response = list(map(tokenizer.decode, outputs))[0]
126
+ response = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
127
+ logger.info(f"Ответ: {response}")
128
+
129
+ if not response:
130
+ response = "🤔 Пока не знаю, что ответить. Можете переформулировать вопрос?"
131
+
132
+ # Отправка ответа
133
+ await update.message.reply_text(response, parse_mode=None)
134
+ add_to_chat_history(chat_id, user_input, response)
135
+
136
+ except Exception as e:
137
+ logger.error(f"Ошибка обработки сообщения: {str(e)}")
138
+ await update.message.reply_text("❌ Произошла ошибка при обработке запроса")
139
+
140
+
141
+ def main() -> None:
142
+ try:
143
+ application = ApplicationBuilder().token(TOKEN).build()
144
+ application.add_handler(CommandHandler("start", start))
145
+ application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message))
146
+ application.add_error_handler(error)
147
+
148
+ logger.info("Бот запущен")
149
+ application.run_polling()
150
+
151
+ except Exception as e:
152
+ logger.error(f"Ошибка запуска бота: {str(e)}")
153
+
154
+
155
+ async def error(update: Update, context: CallbackContext) -> None:
156
+ logger.error(f'Ошибка: {context.error}')
157
+
158
+
159
+ if __name__ == '__main__':
160
+ main()