Alignment-Lab-AI commited on
Commit
1c24143
·
verified ·
1 Parent(s): 0813196

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -35
app.py CHANGED
@@ -1,49 +1,164 @@
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline, Conversation, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- # Define the special tokens
5
  bos_token = "<|begin_of_text|>"
6
  eos_token = "<|eot_id|>"
7
  start_header_id = "<|start_header_id|>"
8
  end_header_id = "<|end_header_id|>"
9
 
10
- # Load the conversational pipeline and tokenizer
11
- model_id = "H-D-T/Buzz-3b-small-v0.6.3"
12
- chatbot = pipeline("conversational", model=model_id)
13
- tokenizer = AutoTokenizer.from_pretrained(model_id)
14
 
15
- def format_conversation(chat_history):
16
- formatted_history = ""
17
- for i, (user, assistant) in enumerate(chat_history):
18
- user_msg = f"{start_header_id}user{end_header_id}\n\n{user.strip()}{eos_token}"
19
- assistant_msg = f"{start_header_id}assistant{end_header_id}\n\n{assistant.strip()}{eos_token}"
 
 
 
20
  if i == 0:
21
- user_msg = bos_token + user_msg
22
- formatted_history += user_msg + assistant_msg
23
- return formatted_history
24
-
25
- def predict(message, chat_history):
26
- chat_history.append(("user", message))
27
- formatted_history = format_conversation(chat_history)
28
- conversation = Conversation(formatted_history)
29
- conversation = chatbot(conversation)
30
- response = conversation.generated_responses[-1]
31
- chat_history.append(("assistant", response))
32
- return "", chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  with gr.Blocks(css="style.css") as demo:
35
- gr.Markdown("# Buzz-3B-Small Conversational Demo")
36
- with gr.Chatbot() as chatbot_ui:
37
- chatbot_ui.append({"role": "assistant", "content": "Hi, how can I help you today?"})
38
- with gr.Row():
39
- with gr.Column():
40
- textbox = gr.Textbox(label="Your message:")
41
- with gr.Column():
42
- submit_btn = gr.Button("Send")
43
-
44
- chat_history = gr.State([])
45
-
46
- submit_btn.click(predict, inputs=[textbox, chat_history], outputs=[textbox, chat_history])
47
 
48
  if __name__ == "__main__":
49
  demo.queue(max_size=20).launch()
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator, List, Dict, Any
4
+
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, Conversation, pipeline
9
+
10
+ MAX_MAX_NEW_TOKENS = 1024
11
+ DEFAULT_MAX_NEW_TOKENS = 256
12
+ MAX_INPUT_TOKEN_LENGTH = 512
13
+
14
+ DESCRIPTION = """\
15
+ # Buzz-3B-Small
16
+ This Space demonstrates Buzz-3b-small-v0.6.3.
17
+ """
18
+
19
+ LICENSE = """
20
+ <p/>
21
+ ---
22
+ Chat with Buzz-small!
23
+ only 3b, this demo runs on the fp16 weights of the model in safetensors format (converting to cpp soon)
24
+ """
25
+
26
+ if torch.cuda.is_available():
27
+ torch.device('cuda')
28
+ else:
29
+ torch.device('cpu')
30
+
31
+ model_id = "H-D-T/Buzz-3b-small-v0.6.3"
32
+ chatbot = pipeline(model=model_id, device=device, task="conversational")
33
 
34
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
35
  bos_token = "<|begin_of_text|>"
36
  eos_token = "<|eot_id|>"
37
  start_header_id = "<|start_header_id|>"
38
  end_header_id = "<|end_header_id|>"
39
 
40
+ if tokenizer.pad_token is None:
41
+ tokenizer.pad_token = tokenizer.eos_token
42
+ tokenizer.pad_token_id = tokenizer.eos_token_id
43
+ model.config.pad_token_id = tokenizer.eos_token_id
44
 
