|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
model_name = "gpt2" |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
except Exception as e: |
|
print(f"Ошибка загрузки модели: {e}") |
|
exit(1) |
|
|
|
def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95): |
|
history = history or [] |
|
|
|
input_text = "\n".join([f"User: {msg[0]}\nAssistant: {msg[1]}" for msg in history] + [f"User: {message}"]) |
|
|
|
|
|
try: |
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512, padding=True) |
|
except Exception as e: |
|
return f"Ошибка токенизации: {e}", history |
|
|
|
|
|
try: |
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=2 |
|
) |
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
except Exception as e: |
|
return f"Ошибка генерации: {e}", history |
|
|
|
|
|
formatted_response = format_response(response) |
|
history.append((message, formatted_response)) |
|
|
|
return formatted_response, history |
|
|
|
def format_response(response): |
|
diagnosis = extract_diagnosis(response) |
|
operation = extract_operation(response) |
|
treatment = extract_treatment(response) |
|
return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}" |
|
|
|
def extract_diagnosis(response): |
|
return response.split(".")[0].strip() if "." in response else response.strip() |
|
|
|
def extract_operation(response): |
|
return "Не требуется" |
|
|
|
def extract_treatment(response): |
|
return response.split(".")[-1].strip() if "." in response else "Не указано" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Медицинский чат-бот на базе GPT-2") |
|
chatbot = gr.Chatbot(label="Чат") |
|
msg = gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы...") |
|
max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=1, label="Максимальная длина ответа") |
|
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Температура") |
|
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p") |
|
clear = gr.Button("Очистить чат") |
|
state = gr.State(value=[]) |
|
|
|
def submit_message(message, history, max_tokens, temperature, top_p): |
|
response, updated_history = respond(message, history, max_tokens, temperature, top_p) |
|
return [(message, response)], updated_history, "" |
|
|
|
def clear_chat(): |
|
return [], [], "" |
|
|
|
msg.submit(submit_message, [msg, state, max_tokens, temperature, top_p], [chatbot, state, msg]) |
|
clear.click(clear_chat, outputs=[chatbot, state, msg]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |