|
import gradio as gr |
|
from transformers import pipeline |
|
import torch |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_name = "distilgpt2" |
|
try: |
|
logger.info(f"Попытка загрузки модели {model_name}...") |
|
generator = pipeline( |
|
"text-generation", |
|
model=model_name, |
|
device=-1, |
|
framework="pt", |
|
max_length=512, |
|
truncation=True, |
|
model_kwargs={"torch_dtype": torch.float32} |
|
) |
|
logger.info("Модель успешно загружена.") |
|
except Exception as e: |
|
logger.error(f"Ошибка загрузки модели: {e}") |
|
exit(1) |
|
|
|
def respond(message, history, max_tokens=256, temperature=0.7, top_p=0.9): |
|
history = history or [] |
|
|
|
input_text = "" |
|
for user_msg, bot_msg in history: |
|
input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n" |
|
input_text += f"User: {message}" |
|
|
|
|
|
try: |
|
logger.info(f"Генерация ответа для: {message}") |
|
outputs = generator( |
|
input_text, |
|
max_length=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
no_repeat_ngram_size=2, |
|
num_return_sequences=1 |
|
) |
|
response = outputs[0]["generated_text"][len(input_text):].strip() |
|
logger.info(f"Ответ сгенерирован: {response}") |
|
except Exception as e: |
|
logger.error(f"Ошибка генерации ответа: {e}") |
|
return f"Ошибка генерации: {e}", history |
|
|
|
|
|
formatted_response = format_response(response) |
|
history.append((message, formatted_response)) |
|
return formatted_response, history |
|
|
|
def format_response(response): |
|
diagnosis = extract_diagnosis(response) |
|
operation = extract_operation(response) |
|
treatment = extract_treatment(response) |
|
return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}" |
|
|
|
def extract_diagnosis(response): |
|
sentences = response.split(".") |
|
return sentences[0].strip() if sentences else response.strip() |
|
|
|
def extract_operation(response): |
|
return "Не требуется" |
|
|
|
def extract_treatment(response): |
|
sentences = response.split(".") |
|
return sentences[-1].strip() if len(sentences) > 1 else "Не указано" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("## Медицинский чат-бот на базе DistilGPT-2") |
|
chatbot = gr.Chatbot(label="Чат", height=400) |
|
with gr.Row(): |
|
msg = gr.Textbox( |
|
label="Ваше сообщение", |
|
placeholder="Опишите симптомы (например, 'Болит горло')...", |
|
lines=2, |
|
show_label=True |
|
) |
|
submit_btn = gr.Button("Отправить", variant="primary") |
|
with gr.Row(): |
|
max_tokens = gr.Slider(minimum=50, maximum=512, value=256, step=10, label="Макс. токенов") |
|
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Температура") |
|
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p") |
|
clear_btn = gr.Button("Очистить чат", variant="secondary") |
|
state = gr.State(value=[]) |
|
|
|
def submit_message(message, history, max_tokens, temperature, top_p): |
|
if not message.strip(): |
|
return [], history, "Пожалуйста, введите сообщение." |
|
response, updated_history = respond(message, history, max_tokens, temperature, top_p) |
|
return [(message, response)], updated_history, "" |
|
|
|
def clear_chat(): |
|
return [], [], "" |
|
|
|
|
|
submit_btn.click( |
|
fn=submit_message, |
|
inputs=[msg, state, max_tokens, temperature, top_p], |
|
outputs=[chatbot, state, msg], |
|
queue=True |
|
) |
|
|
|
msg.submit( |
|
fn=submit_message, |
|
inputs=[msg, state, max_tokens, temperature, top_p], |
|
outputs=[chatbot, state, msg], |
|
queue=True |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_chat, |
|
outputs=[chatbot, state, msg] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |