Alignment-Lab-AI commited on
Commit
422ac12
·
verified ·
1 Parent(s): 052bd8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -157
app.py CHANGED
@@ -1,42 +1,5 @@
1
- import os
2
- from threading import Thread
3
- from typing import Iterator
4
-
5
  import gradio as gr
6
- import spaces
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
-
10
- MAX_MAX_NEW_TOKENS = 1024
11
- DEFAULT_MAX_NEW_TOKENS = 256
12
- MAX_INPUT_TOKEN_LENGTH = 512
13
-
14
- DESCRIPTION = """\
15
- # Buzz-3B-Small
16
- This Space demonstrates Buzz-3b-small-v0.6.3.
17
- """
18
-
19
- LICENSE = """
20
- <p/>
21
- ---
22
- This demo uses Buzz-3b-small-v0.6.3. Please check the model card for details.
23
- """
24
-
25
- if not torch.cuda.is_available():
26
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo works better on GPU.</p>"
27
-
28
- model_id = "H-D-T/Buzz-3b-small-v0.6.3"
29
-
30
- if torch.cuda.is_available():
31
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
32
- else:
33
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True)
34
-
35
- tokenizer = AutoTokenizer.from_pretrained(model_id)
36
- if tokenizer.pad_token == None:
37
- tokenizer.pad_token = tokenizer.eos_token
38
- tokenizer.pad_token_id = tokenizer.eos_token_id
39
- model.config.pad_token_id = tokenizer.eos_token_id
40
 
41
  # Define the special tokens
42
  bos_token = "<|begin_of_text|>"
@@ -44,132 +7,43 @@ eos_token = "<|eot_id|>"
44
  start_header_id = "<|start_header_id|>"
45
  end_header_id = "<|end_header_id|>"
46
 
47
- def format_chat_history(chat_history: list[tuple[str, str]], add_generation_prompt=False) -> str:
48
- """
49
- Formats the chat history according to the model's chat template.
50
- """
51
- chat_template = f"""
52
- {{% if not add_generation_prompt is defined %}}{{% set add_generation_prompt = false %}}{{% endif %}}
53
- {{% set loop_messages = messages %}}
54
- {{% for message in loop_messages %}}
55
- {{% set content = '{start_header_id}' + message['role'] + '{end_header_id}\\n\\n' + message['content'].strip() + '{eos_token}' %}}
56
- {{% if loop.index0 == 0 %}}{{% set content = bos_token + content %}}{{% endif %}}
57
- {{ content }}
58
- {{% endfor %}}
59
- {{% if add_generation_prompt %}}{{ '{start_header_id}assistant{end_header_id}\\n\\n' }}{{% else %}}{{ eos_token }}{{% endif %}}
60
- """
61
- chat_context = ""
62
  for i, (user, assistant) in enumerate(chat_history):
63
- user_msg = start_header_id + "user" + end_header_id + "\n\n" + user.strip() + eos_token
64
- assistant_msg = start_header_id + "assistant" + end_header_id + "\n\n" + assistant.strip() + eos_token
65
  if i == 0:
66
  user_msg = bos_token + user_msg
67
- chat_context += user_msg + assistant_msg
68
-
69
- if add_generation_prompt:
70
- chat_context += start_header_id + "assistant" + end_header_id + "\n\n"
71
- else:
72
- chat_context += eos_token
73
-
74
- return chat_context
75
 
76
- @spaces.GPU
77
- def generate(
78
- message: str,
79
- chat_history: list[tuple[str, str]],
80
- max_new_tokens: int = 1024,
81
- temperature: float = 0.6,
82
- top_p: float = 0.9,
83
- top_k: int = 50,
84
- repetition_penalty: float = 1.4,
85
- ) -> Iterator[str]:
86
-
87
  chat_history.append(("user", message))
