Xolkin commited on
Commit
2f225f8
·
verified ·
1 Parent(s): 2edf1f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -45
app.py CHANGED
@@ -2,22 +2,28 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Загружаем модель GPT-2
6
  model_name = "gpt2"
7
  try:
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
10
  except Exception as e:
11
  print(f"Ошибка загрузки модели: {e}")
12
  exit(1)
13
 
14
  def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95):
15
- # Формируем историю чата
16
  history = history or []
 
17
  input_text = "\n".join([f"User: {msg[0]}\nAssistant: {msg[1]}" for msg in history] + [f"User: {message}"])
18
 
19
  # Токенизация
20
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
 
 
 
21
 
22
  # Генерация ответа
23
  try:
@@ -32,72 +38,49 @@ def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95):
32
  )
33
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
  except Exception as e:
35
- return f"Ошибка генерации ответа: {e}", history
36
 
37
  # Форматируем ответ
38
  formatted_response = format_response(response)
39
-
40
- # Обновляем историю
41
  history.append((message, formatted_response))
42
 
43
  return formatted_response, history
44
 
45
  def format_response(response):
46
- # Упрощенное форматирование ответа
47
  diagnosis = extract_diagnosis(response)
48
  operation = extract_operation(response)
49
  treatment = extract_treatment(response)
50
-
51
  return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
52
 
53
  def extract_diagnosis(response):
54
- # Извлечение диагноза (упрощенно)
55
  return response.split(".")[0].strip() if "." in response else response.strip()
56
 
57
  def extract_operation(response):
58
- # Упрощенная логика для операции
59
  return "Не требуется"
60
 
61
  def extract_treatment(response):
62
- # Извлечение лечения
63
  return response.split(".")[-1].strip() if "." in response else "Не указано"
64
 
65
- # Интерфейс Gradio
66
- def create_interface():
67
- with gr.Blocks() as demo:
68
- gr.Markdown("## Медицинский чат-бот на базе GPT-2")
69
- chatbot = gr.Chatbot(label="Чат")
70
- msg = gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы...")
71
- max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=1, label="Максимальная длина ответа")
72
- temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Температура")
73
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")
74
- clear = gr.Button("Очистить чат")
75
-
76
- # Состояние для истории
77
- state = gr.State(value=[])
78
-
79
- def submit_message(message, history, max_tokens, temperature, top_p):
80
- response, updated_history = respond(message, history, max_tokens, temperature, top_p)
81
- return response, updated_history, gr.update(value="")
82
 
83
- def clear_chat():
84
- return [], [], ""
 
85
 
86
- # Обработка отправки сообщения
87
- msg.submit(
88
- submit_message,
89
- inputs=[msg, state, max_tokens, temperature, top_p],
90
- outputs=[chatbot, state, msg]
91
- )
92
-
93
- # Очистка чата
94
- clear.click(
95
- clear_chat,
96
- outputs=[chatbot, state, msg]
97
- )
98
 
99
- return demo
 
100
 
101
  if __name__ == "__main__":
102
- demo = create_interface()
103
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Загружаем модель GPT-2 локально
6
  model_name = "gpt2"
7
  try:
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ # Устанавливаем pad_token, если он не задан
11
+ if tokenizer.pad_token is None:
12
+ tokenizer.pad_token = tokenizer.eos_token
13
  except Exception as e:
14
  print(f"Ошибка загрузки модели: {e}")
15
  exit(1)
16
 
17
  def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95):
 
18
  history = history or []
19
+ # Формируем историю чата
20
  input_text = "\n".join([f"User: {msg[0]}\nAssistant: {msg[1]}" for msg in history] + [f"User: {message}"])
21
 
22
  # Токенизация
23
+ try:
24
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
25
+ except Exception as e:
26
+ return f"Ошибка токенизации: {e}", history
27
 
28
  # Генерация ответа
29
  try:
 
38
  )
39
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
  except Exception as e:
41
+ return f"Ошибка генерации: {e}", history
42
 
43
  # Форматируем ответ
44
  formatted_response = format_response(response)
 
 
45
  history.append((message, formatted_response))
46
 
47
  return formatted_response, history
48
 
49
  def format_response(response):
 
50
  diagnosis = extract_diagnosis(response)
51
  operation = extract_operation(response)
52
  treatment = extract_treatment(response)
 
53
  return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
54
 
55
  def extract_diagnosis(response):
 
56
  return response.split(".")[0].strip() if "." in response else response.strip()
57
 
58
  def extract_operation(response):
 
59
  return "Не требуется"
60
 
61
  def extract_treatment(response):
 
62
  return response.split(".")[-1].strip() if "." in response else "Не указано"
63
 
64
+ # Создаем Gradio интерфейс
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown("## Медицинский чат-бот на базе GPT-2")
67
+ chatbot = gr.Chatbot(label="Чат")
68
+ msg = gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы...")
69
+ max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=1, label="Максимальная длина ответа")
70
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Температура")
71
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")
72
+ clear = gr.Button("Очистить чат")
73
+ state = gr.State(value=[])
 
 
 
 
 
 
 
74
 
75
+ def submit_message(message, history, max_tokens, temperature, top_p):
76
+ response, updated_history = respond(message, history, max_tokens, temperature, top_p)
77
+ return [(message, response)], updated_history, ""
78
 
79
+ def clear_chat():
80
+ return [], [], ""
 
 
 
 
 
 
 
 
 
 
81
 
82
+ msg.submit(submit_message, [msg, state, max_tokens, temperature, top_p], [chatbot, state, msg])
83
+ clear.click(clear_chat, outputs=[chatbot, state, msg])
84
 
85
  if __name__ == "__main__":
 
86
  demo.launch(server_name="0.0.0.0", server_port=7860)