File size: 4,276 Bytes
c3dfc09
4d87bab
2edf1f0
c3dfc09
4d87bab
92bf9aa
2edf1f0
4d87bab
 
 
 
 
 
 
 
2edf1f0
 
 
c3dfc09
92bf9aa
2edf1f0
92bf9aa
 
 
 
 
ca1d8ee
4d87bab
2f225f8
4d87bab
92bf9aa
4d87bab
 
 
 
 
 
 
92bf9aa
4d87bab
2edf1f0
92bf9aa
2edf1f0
 
 
 
 
 
e1f2405
133324c
 
 
 
2edf1f0
133324c
 
92bf9aa
 
133324c
 
2edf1f0
133324c
 
92bf9aa
 
133324c
92bf9aa
 
4d87bab
 
92bf9aa
4d87bab
 
 
 
 
92bf9aa
4d87bab
 
 
 
 
 
2f225f8
2edf1f0
2f225f8
92bf9aa
 
2f225f8
 
2edf1f0
2f225f8
 
2edf1f0
4d87bab
 
 
 
 
 
 
 
92bf9aa
 
 
 
 
 
4d87bab
 
92bf9aa
 
 
c3dfc09
133324c
2edf1f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from transformers import pipeline
import torch

# Загружаем модель через pipeline (локально, но из Hugging Face Hub)
model_name = "distilgpt2"
try:
    generator = pipeline(
        "text-generation",
        model=model_name,
        device=-1,  # -1 означает CPU, подходит для бесплатного Spaces
        framework="pt",
        max_length=512,
        truncation=True
    )
except Exception as e:
    print(f"Ошибка загрузки модели: {e}")
    exit(1)

def respond(message, history, max_tokens=256, temperature=0.7, top_p=0.9):
    history = history or []
    # Формируем входной текст с историей
    input_text = ""
    for user_msg, bot_msg in history:
        input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    input_text += f"User: {message}"

    # Генерация ответа через pipeline
    try:
        outputs = generator(
            input_text,
            max_length=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            no_repeat_ngram_size=2,
            pad_token_id=generator.tokenizer.eos_token_id,
            num_return_sequences=1
        )
        response = outputs[0]["generated_text"][len(input_text):].strip()
    except Exception as e:
        return f"Ошибка генерации ответа: {e}", history

    # Форматируем ответ
    formatted_response = format_response(response)
    history.append((message, formatted_response))

    return formatted_response, history

def format_response(response):
    diagnosis = extract_diagnosis(response)
    operation = extract_operation(response)
    treatment = extract_treatment(response)
    return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"

def extract_diagnosis(response):
    sentences = response.split(".")
    return sentences[0].strip() if sentences else response.strip()

def extract_operation(response):
    return "Не требуется"

def extract_treatment(response):
    sentences = response.split(".")
    return sentences[-1].strip() if len(sentences) > 1 else "Не указано"

# Gradio интерфейс
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## Медицинский чат-бот на базе DistilGPT-2")
    chatbot = gr.Chatbot(label="Чат", height=400)
    with gr.Row():
        msg = gr.Textbox(
            label="Ваше сообщение",
            placeholder="Опишите симптомы (например, 'Болит горло')...",
            lines=2,
            show_label=True
        )
        submit_btn = gr.Button("Отправить", variant="primary")
    with gr.Row():
        max_tokens = gr.Slider(minimum=50, maximum=512, value=256, step=10, label="Макс. токенов")
        temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Температура")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p")
    clear_btn = gr.Button("Очистить чат", variant="secondary")
    state = gr.State(value=[])

    def submit_message(message, history, max_tokens, temperature, top_p):
        if not message.strip():
            return [], history, "Пожалуйста, введите сообщение."
        response, updated_history = respond(message, history, max_tokens, temperature, top_p)
        return [(message, response)], updated_history, ""

    def clear_chat():
        return [], [], ""

    # Кнопка "Отправить"
    submit_btn.click(
        fn=submit_message,
        inputs=[msg, state, max_tokens, temperature, top_p],
        outputs=[chatbot, state, msg],
        queue=True
    )
    # Поддержка Enter
    msg.submit(
        fn=submit_message,
        inputs=[msg, state, max_tokens, temperature, top_p],
        outputs=[chatbot, state, msg],
        queue=True
    )
    # Кнопка "Очистить"
    clear_btn.click(
        fn=clear_chat,
        outputs=[chatbot, state, msg]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)