Xolkin commited on
Commit
92bf9aa
·
verified ·
1 Parent(s): 2f225f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -30
app.py CHANGED
@@ -2,43 +2,57 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Загружаем модель GPT-2 локально
6
- model_name = "gpt2"
7
  try:
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
- # Устанавливаем pad_token, если он не задан
11
  if tokenizer.pad_token is None:
12
  tokenizer.pad_token = tokenizer.eos_token
 
13
  except Exception as e:
14
  print(f"Ошибка загрузки модели: {e}")
15
  exit(1)
16
 
17
- def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95):
18
  history = history or []
19
- # Формируем историю чата
20
- input_text = "\n".join([f"User: {msg[0]}\nAssistant: {msg[1]}" for msg in history] + [f"User: {message}"])
 
 
 
21
 
22
  # Токенизация
23
  try:
24
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
 
 
 
 
 
 
25
  except Exception as e:
26
  return f"Ошибка токенизации: {e}", history
27
 
28
  # Генерация ответа
29
  try:
30
- outputs = model.generate(
31
- inputs["input_ids"],
32
- max_length=max_tokens,
33
- temperature=temperature,
34
- top_p=top_p,
35
- do_sample=True,
36
- pad_token_id=tokenizer.eos_token_id,
37
- no_repeat_ngram_size=2
38
- )
 
 
39
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
40
  except Exception as e:
41
- return f"Ошибка генерации: {e}", history
42
 
43
  # Форматируем ответ
44
  formatted_response = format_response(response)
@@ -53,34 +67,70 @@ def format_response(response):
53
  return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
54
 
55
  def extract_diagnosis(response):
56
- return response.split(".")[0].strip() if "." in response else response.strip()
 
 
57
 
58
  def extract_operation(response):
 
59
  return "Не требуется"
60
 
61
  def extract_treatment(response):
62
- return response.split(".")[-1].strip() if "." in response else "Не указано"
 
 
63
 
64
- # Создаем Gradio интерфейс
65
- with gr.Blocks() as demo:
66
- gr.Markdown("## Медицинский чат-бот на базе GPT-2")
67
- chatbot = gr.Chatbot(label="Чат")
68
- msg = gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы...")
69
- max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=1, label="Максимальная длина ответа")
70
- temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Температура")
71
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  clear = gr.Button("Очистить чат")
73
  state = gr.State(value=[])
74
 
75
  def submit_message(message, history, max_tokens, temperature, top_p):
 
 
76
  response, updated_history = respond(message, history, max_tokens, temperature, top_p)
77
  return [(message, response)], updated_history, ""
78
 
79
  def clear_chat():
80
  return [], [], ""
81
 
82
- msg.submit(submit_message, [msg, state, max_tokens, temperature, top_p], [chatbot, state, msg])
83
- clear.click(clear_chat, outputs=[chatbot, state, msg])
 
 
 
 
 
 
 
 
84
 
85
  if __name__ == "__main__":
86
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Загружаем локальную модель distilgpt2 (более легкая, чем GPT-2)
6
+ model_name = "distilgpt2"
7
  try:
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ # Устанавливаем pad_token, если не задан
11
  if tokenizer.pad_token is None:
12
  tokenizer.pad_token = tokenizer.eos_token
13
+ model.eval() # Режим оценки для оптимизации
14
  except Exception as e:
15
  print(f"Ошибка загрузки модели: {e}")
16
  exit(1)
17
 
18
+ def respond(message, history, max_tokens=256, temperature=0.7, top_p=0.9):
19
  history = history or []
20
+ # Формируем входной текст с историей
21
+ input_text = ""
22
+ for user_msg, bot_msg in history:
23
+ input_text += f"User: {user_msg}\nAssistant: {bot_msg}\n"
24
+ input_text += f"User: {message}"
25
 
26
  # Токенизация
27
  try:
28
+ inputs = tokenizer(
29
+ input_text,
30
+ return_tensors="pt",
31
+ truncation=True,
32
+ max_length=512,
33
+ padding=True
34
+ )
35
  except Exception as e:
36
  return f"Ошибка токенизации: {e}", history
37
 
38
  # Генерация ответа
39
  try:
40
+ with torch.no_grad(): # Отключаем градиенты для экономии памяти
41
+ outputs = model.generate(
42
+ inputs["input_ids"],
43
+ max_length=max_tokens,
44
+ temperature=temperature,
45
+ top_p=top_p,
46
+ do_sample=True,
47
+ pad_token_id=tokenizer.eos_token_id,
48
+ no_repeat_ngram_size=2,
49
+ num_beams=2 # Добавляем beam search для лучшего качества
50
+ )
51
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+ # Удаляем входной текст из ответа
53
+ response = response[len(input_text):].strip()
54
  except Exception as e:
55
+ return f"Ошибка генерации ответа: {e}", history
56
 
57
  # Форматируем ответ
58
  formatted_response = format_response(response)
 
67
  return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
68
 
69
  def extract_diagnosis(response):
70
+ # Простое извлечение диагноза
71
+ sentences = response.split(".")
72
+ return sentences[0].strip() if sentences else response.strip()
73
 
74
  def extract_operation(response):
75
+ # Упрощенная логика: операция не требуется
76
  return "Не требуется"
77
 
78
  def extract_treatment(response):
79
+ # Извлечение лечения
80
+ sentences = response.split(".")
81
+ return sentences[-1].strip() if len(sentences) > 1 else "Не указано"
82
 
83
+ # Gradio интерфейс
84
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
85
+ gr.Markdown("## Медицинский чат-бот (на базе DistilGPT-2)")
86
+ chatbot = gr.Chatbot(label="История чата", height=400)
87
+ msg = gr.Textbox(
88
+ label="Ваше со��бщение",
89
+ placeholder="Опишите симптомы (например, 'Болит голова и температура')...",
90
+ lines=2
91
+ )
92
+ with gr.Row():
93
+ max_tokens = gr.Slider(
94
+ minimum=50,
95
+ maximum=512,
96
+ value=256,
97
+ step=10,
98
+ label="Макс. токенов"
99
+ )
100
+ temperature = gr.Slider(
101
+ minimum=0.1,
102
+ maximum=1.5,
103
+ value=0.7,
104
+ label="Температура"
105
+ )
106
+ top_p = gr.Slider(
107
+ minimum=0.1,
108
+ maximum=1.0,
109
+ value=0.9,
110
+ label="Top-p"
111
+ )
112
  clear = gr.Button("Очистить чат")
113
  state = gr.State(value=[])
114
 
115
  def submit_message(message, history, max_tokens, temperature, top_p):
116
+ if not message.strip():
117
+ return [], history, "Пожалуйста, введите сообщение."
118
  response, updated_history = respond(message, history, max_tokens, temperature, top_p)
119
  return [(message, response)], updated_history, ""
120
 
121
  def clear_chat():
122
  return [], [], ""
123
 
124
+ msg.submit(
125
+ fn=submit_message,
126
+ inputs=[msg, state, max_tokens, temperature, top_p],
127
+ outputs=[chatbot, state, msg],
128
+ queue=True
129
+ )
130
+ clear.click(
131
+ fn=clear_chat,
132
+ outputs=[chatbot, state, msg]
133
+ )
134
 
135
  if __name__ == "__main__":
136
  demo.launch(server_name="0.0.0.0", server_port=7860)