Xolkin commited on
Commit
2edf1f0
·
verified ·
1 Parent(s): 27623dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -58
app.py CHANGED
@@ -1,81 +1,103 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
 
4
- # Используем модель GPT-2 для генерации текста
5
- model_name = "gpt2" # Модель GPT-2
6
- model = AutoModelForCausalLM.from_pretrained(model_name)
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
8
 
9
- def respond(message, history=None, system_message=None, max_tokens=512, temperature=0.7, top_p=0.95):
10
- if history is None:
11
- history = [] # Инициализируем пустой список, если history не передан
 
12
 
13
- # Объединяем сообщения в историю
14
- input_text = "\n".join([msg[1] for msg in history] + [message])
15
-
16
- # Токенизация текста
17
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
18
 
19
  # Генерация ответа
20
- outputs = model.generate(
21
- inputs["input_ids"],
22
- max_length=max_tokens,
23
- temperature=temperature,
24
- top_p=top_p,
25
- do_sample=True,
26
- )
27
-
28
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
-
30
- # Формируем ответ согласно шаблону
31
- response = format_response(response)
32
-
33
- # Добавляем текущий запрос и ответ в историю
34
- history.append((message, response))
35
-
36
- return response, history # Возвращаем ответ и обновленную историю
 
 
 
 
37
 
38
  def format_response(response):
39
- # Форматируем ответ в соответствии с шаблоном
40
  diagnosis = extract_diagnosis(response)
41
  operation = extract_operation(response)
42
  treatment = extract_treatment(response)
43
 
44
- formatted_response = f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
45
-
46
- return formatted_response
47
 
48
  def extract_diagnosis(response):
49
- # Простой способ извлечь диагноз (можно улучшить с помощью NLP)
50
- diagnosis = response.split(".")[0] # Пример: диагноз - первая часть ответа
51
- return diagnosis.strip()
52
 
53
  def extract_operation(response):
54
- # Извлекаем название операции из ответа
55
- operation = "Не требуется" # Пример, что операция не требуется
56
- return operation.strip()
57
 
58
  def extract_treatment(response):
59
- # Извлекаем лечение (например, лечение как последняя часть ответа)
60
- treatment = response.split(".")[-1] # Пример: лечение - последняя часть
61
- return treatment.strip()
62
 
63
  # Интерфейс Gradio
64
- demo = gr.Interface(
65
- fn=respond,
66
- inputs=[
67
- gr.Textbox(value="Здравствуйте. Отвечай кратко...", label="System message"),
68
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens"),
69
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, label="Temperature"),
70
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p"),
71
- gr.State() # Вход для состояния (history)
72
- ],
73
- outputs=[
74
- gr.Textbox(label="Response"), # Ответ от модели
75
- gr.State() # Выход для состояния (обновленная история)
76
- ],
77
- live=True # Обновляем интерфейс в реальном времени
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  if __name__ == "__main__":
81
- demo.launch()
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ # Загружаем модель GPT-2
6
+ model_name = "gpt2"
7
+ try:
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ except Exception as e:
11
+ print(f"Ошибка загрузки модели: {e}")
12
+ exit(1)
13
 
14
+ def respond(message, history, max_tokens=512, temperature=0.7, top_p=0.95):
15
+ # Формируем историю чата
16
+ history = history or []
17
+ input_text = "\n".join([f"User: {msg[0]}\nAssistant: {msg[1]}" for msg in history] + [f"User: {message}"])
18
 
19
+ # Токенизация
20
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
 
 
 
21
 
22
  # Генерация ответа
23
+ try:
24
+ outputs = model.generate(
25
+ inputs["input_ids"],
26
+ max_length=max_tokens,
27
+ temperature=temperature,
28
+ top_p=top_p,
29
+ do_sample=True,
30
+ pad_token_id=tokenizer.eos_token_id,
31
+ no_repeat_ngram_size=2
32
+ )
33
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+ except Exception as e:
35
+ return f"Ошибка генерации ответа: {e}", history
36
+
37
+ # Форматируем ответ
38
+ formatted_response = format_response(response)
39
+
40
+ # Обновляем историю
41
+ history.append((message, formatted_response))
42
+
43
+ return formatted_response, history
44
 
45
  def format_response(response):
46
+ # Упрощенное форматирование ответа
47
  diagnosis = extract_diagnosis(response)
48
  operation = extract_operation(response)
49
  treatment = extract_treatment(response)
50
 
51
+ return f"Предварительный диагноз: {diagnosis}\nОперация: {operation}\nЛечение: {treatment}"
 
 
52
 
53
  def extract_diagnosis(response):
54
+ # Извлечение диагноза (упрощенно)
55
+ return response.split(".")[0].strip() if "." in response else response.strip()
 
56
 
57
  def extract_operation(response):
58
+ # Упрощенная логика для операции
59
+ return "Не требуется"
 
60
 
61
  def extract_treatment(response):
62
+ # Извлечение лечения
63
+ return response.split(".")[-1].strip() if "." in response else "Не указано"
 
64
 
65
  # Интерфейс Gradio
66
+ def create_interface():
67
+ with gr.Blocks() as demo:
68
+ gr.Markdown("## Медицинский чат-бот на базе GPT-2")
69
+ chatbot = gr.Chatbot(label="Чат")
70
+ msg = gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы...")
71
+ max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=1, label="Максимальная длина ответа")
72
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Температура")
73
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p")
74
+ clear = gr.Button("Очистить чат")
75
+
76
+ # Состояние для истории
77
+ state = gr.State(value=[])
78
+
79
+ def submit_message(message, history, max_tokens, temperature, top_p):
80
+ response, updated_history = respond(message, history, max_tokens, temperature, top_p)
81
+ return response, updated_history, gr.update(value="")
82
+
83
+ def clear_chat():
84
+ return [], [], ""
85
+
86
+ # Обработка отправки сообщения
87
+ msg.submit(
88
+ submit_message,
89
+ inputs=[msg, state, max_tokens, temperature, top_p],
90
+ outputs=[chatbot, state, msg]
91
+ )
92
+
93
+ # Очистка чата
94
+ clear.click(
95
+ clear_chat,
96
+ outputs=[chatbot, state, msg]
97
+ )
98
+
99
+ return demo
100
 
101
  if __name__ == "__main__":
102
+ demo = create_interface()
103
+ demo.launch(server_name="0.0.0.0", server_port=7860)