DoctorAI / app.py
Xolkin's picture
Update app.py
4d87bab verified
raw
history blame
4.28 kB
import gradio as gr
from transformers import pipeline
import torch
# Загружаем модель через pipeline (локально, но из Hugging Face Hub)
model_name = "distilgpt2"
try:
generator = pipeline(
"text-generation",
model=model_name,
device=-1, # -1 означает CPU, подходит для бесплатного Spaces
framework="pt",
max_length=512,
truncation=True
)
except Exception as e:
print(f"Ошибка загрузки модели: {e}")
exit(1)
def respond(message, history, max_tokens=256, temperature=0.7, top_p=0.9):
history = history or []
# Формируем входной текст с историей
input_text = ""
for user_msg, bot_msg in history:
input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
input_text += f"User: {message}"
# Генерация ответа через pipeline
try:
outputs = generator(
input_text,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
no_repeat_ngram_size=2,
pad_token_id=generator.tokenizer.eos_token_id,
num_return_sequences=1
)
response = outputs[0]["generated_text"][len(input_text):].strip()
except Exception as e:
return f"Ошибка генерации ответа: {e}", history
# Форматируем ответ
formatted_response = format_response(response)
history.append((message, formatted_response))
return formatted_response, history
def format_response(response):
diagnosis = extract_diagnosis(response)
operation = extract_operation(response)
treatment = extract_treatment(response)
return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
def extract_diagnosis(response):
sentences = response.split(".")
return sentences[0].strip() if sentences else response.strip()
def extract_operation(response):
return "Не требуется"
def extract_treatment(response):
sentences = response.split(".")
return sentences[-1].strip() if len(sentences) > 1 else "Не указано"
# Gradio интерфейс
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Медицинский чат-бот на базе DistilGPT-2")
chatbot = gr.Chatbot(label="Чат", height=400)
with gr.Row():
msg = gr.Textbox(
label="Ваше сообщение",
placeholder="Опишите симптомы (например, 'Болит горло')...",
lines=2,
show_label=True
)
submit_btn = gr.Button("Отправить", variant="primary")
with gr.Row():
max_tokens = gr.Slider(minimum=50, maximum=512, value=256, step=10, label="Макс. токенов")
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Температура")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p")
clear_btn = gr.Button("Очистить чат", variant="secondary")
state = gr.State(value=[])
def submit_message(message, history, max_tokens, temperature, top_p):
if not message.strip():
return [], history, "Пожалуйста, введите сообщение."
response, updated_history = respond(message, history, max_tokens, temperature, top_p)
return [(message, response)], updated_history, ""
def clear_chat():
return [], [], ""
# Кнопка "Отправить"
submit_btn.click(
fn=submit_message,
inputs=[msg, state, max_tokens, temperature, top_p],
outputs=[chatbot, state, msg],
queue=True
)
# Поддержка Enter
msg.submit(
fn=submit_message,
inputs=[msg, state, max_tokens, temperature, top_p],
outputs=[chatbot, state, msg],
queue=True
)
# Кнопка "Очистить"
clear_btn.click(
fn=clear_chat,
outputs=[chatbot, state, msg]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)