Spaces:
Sleeping
Sleeping
import logging | |
import os | |
from platform import system | |
from dotenv import load_dotenv | |
from huggingface_hub import HfApi, HfFolder | |
from humanfriendly.terminal import output | |
from telegram import Update | |
from telegram.ext import ApplicationBuilder, CommandHandler, MessageHandler, filters, CallbackContext | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
load_dotenv() | |
TOKEN = os.getenv("TELEGRAM_TOKEN") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MAX_LENGTH_REQUEST = 1024 | |
MAX_NEW_TOKENS = 128 | |
MAX_LENGTH_RESPONSE = 100 | |
TEST_ENV=os.getenv("TEST_ENV") | |
# Настройка логирования | |
logging.basicConfig( | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
logger.info(f"TEST_ENV= {TEST_ENV}") | |
# Логин через токен | |
try: | |
api = HfApi() | |
HfFolder.save_token(HF_TOKEN) | |
except Exception as e: | |
logger.error(f"Ошибка авторизации токена: {str(e)}") | |
raise | |
rugpt3large_based_on_gpt2_model_name = "ai-forever/rugpt3large_based_on_gpt2" | |
rugpt3small_based_on_gpt2_model_name = "ai-forever/rugpt3small_based_on_gpt2" | |
sber_rugpt3small_based_on_gpt2_model_name = "sberbank-ai/rugpt3small_based_on_gpt2" | |
phi_mini_instruct_GGUF_model_name = "bartowski/Phi-3.5-mini-instruct-GGUF" | |
# Инициализация модели | |
try: | |
model_name = phi_mini_instruct_GGUF_model_name # Меньшая модель | |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
logger.info("Модель успешно загружена") | |
except Exception as e: | |
logger.error(f"Ошибка загрузки модели: {str(e)}") | |
raise | |
# Настройка устройства | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
logger.info(f"Используемое устройство: {device}") | |
# Контекст диалога (упрощенная версия) | |
chat_contexts = {} | |
def get_chat_context(chat_id): | |
if chat_id not in chat_contexts: | |
chat_contexts[chat_id] = {"history": []} | |
return chat_contexts[chat_id] | |
MAX_HISTORY_LENGTH = 10 | |
def add_to_chat_history(chat_id, user_input, bot_response): | |
context = get_chat_context(chat_id) | |
context["history"].append({"user": user_input, "bot": bot_response}) | |
if len(context["history"]) > MAX_HISTORY_LENGTH: | |
context["history"] = context["history"][-MAX_HISTORY_LENGTH:] | |
async def start(update: Update, context: CallbackContext) -> None: | |
"""Обработчик команды /start""" | |
await update.message.reply_text('🚀 Привет! Я РУССКИЙ! :) бот.') | |
async def handle_message(update: Update, context: CallbackContext) -> None: | |
"""Обработка текстовых сообщений""" | |
try: | |
user_input = update.message.text | |
chat_id = update.message.chat_id | |
user_name = update.message.from_user.username | |
logger.info(f"Получено сообщение: {user_input}") | |
# Получаем контекст чата | |
context = get_chat_context(chat_id) | |
# Формируем входной текст с учетом истории | |
input_text = "" | |
for msg in context["history"]: | |
input_text += f"Пользователь: {msg['user']}\nБот: {msg['bot']}" | |
tokenizer.pad_token = tokenizer.eos_token | |
# Генерация промта | |
system_prompt = "Ответ должен быть точным и кратким." | |
# system_prompt = "" | |
# prompt = f"{system_prompt} Вопрос: {user_input}; Ответ: " | |
prompt = f"{system_prompt}\n {user_input}\n" | |
logger.info(f"Промт: {prompt}") | |
# Генерация ответа | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", # Возвращает PyTorch тензоры | |
# truncation=True, # Обрезает текст, если он превышает max_length | |
# add_special_tokens=True, # Добавляет специальные токены (например, [CLS], [SEP]) | |
).to(device) | |
outputs = model.generate( | |
inputs.input_ids, | |
max_new_tokens=60, | |
no_repeat_ngram_size=3, | |
repetition_penalty=1.5, | |
do_sample=True, | |
top_k=100, | |
top_p=0.3, | |
temperature=0.4, | |
stop_strings=['<s>'], | |
tokenizer=tokenizer, | |
) | |
# Декодирование ответа | |
# response = list(map(tokenizer.decode, outputs))[0] | |
response = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0] | |
logger.info(f"Ответ: {response}") | |
if not response: | |
response = "🤔 Пока не знаю, что ответить. Можете переформулировать вопрос?" | |
# Отправка ответа | |
await update.message.reply_text(response, parse_mode=None) | |
add_to_chat_history(chat_id, user_input, response) | |
except Exception as e: | |
logger.error(f"Ошибка обработки сообщения: {str(e)}") | |
await update.message.reply_text("❌ Произошла ошибка при обработке запроса") | |
def app() -> None: | |
try: | |
application = ApplicationBuilder().token(TOKEN).build() | |
application.add_handler(CommandHandler("start", start)) | |
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message)) | |
application.add_error_handler(error) | |
logger.info("Бот запущен") | |
application.run_polling() | |
except Exception as e: | |
logger.error(f"Ошибка запуска бота: {str(e)}") | |
async def error(update: Update, context: CallbackContext) -> None: | |
logger.error(f'Ошибка: {context.error}') | |
if __name__ == '__app__': | |
app() | |