45
+ def format_conversation(chat_history: List[Dict[str, str]], add_generation_prompt=False) -> str:
46
+ """
47
+ Formats the chat history according to the model's chat template.
48
+ """
49
+ formatted_history = []
50
+ for i, message in enumerate(chat_history):
51
+ role, content = message["role"], message["content"]
52
+ formatted_message = f"{start_header_id}{role}{end_header_id}\n\n{content.strip()}{eos_token}"
53
  if i == 0:
54
+ formatted_message = bos_token + formatted_message
55
+ formatted_history.append(formatted_message)
56
+
57
+ if add_generation_prompt:
58
+ formatted_history.append(f"{start_header_id}assistant{end_header_id}\n\n")
59
+ else:
60
+ formatted_history.append(eos_token)
61
+
62
+ return "".join(formatted_history)
63
+
64
+ @spaces.GPU
65
+ def generate(
66
+ message: str,
67
+ chat_history: List[Dict[str, str]],
68
+ max_new_tokens: int = 1024,
69
+ temperature: float = 0.6,
70
+ top_p: float = 0.9,
71
+ top_k: int = 50,
72
+ repetition_penalty: float = 1.4,
73
+ ) -> Iterator[str]:
74
+
75
+ chat_history.append({"role": "user", "content": message})
76
+ chat_context = format_conversation(chat_history, add_generation_prompt=True)
77
+ input_ids = tokenizer([chat_context], return_tensors="pt").input_ids
78
+
79
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
80
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
81
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
82
+ input_ids = input_ids.to(device)
83
+
84
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
85
+ generate_kwargs = dict(
86
+ input_ids=input_ids,
87
+ streamer=streamer,
88
+ max_new_tokens=max_new_tokens,
89
+ do_sample=True,
90
+ top_p=top_p,
91
+ top_k=top_k,
92
+ temperature=temperature,
93
+ num_beams=1,
94
+ pad_token_id=tokenizer.eos_token_id,
95
+ repetition_penalty=repetition_penalty,
96
+ no_repeat_ngram_size=5,
97
+ early_stopping=False,
98
+ )
99
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
100
+ t.start()
101
+
102
+ outputs = []
103
+ for text in streamer:
104
+ outputs.append(text)
105
+ yield "".join(outputs)
106
+
107
+
108
+ chat_interface = gr.ChatInterface(
109
+ fn=generate,
110
+ additional_inputs=[
111
+ gr.Slider(
112
+ label="Max new tokens",
113
+ minimum=1,
114
+ maximum=MAX_MAX_NEW_TOKENS,
115
+ step=1,
116
+ value=DEFAULT_MAX_NEW_TOKENS,
117
+ ),
118
+ gr.Slider(
119
+ label="Temperature",
120
+ minimum=0.1,
121
+ maximum=4.0,
122
+ step=0.1,
123
+ value=0.6,
124
+ ),
125
+ gr.Slider(
126
+ label="Top-p (nucleus sampling)",
127
+ minimum=0.05,
128
+ maximum=1.0,
129
+ step=0.05,
130
+ value=0.9,
131
+ ),
132
+ gr.Slider(
133
+ label="Top-k",
134
+ minimum=1,
135
+ maximum=1000,
136
+ step=1,
137
+ value=50,
138
+ ),
139
+ gr.Slider(
140
+ label="Repetition penalty",
141
+ minimum=1.0,
142
+ maximum=2.0,
143
+ step=0.05,
144
+ value=1.4,
145
+ ),
146
+ ],
147
+ stop_btn=None,
148
+ examples=[
149
+ ["A recipe for a chocolate cake:"],
150
+ ["Can you explain briefly to me what is the Python programming language?"],
151
+ ["Explain the plot of Cinderella in a sentence."],
152
+ ["Question: What is the capital of France?\nAnswer:"],
153
+ ["Question: I am very tired, what should I do?\nAnswer:"],
154
+ ],
155
+ )
156
 
157
  with gr.Blocks(css="style.css") as demo:
158
+ gr.Markdown(DESCRIPTION)
159
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
160
+ chat_interface.render()
161
+ gr.Markdown(LICENSE)
 
 
 
 
 
 
 
 
162
 
163
  if __name__ == "__main__":
164
  demo.queue(max_size=20).launch()