MadsGalsgaard commited on
Commit
786bf8b
·
verified ·
1 Parent(s): b82d162
Files changed (1) hide show
  1. app.py +191 -161
app.py CHANGED
@@ -115,171 +115,201 @@
115
 
116
  ### 20aug
117
 
118
- import os
119
- import time
120
- import spaces
121
- import torch
122
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
123
- import gradio as gr
124
- from threading import Thread
125
-
126
- MODEL_LIST = ["meta-llama/Meta-Llama-3.1-8B-Instruct"]
127
- HF_TOKEN = os.environ.get("HF_API_TOKEN", None)
128
- MODEL = os.environ.get("MODEL_ID")
129
-
130
- TITLE = "<h1><center>Meta-Llama3.1-8B</center></h1>"
131
-
132
- PLACEHOLDER = """
133
- <center>
134
- <p>Hi! How can I help you today?</p>
135
- </center>
136
- """
137
-
138
-
139
- CSS = """
140
- .duplicate-button {
141
- margin: auto !important;
142
- color: white !important;
143
- background: black !important;
144
- border-radius: 100vh !important;
145
- }
146
- h3 {
147
- text-align: center;
148
- }
149
- """
150
-
151
- device = "cuda" # for GPU usage or "cpu" for CPU usage
152
-
153
- quantization_config = BitsAndBytesConfig(
154
- load_in_4bit=True,
155
- bnb_4bit_compute_dtype=torch.bfloat16,
156
- bnb_4bit_use_double_quant=True,
157
- bnb_4bit_quant_type= "nf4")
158
-
159
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
160
- model = AutoModelForCausalLM.from_pretrained(
161
- MODEL,
162
- torch_dtype=torch.bfloat16,
163
- device_map="auto",
164
- quantization_config=quantization_config)
165
-
166
- @spaces.GPU()
167
- def stream_chat(
168
- message: str,
169
- history: list,
170
- system_prompt: str,
171
- temperature: float = 0.8,
172
- max_new_tokens: int = 1024,
173
- top_p: float = 1.0,
174
- top_k: int = 20,
175
- penalty: float = 1.2,
176
- ):
177
- print(f'message: {message}')
178
- print(f'history: {history}')
179
-
180
- conversation = [
181
- {"role": "system", "content": system_prompt}
182
- ]
183
- for prompt, answer in history:
184
- conversation.extend([
185
- {"role": "user", "content": prompt},
186
- {"role": "assistant", "content": answer},
187
- ])
188
-
189
- conversation.append({"role": "user", "content": message})
190
-
191
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
192
 
193
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
194
 
195
- generate_kwargs = dict(
196
- input_ids=input_ids,
197
- max_new_tokens = max_new_tokens,
198
- do_sample = False if temperature == 0 else True,
199
- top_p = top_p,
200
- top_k = top_k,
201
- temperature = temperature,
202
- repetition_penalty=penalty,
203
- eos_token_id=[128001,128008,128009],
204
- streamer=streamer,
205
- )
206
-
207
- with torch.no_grad():
208
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
209
- thread.start()
210
 
211
- buffer = ""
212
- for new_text in streamer:
213
- buffer += new_text
214
- yield buffer
215
 
216
 
217
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
218
-
219
- with gr.Blocks(css=CSS, theme="soft") as demo:
220
- gr.HTML(TITLE)
221
- gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
222
- gr.ChatInterface(
223
- fn=stream_chat,
224
- chatbot=chatbot,
225
- fill_height=True,
226
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
227
- additional_inputs=[
228
- gr.Textbox(
229
- value="You are a helpful assistant",
230
- label="System Prompt",
231
- render=False,
232
- ),
233
- gr.Slider(
234
- minimum=0,
235
- maximum=1,
236
- step=0.1,
237
- value=0.8,
238
- label="Temperature",
239
- render=False,
240
- ),
241
- gr.Slider(
242
- minimum=128,
243
- maximum=8192,
244
- step=1,
245
- value=1024,
246
- label="Max new tokens",
247
- render=False,
248
- ),
249
- gr.Slider(
250
- minimum=0.0,
251
- maximum=1.0,
252
- step=0.1,
253
- value=1.0,
254
- label="top_p",
255
- render=False,
256
- ),
257
- gr.Slider(
258
- minimum=1,
259
- maximum=20,
260
- step=1,
261
- value=20,
262
- label="top_k",
263
- render=False,
264
- ),
265
- gr.Slider(
266
- minimum=0.0,
267
- maximum=2.0,
268
- step=0.1,
269
- value=1.2,
270
- label="Repetition penalty",
271
- render=False,
272
- ),
273
- ],
274
- examples=[
275
- ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
276
- ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
277
- ["Tell me a random fun fact about the Roman Empire."],
278
- ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
279
- ],
280
- cache_examples=False,
281
- )
 
 
 
 
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
- if __name__ == "__main__":
285
- demo.launch()
 
 
115
 
116
  ### 20aug
117
 
