Alignment-Lab-AI commited on
Commit
052bd8b
·
verified ·
1 Parent(s): 66cfc04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -1
app.py CHANGED
@@ -1,3 +1,175 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.load("models/H-D-T/Buzz-3b-small-v0.6.3").launch()
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ MAX_MAX_NEW_TOKENS = 1024
11
+ DEFAULT_MAX_NEW_TOKENS = 256
12
+ MAX_INPUT_TOKEN_LENGTH = 512
13
+
14
+ DESCRIPTION = """\
15
+ # Buzz-3B-Small
16
+ This Space demonstrates Buzz-3b-small-v0.6.3.
17
+ """
18
+
19
+ LICENSE = """
20
+ <p/>
21
+ ---
22
+ This demo uses Buzz-3b-small-v0.6.3. Please check the model card for details.
23
+ """
24
+
25
+ if not torch.cuda.is_available():
26
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo works better on GPU.</p>"
27
+
28
+ model_id = "H-D-T/Buzz-3b-small-v0.6.3"
29
+
30
+ if torch.cuda.is_available():
31
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
32
+ else:
33
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True)
34
+
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
36
+ if tokenizer.pad_token == None:
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+ tokenizer.pad_token_id = tokenizer.eos_token_id
39
+ model.config.pad_token_id = tokenizer.eos_token_id
40
+
41
+ # Define the special tokens
42
+ bos_token = "<|begin_of_text|>"
43
+ eos_token = "<|eot_id|>"
44
+ start_header_id = "<|start_header_id|>"
45
+ end_header_id = "<|end_header_id|>"
46
+
47
+ def format_chat_history(chat_history: list[tuple[str, str]], add_generation_prompt=False) -> str:
48
+ """
49
+ Formats the chat history according to the model's chat template.
50
+ """
51
+ chat_template = f"""
52
+ {{% if not add_generation_prompt is defined %}}{{% set add_generation_prompt = false %}}{{% endif %}}
53
+ {{% set loop_messages = messages %}}
54
+ {{% for message in loop_messages %}}
55
+ {{% set content = '{start_header_id}' + message['role'] + '{end_header_id}\\n\\n' + message['content'].strip() + '{eos_token}' %}}
56
+ {{% if loop.index0 == 0 %}}{{% set content = bos_token + content %}}{{% endif %}}
57
+ {{ content }}
58
+ {{% endfor %}}
59
+ {{% if add_generation_prompt %}}{{ '{start_header_id}assistant{end_header_id}\\n\\n' }}{{% else %}}{{ eos_token }}{{% endif %}}
60
+ """
61
+ chat_context = ""
62
+ for i, (user, assistant) in enumerate(chat_history):
63
+ user_msg = start_header_id + "user" + end_header_id + "\n\n" + user.strip() + eos_token
64
+ assistant_msg = start_header_id + "assistant" + end_header_id + "\n\n" + assistant.strip() + eos_token
65
+ if i == 0:
66
+ user_msg = bos_token + user_msg
67
+ chat_context += user_msg + assistant_msg
68
+
69
+ if add_generation_prompt:
70
+ chat_context += start_header_id + "assistant" + end_header_id + "\n\n"
71
+ else:
72
+ chat_context += eos_token
73
+
74
+ return chat_context
75
+
76
+ @spaces.GPU
77
+ def generate(
78
+ message: str,
79
+ chat_history: list[tuple[str, str]],
80
+ max_new_tokens: int = 1024,
81
+ temperature: float = 0.6,
82
+ top_p: float = 0.9,
83
+ top_k: int = 50,
84
+ repetition_penalty: float = 1.4,
85
+ ) -> Iterator[str]:
86
+
87
+ chat_history.append(("user", message))
88
+ chat_context = format_chat_history(chat_history, add_generation_prompt=True)
89
+ input_ids = tokenizer([chat_context], return_tensors="pt").input_ids
90
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
91
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
92
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
93
+ input_ids = input_ids.to(model.device)
94
+
95
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
96
+ generate_kwargs = dict(
97
+ {"input_ids": input_ids},
98
+ streamer=streamer,
99
+ max_new_tokens=max_new_tokens,
100
+ do_sample=True,
101
+ top_p=top_p,
102
+ top_k=top_k,
103
+ temperature=temperature,
104
+ num_beams=1,
105
+ pad_token_id = tokenizer.eos_token_id,
106
+ repetition_penalty=repetition_penalty,
107
+ no_repeat_ngram_size=5,
108
+ early_stopping=False,
109
+ )
110
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
111
+ t.start()
112
+
113
+ outputs = []
114
+ for text in streamer:
115
+ outputs.append(text)
116
+ yield "".join(outputs)
117
+
118
+
119
+ chat_interface = gr.ChatInterface(
120
+ fn=generate,
121
+ additional_inputs=[
122
+ gr.Slider(
123
+ label="Max new tokens",
124
+ minimum=1,
125
+ maximum=MAX_MAX_NEW_TOKENS,
126
+ step=1,
127
+ value=DEFAULT_MAX_NEW_TOKENS,
128
+ ),
129
+ gr.Slider(
130
+ label="Temperature",
131
+ minimum=0.1,
132
+ maximum=4.0,
133
+ step=0.1,
134
+ value=0.6,
135
+ ),
136
+ gr.Slider(
137
+ label="Top-p (nucleus sampling)",
138
+ minimum=0.05,
139
+ maximum=1.0,
140
+ step=0.05,
141
+ value=0.9,
142
+ ),
143
+ gr.Slider(
144
+ label="Top-k",
145
+ minimum=1,
146
+ maximum=1000,
147
+ step=1,
148
+ value=50,
149
+ ),
150
+ gr.Slider(
151
+ label="Repetition penalty",
152
+ minimum=1.0,
153
+ maximum=2.0,
154
+ step=0.05,
155
+ value=1.4,
156
+ ),
157
+ ],
158
+ stop_btn=None,
159
+ examples=[
160
+ ["A recipe for a chocolate cake:"],
161
+ ["Can you explain briefly to me what is the Python programming language?"],
162
+ ["Explain the plot of Cinderella in a sentence."],
163
+ ["Question: What is the capital of France?\nAnswer:"],
164
+ ["Question: I am very tired, what should I do?\nAnswer:"],
165
+ ],
166
+ )
167
+
168
+ with gr.Blocks(css="style.css") as demo:
169
+ gr.Markdown(DESCRIPTION)
170
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
171
+ chat_interface.render()
172
+ gr.Markdown(LICENSE)
173
 
174
+ if __name__ == "__main__":
175
+ demo.queue(max_size=20).launch()