Xolkin commited on
Commit
ca1d8ee
·
verified ·
1 Parent(s): 2192ebd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -33
app.py CHANGED
@@ -2,19 +2,11 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  # Подключаем модель и токенизатор
5
- model_name = "distilgpt2" # Используем distilgpt2 как более легкую модель
6
  model = AutoModelForCausalLM.from_pretrained(model_name)
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
- def respond(
10
- message,
11
- history: list[tuple[str, str]],
12
- system_message,
13
- max_tokens,
14
- temperature,
15
- top_p,
16
- ):
17
- # Создаем входные данные
18
  messages = [{"role": "system", "content": system_message}]
19
 
20
  for val in history:
@@ -22,17 +14,13 @@ def respond(
22
  messages.append({"role": "user", "content": val[0]})
23
  if val[1]:
24
  messages.append({"role": "assistant", "content": val[1]})
25
-
26
- # Добавляем последнее сообщение пользователя
27
  messages.append({"role": "user", "content": message})
28
 
29
- # Объединяем все сообщения в один текст
30
  input_text = "\n".join([msg["content"] for msg in messages])
31
 
32
- # Токенизация текста
33
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
34
 
35
- # Генерация ответа моделью
36
  outputs = model.generate(
37
  inputs["input_ids"],
38
  max_length=max_tokens,
@@ -41,31 +29,20 @@ def respond(
41
  do_sample=True,
42
  )
43
 
44
- # Декодируем результат в строку
45
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
-
47
- # Добавляем подпись
48
  response += "\nСделано больницей EMS штата Alta!"
49
-
50
  return response
51
 
52
- # Запускаем Gradio интерфейс
53
  demo = gr.Interface(
54
  fn=respond,
55
  inputs=[
56
- gr.Textbox(value="Здравствуйте. Отвечай кратко(не пиши вступление, умозаключения итп) и сразу пиши начинай с этого ответ: Предварительный диагноз:(диагноз), Операция: (Если требуется, только название, не надо писать хирургическое вмешательство или подобное, а четкое медицинское название операции), Лечение: (Кратко, очень). Не пиши воду. Только по факту на 3 пункта отвечай. Не отходи от этого шаблона", label="System message"),
57
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens"),
58
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
59
- gr.Slider(
60
- minimum=0.1,
61
- maximum=1.0,
62
- value=0.95,
63
- step=0.05,
64
- label="Top-p (nucleus sampling)",
65
- ),
66
  ],
67
- css="styles.css", # Ссылка на внешний CSS файл
68
  )
69
 
70
- if __name__ == "__main__":
71
- demo.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  # Подключаем модель и токенизатор
5
+ model_name = "distilgpt2"
6
  model = AutoModelForCausalLM.from_pretrained(model_name)
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
 
10
  messages = [{"role": "system", "content": system_message}]
11
 
12
  for val in history:
 
14
  messages.append({"role": "user", "content": val[0]})
15
  if val[1]:
16
  messages.append({"role": "assistant", "content": val[1]})
17
+
 
18
  messages.append({"role": "user", "content": message})
19
 
 
20
  input_text = "\n".join([msg["content"] for msg in messages])
21
 
 
22
  inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
23
 
 
24
  outputs = model.generate(
25
  inputs["input_ids"],
26
  max_length=max_tokens,
 
29
  do_sample=True,
30
  )
31
 
 
32
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
33
  response += "\nСделано больницей EMS штата Alta!"
 
34
  return response
35
 
36
+ # Интерфейс Gradio
37
  demo = gr.Interface(
38
  fn=respond,
39
  inputs=[
40
+ gr.Textbox(value="Здравствуйте. Отвечай кратко...", label="System message"),
41
+ gr.Slider(minimum=1, maximum=2048, value=512, label="Max Tokens"),
42
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, label="Temperature"),
43
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top-p"),
 
 
 
 
 
 
44
  ],
45
+ outputs="text",
46
  )
47
 
48
+ demo.launch()