File size: 3,548 Bytes
c3dfc09
dfa1a47
2edf1f0
c3dfc09
2f225f8
2edf1f0
 
 
2f225f8
 
 
 
2edf1f0
 
 
c3dfc09
2edf1f0
 
2f225f8
2edf1f0
ca1d8ee
2edf1f0
2f225f8
 
 
 
dfa1a47
f29c29c
2edf1f0
 
 
 
 
 
 
 
 
 
 
 
2f225f8
2edf1f0
 
 
 
 
 
e1f2405
133324c
 
 
 
2edf1f0
133324c
 
2edf1f0
133324c
 
2edf1f0
133324c
 
2edf1f0
133324c
2f225f8
 
 
 
 
 
 
 
 
 
2edf1f0
2f225f8
 
 
2edf1f0
2f225f8
 
2edf1f0
2f225f8
 
c3dfc09
133324c
2edf1f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Загружаем модель GPT-2 локально
model_name = "gpt2"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    # Устанавливаем pad_token, если он не задан
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
except Exception as e:
    print(f"Ошибка загрузки модели: {e}")
    exit(1)

def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95):
    history = history or []
    # Формируем историю чата
    input_text = "\n".join([f"User: {msg[0]}\nAssistant: {msg[1]}" for msg in history] + [f"User: {message}"])

    # Токенизация
    try:
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
    except Exception as e:
        return f"Ошибка токенизации: {e}", history

    # Генерация ответа
    try:
        outputs = model.generate(
            inputs["input_ids"],
            max_length=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            no_repeat_ngram_size=2
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Ошибка генерации: {e}", history

    # Форматируем ответ
    formatted_response = format_response(response)
    history.append((message, formatted_response))

    return formatted_response, history

def format_response(response):
    diagnosis = extract_diagnosis(response)
    operation = extract_operation(response)
    treatment = extract_treatment(response)
    return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"

def extract_diagnosis(response):
    return response.split(".")[0].strip() if "." in response else response.strip()

def extract_operation(response):
    return "Не требуется"

def extract_treatment(response):
    return response.split(".")[-1].strip() if "." in response else "Не указано"

# Создаем Gradio интерфейс
with gr.Blocks() as demo:
    gr.Markdown("## Медицинский чат-бот на базе GPT-2")
    chatbot = gr.Chatbot(label="Чат")
    msg = gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы...")
    max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=1, label="Максимальная длина ответа")
    temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Температура")
    top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")
    clear = gr.Button("Очистить чат")
    state = gr.State(value=[])

    def submit_message(message, history, max_tokens, temperature, top_p):
        response, updated_history = respond(message, history, max_tokens, temperature, top_p)
        return [(message, response)], updated_history, ""

    def clear_chat():
        return [], [], ""

    msg.submit(submit_message, [msg, state, max_tokens, temperature, top_p], [chatbot, state, msg])
    clear.click(clear_chat, outputs=[chatbot, state, msg])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)