File size: 4,791 Bytes
c3dfc09
4d87bab
2edf1f0
e143294
c3dfc09
e143294
 
 
 
 
92bf9aa
2edf1f0
e143294
4d87bab
 
 
e143294
4d87bab
 
e143294
 
4d87bab
e143294
2edf1f0
e143294
2edf1f0
c3dfc09
92bf9aa
2edf1f0
e143294
92bf9aa
 
 
 
ca1d8ee
e143294
2f225f8
e143294
4d87bab
92bf9aa
4d87bab
 
 
 
 
 
92bf9aa
4d87bab
e143294
2edf1f0
e143294
 
2edf1f0
 
 
 
 
e1f2405
133324c
 
 
 
2edf1f0
133324c
 
92bf9aa
 
133324c
 
2edf1f0
133324c
 
92bf9aa
 
133324c
92bf9aa
 
4d87bab
 
92bf9aa
4d87bab
 
 
 
 
92bf9aa
4d87bab
 
 
 
 
 
2f225f8
2edf1f0
2f225f8
92bf9aa
 
2f225f8
 
2edf1f0
2f225f8
 
2edf1f0
4d87bab
 
 
 
 
 
 
 
92bf9aa
 
 
 
 
 
4d87bab
 
92bf9aa
 
 
c3dfc09
133324c
2edf1f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from transformers import pipeline
import torch
import logging

# Настройка логирования для диагностики
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Загружаем модель через pipeline (локально из Hugging Face Hub)
model_name = "distilgpt2"
try:
    logger.info(f"Попытка загрузки модели {model_name}...")
    generator = pipeline(
        "text-generation",
        model=model_name,
        device=-1,  # CPU для бесплатного Spaces
        framework="pt",
        max_length=512,
        truncation=True,
        model_kwargs={"torch_dtype": torch.float32}  # Указываем тип данных для совместимости
    )
    logger.info("Модель успешно загружена.")
except Exception as e:
    logger.error(f"Ошибка загрузки модели: {e}")
    exit(1)

def respond(message, history, max_tokens=256, temperature=0.7, top_p=0.9):
    history = history or []
    # Формируем входной текст
    input_text = ""
    for user_msg, bot_msg in history:
        input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    input_text += f"User: {message}"

    # Генерация ответа
    try:
        logger.info(f"Генерация ответа для: {message}")
        outputs = generator(
            input_text,
            max_length=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            no_repeat_ngram_size=2,
            num_return_sequences=1
        )
        response = outputs[0]["generated_text"][len(input_text):].strip()
        logger.info(f"Ответ сгенерирован: {response}")
    except Exception as e:
        logger.error(f"Ошибка генерации ответа: {e}")
        return f"Ошибка генерации: {e}", history

    # Форматируем ответ
    formatted_response = format_response(response)
    history.append((message, formatted_response))
    return formatted_response, history

def format_response(response):
    diagnosis = extract_diagnosis(response)
    operation = extract_operation(response)
    treatment = extract_treatment(response)
    return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"

def extract_diagnosis(response):
    sentences = response.split(".")
    return sentences[0].strip() if sentences else response.strip()

def extract_operation(response):
    return "Не требуется"

def extract_treatment(response):
    sentences = response.split(".")
    return sentences[-1].strip() if len(sentences) > 1 else "Не указано"

# Gradio интерфейс
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## Медицинский чат-бот на базе DistilGPT-2")
    chatbot = gr.Chatbot(label="Чат", height=400)
    with gr.Row():
        msg = gr.Textbox(
            label="Ваше сообщение",
            placeholder="Опишите симптомы (например, 'Болит горло')...",
            lines=2,
            show_label=True
        )
        submit_btn = gr.Button("Отправить", variant="primary")
    with gr.Row():
        max_tokens = gr.Slider(minimum=50, maximum=512, value=256, step=10, label="Макс. токенов")
        temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Температура")
        top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p")
    clear_btn = gr.Button("Очистить чат", variant="secondary")
    state = gr.State(value=[])

    def submit_message(message, history, max_tokens, temperature, top_p):
        if not message.strip():
            return [], history, "Пожалуйста, введите сообщение."
        response, updated_history = respond(message, history, max_tokens, temperature, top_p)
        return [(message, response)], updated_history, ""

    def clear_chat():
        return [], [], ""

    # Кнопка "Отправить"
    submit_btn.click(
        fn=submit_message,
        inputs=[msg, state, max_tokens, temperature, top_p],
        outputs=[chatbot, state, msg],
        queue=True
    )
    # Поддержка Enter
    msg.submit(
        fn=submit_message,
        inputs=[msg, state, max_tokens, temperature, top_p],
        outputs=[chatbot, state, msg],
        queue=True
    )
    # Кнопка "Очистить"
    clear_btn.click(
        fn=clear_chat,
        outputs=[chatbot, state, msg]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)