Xolkin commited on
Commit
4d87bab
·
verified ·
1 Parent(s): 92bf9aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -63
app.py CHANGED
@@ -1,16 +1,18 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Загружаем локальную модель distilgpt2 (более легкая, чем GPT-2)
6
  model_name = "distilgpt2"
7
  try:
8
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
9
- model = AutoModelForCausalLM.from_pretrained(model_name)
10
- # Устанавливаем pad_token, если не задан
11
- if tokenizer.pad_token is None:
12
- tokenizer.pad_token = tokenizer.eos_token
13
- model.eval() # Режим оценки для оптимизации
 
 
14
  except Exception as e:
15
  print(f"Ошибка загрузки модели: {e}")
16
  exit(1)
@@ -23,34 +25,19 @@ def respond(message, history, max_tokens=256, temperature=0.7, top_p=0.9):
23
  input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
24
  input_text += f"User: {message}"
25
 
26
- # Токенизация
27
  try:
28
- inputs = tokenizer(
29
  input_text,
30
- return_tensors="pt",
31
- truncation=True,
32
- max_length=512,
33
- padding=True
 
 
 
34
  )
35
- except Exception as e:
36
- return f"Ошибка токенизации: {e}", history
37
-
38
- # Генерация ответа
39
- try:
40
- with torch.no_grad(): # Отключаем градиенты для экономии памяти
41
- outputs = model.generate(
42
- inputs["input_ids"],
43
- max_length=max_tokens,
44
- temperature=temperature,
45
- top_p=top_p,
46
- do_sample=True,
47
- pad_token_id=tokenizer.eos_token_id,
48
- no_repeat_ngram_size=2,
49
- num_beams=2 # Добавляем beam search для лучшего качества
50
- )
51
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
- # Удаляем входной текст из ответа
53
- response = response[len(input_text):].strip()
54
  except Exception as e:
55
  return f"Ошибка генерации ответа: {e}", history
56
 
@@ -67,49 +54,33 @@ def format_response(response):
67
  return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
68
 
69
  def extract_diagnosis(response):
70
- # Простое извлечение диагноза
71
  sentences = response.split(".")
72
  return sentences[0].strip() if sentences else response.strip()
73
 
74
  def extract_operation(response):
75
- # Упрощенная логика: операция не требуется
76
  return "Не требуется"
77
 
78
  def extract_treatment(response):
79
- # Извлечение лечения
80
  sentences = response.split(".")
81
  return sentences[-1].strip() if len(sentences) > 1 else "Не указано"
82
 
83
  # Gradio интерфейс
84
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
85
- gr.Markdown("## Медицинский чат-бот (на базе DistilGPT-2)")
86
- chatbot = gr.Chatbot(label="История чата", height=400)
87
- msg = gr.Textbox(
88
- label="Ваше сообщение",
89
- placeholder="Опишите симптомы (например, 'Болит голова и температура')...",
90
- lines=2
91
- )
92
  with gr.Row():
93
- max_tokens = gr.Slider(
94
- minimum=50,
95
- maximum=512,
96
- value=256,
97
- step=10,
98
- label="Макс. токенов"
99
- )
100
- temperature = gr.Slider(
101
- minimum=0.1,
102
- maximum=1.5,
103
- value=0.7,
104
- label="Температура"
105
- )
106
- top_p = gr.Slider(
107
- minimum=0.1,
108
- maximum=1.0,
109
- value=0.9,
110
- label="Top-p"
111
  )
112
- clear = gr.Button("Очистить чат")
 
 
 
 
 
113
  state = gr.State(value=[])
114
 
115
  def submit_message(message, history, max_tokens, temperature, top_p):
@@ -121,13 +92,22 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
121
  def clear_chat():
122
  return [], [], ""
123
 
 
 
 
 
 
 
 
 
124
  msg.submit(
125
  fn=submit_message,
126
  inputs=[msg, state, max_tokens, temperature, top_p],
127
  outputs=[chatbot, state, msg],
128
  queue=True
129
  )
130
- clear.click(
 
131
  fn=clear_chat,
132
  outputs=[chatbot, state, msg]
133
  )
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
  import torch
4
 
5
+ # Загружаем модель через pipeline (локально, но из Hugging Face Hub)
6
  model_name = "distilgpt2"
7
  try:
8
+ generator = pipeline(
9
+ "text-generation",
10
+ model=model_name,
11
+ device=-1, # -1 означает CPU, подходит для бесплатного Spaces
12
+ framework="pt",
13
+ max_length=512,
14
+ truncation=True
15
+ )
16
  except Exception as e:
17
  print(f"Ошибка загрузки модели: {e}")
18
  exit(1)
 
25
  input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
26
  input_text += f"User: {message}"
27
 
28
+ # Генерация ответа через pipeline
29
  try:
30
+ outputs = generator(
31
  input_text,
32
+ max_length=max_tokens,
33
+ temperature=temperature,
34
+ top_p=top_p,
35
+ do_sample=True,
36
+ no_repeat_ngram_size=2,
37
+ pad_token_id=generator.tokenizer.eos_token_id,
38
+ num_return_sequences=1
39
  )
40
+ response = outputs[0]["generated_text"][len(input_text):].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
  return f"Ошибка генерации ответа: {e}", history
43
 
 
54
  return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
55
 
56
  def extract_diagnosis(response):
 
57
  sentences = response.split(".")
58
  return sentences[0].strip() if sentences else response.strip()
59
 
60
  def extract_operation(response):
 
61
  return "Не требуется"
62
 
63
  def extract_treatment(response):
 
64
  sentences = response.split(".")
65
  return sentences[-1].strip() if len(sentences) > 1 else "Не указано"
66
 
67
  # Gradio интерфейс
68
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
69
+ gr.Markdown("## Медицинский чат-бот на базе DistilGPT-2")
70
+ chatbot = gr.Chatbot(label="Чат", height=400)
 
 
 
 
 
71
  with gr.Row():
72
+ msg = gr.Textbox(
73
+ label="Ваше сообщение",
74
+ placeholder="Опишите симптомы (например, 'Болит горло')...",
75
+ lines=2,
76
+ show_label=True
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  )
78
+ submit_btn = gr.Button("Отправить", variant="primary")
79
+ with gr.Row():
80
+ max_tokens = gr.Slider(minimum=50, maximum=512, value=256, step=10, label="Макс. токенов")
81
+ temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, label="Температура")
82
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p")
83
+ clear_btn = gr.Button("Очистить чат", variant="secondary")
84
  state = gr.State(value=[])
85
 
86
  def submit_message(message, history, max_tokens, temperature, top_p):
 
92
  def clear_chat():
93
  return [], [], ""
94
 
95
+ # Кнопка "Отправить"
96
+ submit_btn.click(
97
+ fn=submit_message,
98
+ inputs=[msg, state, max_tokens, temperature, top_p],
99
+ outputs=[chatbot, state, msg],
100
+ queue=True
101
+ )
102
+ # Поддержка Enter
103
  msg.submit(
104
  fn=submit_message,
105
  inputs=[msg, state, max_tokens, temperature, top_p],
106
  outputs=[chatbot, state, msg],
107
  queue=True
108
  )
109
+ # Кнопка "Очистить"
110
+ clear_btn.click(
111
  fn=clear_chat,
112
  outputs=[chatbot, state, msg]
113
  )