118
+ # import os
119
+ # import time
120
+ # import spaces
121
+ # import torch
122
+ # from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
123
+ # import gradio as gr
124
+ # from threading import Thread
125
+
126
+ # MODEL_LIST = ["meta-llama/Meta-Llama-3.1-8B-Instruct"]
127
+ # HF_TOKEN = os.environ.get("HF_API_TOKEN", None)
128
+ # MODEL = os.environ.get("MODEL_ID")
129
+
130
+ # TITLE = "<h1><center>Meta-Llama3.1-8B</center></h1>"
131
+
132
+ # PLACEHOLDER = """
133
+ # <center>
134
+ # <p>Hi! How can I help you today?</p>
135
+ # </center>
136
+ # """
137
+
138
+
139
+ # CSS = """
140
+ # .duplicate-button {
141
+ # margin: auto !important;
142
+ # color: white !important;
143
+ # background: black !important;
144
+ # border-radius: 100vh !important;
145
+ # }
146
+ # h3 {
147
+ # text-align: center;
148
+ # }
149
+ # """
150
+
151
+ # device = "cuda" # for GPU usage or "cpu" for CPU usage
152
+
153
+ # quantization_config = BitsAndBytesConfig(
154
+ # load_in_4bit=True,
155
+ # bnb_4bit_compute_dtype=torch.bfloat16,
156
+ # bnb_4bit_use_double_quant=True,
157
+ # bnb_4bit_quant_type= "nf4")
158
+
159
+ # tokenizer = AutoTokenizer.from_pretrained(MODEL)
160
+ # model = AutoModelForCausalLM.from_pretrained(
161
+ # MODEL,
162
+ # torch_dtype=torch.bfloat16,
163
+ # device_map="auto",
164
+ # quantization_config=quantization_config)
165
+
166
+ # @spaces.GPU()
167
+ # def stream_chat(
168
+ # message: str,
169
+ # history: list,
170
+ # system_prompt: str,
171
+ # temperature: float = 0.8,
172
+ # max_new_tokens: int = 1024,
173
+ # top_p: float = 1.0,
174
+ # top_k: int = 20,
175
+ # penalty: float = 1.2,
176
+ # ):
177
+ # print(f'message: {message}')
178
+ # print(f'history: {history}')
179
+
180
+ # conversation = [
181
+ # {"role": "system", "content": system_prompt}
182
+ # ]
183
+ # for prompt, answer in history:
184
+ # conversation.extend([
185
+ # {"role": "user", "content": prompt},
186
+ # {"role": "assistant", "content": answer},
187
+ # ])
188
+
189
+ # conversation.append({"role": "user", "content": message})
190
+
191
+ # input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
192
 
193
+ # streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
194
 
195
+ # generate_kwargs = dict(
196
+ # input_ids=input_ids,
197
+ # max_new_tokens = max_new_tokens,
198
+ # do_sample = False if temperature == 0 else True,
199
+ # top_p = top_p,
200
+ # top_k = top_k,
201
+ # temperature = temperature,
202
+ # repetition_penalty=penalty,
203
+ # eos_token_id=[128001,128008,128009],
204
+ # streamer=streamer,
205
+ # )
206
+
207
+ # with torch.no_grad():
208
+ # thread = Thread(target=model.generate, kwargs=generate_kwargs)
209
+ # thread.start()
210
 
211
+ # buffer = ""
212
+ # for new_text in streamer:
213
+ # buffer += new_text
214
+ # yield buffer
215
 
216
 
217
+ # chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
218
+
219
+ # with gr.Blocks(css=CSS, theme="soft") as demo:
220
+ # gr.HTML(TITLE)
221
+ # gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
222
+ # gr.ChatInterface(
223
+ # fn=stream_chat,
224
+ # chatbot=chatbot,
225
+ # fill_height=True,
226
+ # additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
227
+ # additional_inputs=[
228
+ # gr.Textbox(
229
+ # value="You are a helpful assistant",
230
+ # label="System Prompt",
231
+ # render=False,
232
+ # ),
233
+ # gr.Slider(
234
+ # minimum=0,
235
+ # maximum=1,
236
+ # step=0.1,
237
+ # value=0.8,
238
+ # label="Temperature",
239
+ # render=False,
240
+ # ),
241
+ # gr.Slider(
242
+ # minimum=128,
243
+ # maximum=8192,
244
+ # step=1,
245
+ # value=1024,
246
+ # label="Max new tokens",
247
+ # render=False,
248
+ # ),
249
+ # gr.Slider(
250
+ # minimum=0.0,
251
+ # maximum=1.0,
252
+ # step=0.1,
253
+ # value=1.0,
254
+ # label="top_p",
255
+ # render=False,
256
+ # ),
257
+ # gr.Slider(
258
+ # minimum=1,
259
+ # maximum=20,
260
+ # step=1,
261
+ # value=20,
262
+ # label="top_k",
263
+ # render=False,
264
+ # ),
265
+ # gr.Slider(
266
+ # minimum=0.0,
267
+ # maximum=2.0,
268
+ # step=0.1,
269
+ # value=1.2,
270
+ # label="Repetition penalty",
271
+ # render=False,
272
+ # ),
273
+ # ],
274
+ # examples=[
275
+ # ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
276
+ # ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
277
+ # ["Tell me a random fun fact about the Roman Empire."],
278
+ # ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
279
+ # ],
280
+ # cache_examples=False,
281
+ # )
282
+
283
+
284
+ # if __name__ == "__main__":
285
+ # demo.launch()
286
 
287
+ import os
288
+ import gradio as gr
289
+ from huggingface_hub import InferenceClient
290
+
291
+
292
+ # Your Hugging Face configuration
293
+ model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
294
+ # token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
295
+
296
+ # Initialize Inference Client with model and token
297
+ inference_client = InferenceClient()
298
+
299
+ def chat_completion(prompt):
300
+ # Pass user input through Hugging Face model
301
+ response = inference_client.chat(
302
+ model=model_name,
303
+ messages=[{"role": "user", "content": prompt}],
304
+ max_tokens=500,
305
+ stream=False
306
+ )
307
+
308
+ # Extract content from the response
309
+ response_text = response['choices'][0]['delta']['content']
310
+
311
+ return response_text
312
 
313
+ # Create Gradio chat interface
314
+ chatbot = gr.ChatInterface(fn=chat_completion)
315
+ chatbot.launch()