File size: 4,791 Bytes
c3dfc09 4d87bab 2edf1f0 e143294 c3dfc09 e143294 92bf9aa 2edf1f0 e143294 4d87bab e143294 4d87bab e143294 4d87bab e143294 2edf1f0 e143294 2edf1f0 c3dfc09 92bf9aa 2edf1f0 e143294 92bf9aa ca1d8ee e143294 2f225f8 e143294 4d87bab 92bf9aa 4d87bab 92bf9aa 4d87bab e143294 2edf1f0 e143294 2edf1f0 e1f2405 133324c 2edf1f0 133324c 92bf9aa 133324c 2edf1f0 133324c 92bf9aa 133324c 92bf9aa 4d87bab 92bf9aa 4d87bab 92bf9aa 4d87bab 2f225f8 2edf1f0 2f225f8 92bf9aa 2f225f8 2edf1f0 2f225f8 2edf1f0 4d87bab 92bf9aa 4d87bab 92bf9aa c3dfc09 133324c 2edf1f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
from transformers import pipeline
import torch
import logging
# Настройка логирования для диагностики
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Загружаем модель через pipeline (локально из Hugging Face Hub)
model_name = "distilgpt2"
try:
logger.info(f"Попытка загрузки модели {model_name}...")
generator = pipeline(
"text-generation",
model=model_name,
device=-1, # CPU для бесплатного Spaces
framework="pt",
max_length=512,
truncation=True,
model_kwargs={"torch_dtype": torch.float32} # Указываем тип данных для совместимости
)
logger.info("Модель успешно загружена.")
except Exception as e:
logger.error(f"Ошибка загрузки модели: {e}")
exit(1)
def respond(message, history, max_tokens=256, temperature=0.7, top_p=0.9):
history = history or []
# Формируем входной текст
input_text = ""
for user_msg, bot_msg in history:
input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
input_text += f"User: {message}"
# Генерация ответа
try:
logger.info(f"Генерация ответа для: {message}")
outputs = generator(
input_text,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
no_repeat_ngram_size=2,
num_return_sequences=1
)
response = outputs[0]["generated_text"][len(input_text):].strip()
logger.info(f"Ответ сгенерирован: {response}")
except Exception as e:
logger.error(f"Ошибка генерации ответа: {e}")
return f"Ошибка генерации: {e}", history
# Форматируем ответ
formatted_response = format_response(response)
history.append((message, formatted_response))
return formatted_response, history
def format_response(response):
diagnosis = extract_diagnosis(response)
operation = extract_operation(response)
treatment = extract_treatment(response)
return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
def extract_diagnosis(response):
sentences = response.split(".")
return sentences[0].strip() if sentences else response.strip()
def extract_operation(response):
return "Не требуется"
def extract_treatment(response):
sentences = response.split(".")
return sentences[-1].strip() if len(sentences) > 1 else "Не указано"
# Gradio интерфейс
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Медицинский чат-бот на базе DistilGPT-2")
chatbot = gr.Chatbot(label="Чат", height=400)
with gr.Row():
msg = gr.Textbox(
label="Ваше сообщение",
placeholder="Опишите симптомы (например, 'Болит горло')...",
lines=2,
show_label=True
)
submit_btn = gr.Button("Отправить", variant="primary")
with gr.Row():
max_tokens = gr.Slider(minimum=50, maximum=512, value=256, step=10, label="Макс. токенов")
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Температура")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p")
clear_btn = gr.Button("Очистить чат", variant="secondary")
state = gr.State(value=[])
def submit_message(message, history, max_tokens, temperature, top_p):
if not message.strip():
return [], history, "Пожалуйста, введите сообщение."
response, updated_history = respond(message, history, max_tokens, temperature, top_p)
return [(message, response)], updated_history, ""
def clear_chat():
return [], [], ""
# Кнопка "Отправить"
submit_btn.click(
fn=submit_message,
inputs=[msg, state, max_tokens, temperature, top_p],
outputs=[chatbot, state, msg],
queue=True
)
# Поддержка Enter
msg.submit(
fn=submit_message,
inputs=[msg, state, max_tokens, temperature, top_p],
outputs=[chatbot, state, msg],
queue=True
)
# Кнопка "Очистить"
clear_btn.click(
fn=clear_chat,
outputs=[chatbot, state, msg]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |