DoctorAI / app.py
Xolkin's picture
Update app.py
e143294 verified
raw
history blame
4.79 kB
import gradio as gr
from transformers import pipeline
import torch
import logging
# Настройка логирования для диагностики
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Загружаем модель через pipeline (локально из Hugging Face Hub)
model_name = "distilgpt2"
try:
logger.info(f"Попытка загрузки модели {model_name}...")
generator = pipeline(
"text-generation",
model=model_name,
device=-1, # CPU для бесплатного Spaces
framework="pt",
max_length=512,
truncation=True,
model_kwargs={"torch_dtype": torch.float32} # Указываем тип данных для совместимости
)
logger.info("Модель успешно загружена.")
except Exception as e:
logger.error(f"Ошибка загрузки модели: {e}")
exit(1)
def respond(message, history, max_tokens=256, temperature=0.7, top_p=0.9):
history = history or []
# Формируем входной текст
input_text = ""
for user_msg, bot_msg in history:
input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
input_text += f"User: {message}"
# Генерация ответа
try:
logger.info(f"Генерация ответа для: {message}")
outputs = generator(
input_text,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
no_repeat_ngram_size=2,
num_return_sequences=1
)
response = outputs[0]["generated_text"][len(input_text):].strip()
logger.info(f"Ответ сгенерирован: {response}")
except Exception as e:
logger.error(f"Ошибка генерации ответа: {e}")
return f"Ошибка генерации: {e}", history
# Форматируем ответ
formatted_response = format_response(response)
history.append((message, formatted_response))
return formatted_response, history
def format_response(response):
diagnosis = extract_diagnosis(response)
operation = extract_operation(response)
treatment = extract_treatment(response)
return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
def extract_diagnosis(response):
sentences = response.split(".")
return sentences[0].strip() if sentences else response.strip()
def extract_operation(response):
return "Не требуется"
def extract_treatment(response):
sentences = response.split(".")
return sentences[-1].strip() if len(sentences) > 1 else "Не указано"
# Gradio интерфейс
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Медицинский чат-бот на базе DistilGPT-2")
chatbot = gr.Chatbot(label="Чат", height=400)
with gr.Row():
msg = gr.Textbox(
label="Ваше сообщение",
placeholder="Опишите симптомы (например, 'Болит горло')...",
lines=2,
show_label=True
)
submit_btn = gr.Button("Отправить", variant="primary")
with gr.Row():
max_tokens = gr.Slider(minimum=50, maximum=512, value=256, step=10, label="Макс. токенов")
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Температура")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p")
clear_btn = gr.Button("Очистить чат", variant="secondary")
state = gr.State(value=[])
def submit_message(message, history, max_tokens, temperature, top_p):
if not message.strip():
return [], history, "Пожалуйста, введите сообщение."
response, updated_history = respond(message, history, max_tokens, temperature, top_p)
return [(message, response)], updated_history, ""
def clear_chat():
return [], [], ""
# Кнопка "Отправить"
submit_btn.click(
fn=submit_message,
inputs=[msg, state, max_tokens, temperature, top_p],
outputs=[chatbot, state, msg],
queue=True
)
# Поддержка Enter
msg.submit(
fn=submit_message,
inputs=[msg, state, max_tokens, temperature, top_p],
outputs=[chatbot, state, msg],
queue=True
)
# Кнопка "Очистить"
clear_btn.click(
fn=clear_chat,
outputs=[chatbot, state, msg]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)