88
- chat_context = format_chat_history(chat_history, add_generation_prompt=True)
89
- input_ids = tokenizer([chat_context], return_tensors="pt").input_ids
90
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
91
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
92
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
93
- input_ids = input_ids.to(model.device)
94
-
95
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
96
- generate_kwargs = dict(
97
- {"input_ids": input_ids},
98
- streamer=streamer,
99
- max_new_tokens=max_new_tokens,
100
- do_sample=True,
101
- top_p=top_p,
102
- top_k=top_k,
103
- temperature=temperature,
104
- num_beams=1,
105
- pad_token_id = tokenizer.eos_token_id,
106
- repetition_penalty=repetition_penalty,
107
- no_repeat_ngram_size=5,
108
- early_stopping=False,
109
- )
110
- t = Thread(target=model.generate, kwargs=generate_kwargs)
111
- t.start()
112
-
113
- outputs = []
114
- for text in streamer:
115
- outputs.append(text)
116
- yield "".join(outputs)
117
-
118
-
119
- chat_interface = gr.ChatInterface(
120
- fn=generate,
121
- additional_inputs=[
122
- gr.Slider(
123
- label="Max new tokens",
124
- minimum=1,
125
- maximum=MAX_MAX_NEW_TOKENS,
126
- step=1,
127
- value=DEFAULT_MAX_NEW_TOKENS,
128
- ),
129
- gr.Slider(
130
- label="Temperature",
131
- minimum=0.1,
132
- maximum=4.0,
133
- step=0.1,
134
- value=0.6,
135
- ),
136
- gr.Slider(
137
- label="Top-p (nucleus sampling)",
138
- minimum=0.05,
139
- maximum=1.0,
140
- step=0.05,
141
- value=0.9,
142
- ),
143
- gr.Slider(
144
- label="Top-k",
145
- minimum=1,
146
- maximum=1000,
147
- step=1,
148
- value=50,
149
- ),
150
- gr.Slider(
151
- label="Repetition penalty",
152
- minimum=1.0,
153
- maximum=2.0,
154
- step=0.05,
155
- value=1.4,
156
- ),
157
- ],
158
- stop_btn=None,
159
- examples=[
160
- ["A recipe for a chocolate cake:"],
161
- ["Can you explain briefly to me what is the Python programming language?"],
162
- ["Explain the plot of Cinderella in a sentence."],
163
- ["Question: What is the capital of France?\nAnswer:"],
164
- ["Question: I am very tired, what should I do?\nAnswer:"],
165
- ],
166
- )
167
 
168
  with gr.Blocks(css="style.css") as demo:
169
- gr.Markdown(DESCRIPTION)
170
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
171
- chat_interface.render()
172
- gr.Markdown(LICENSE)
 
 
 
 
 
 
 
 
173
 
174
  if __name__ == "__main__":
175
  demo.queue(max_size=20).launch()
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, Conversation, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # Define the special tokens
5
  bos_token = "<|begin_of_text|>"
 
7
  start_header_id = "<|start_header_id|>"
8
  end_header_id = "<|end_header_id|>"
9
 
10
+ # Load the conversational pipeline and tokenizer
11
+ model_id = "H-D-T/Buzz-3b-small-v0.6.3"
12
+ chatbot = pipeline("conversational", model=model_id)
13
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
14
+
15
+ def format_conversation(chat_history):
16
+ formatted_history = ""
 
 
 
 
 
 
 
 
17
  for i, (user, assistant) in enumerate(chat_history):
18
+ user_msg = f"{start_header_id}user{end_header_id}\n\n{user.strip()}{eos_token}"
19
+ assistant_msg = f"{start_header_id}assistant{end_header_id}\n\n{assistant.strip()}{eos_token}"
20
  if i == 0:
21
  user_msg = bos_token + user_msg
22
+ formatted_history += user_msg + assistant_msg
23
+ return formatted_history
 
 
 
 
 
 
24
 
25
+ def predict(message, chat_history):
 
 
 
 
 
 
 
 
 
 
26
  chat_history.append(("user", message))
27
+ formatted_history = format_conversation(chat_history)
28
+ conversation = Conversation(formatted_history)
29
+ conversation = chatbot(conversation)
30
+ response = conversation.generated_responses[-1]
31
+ chat_history.append(("assistant", response))
32
+ return "", chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  with gr.Blocks(css="style.css") as demo:
35
+ gr.Markdown("# Buzz-3B-Small Conversational Demo")
36
+ with gr.Chatbot() as chatbot_ui:
37
+ chatbot_ui.append({"role": "assistant", "content": "Hi, how can I help you today?"})
38
+ with gr.Row():
39
+ with gr.Column():
40
+ textbox = gr.Textbox(label="Your message:")
41
+ with gr.Column():
42
+ submit_btn = gr.Button("Send")
43
+
44
+ chat_history = gr.State([])
45
+
46
+ submit_btn.click(predict, inputs=[textbox, chat_history], outputs=[textbox, chat_history])
47
 
48
  if __name__ == "__main__":
49
  demo.queue(max_size=20).launch()