ysharma HF staff commited on
Commit
d2eb8e2
·
1 Parent(s): b519057

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -20
app.py CHANGED
@@ -18,22 +18,84 @@ PLACEHOLDER = """
18
  </div>
19
  """
20
 
21
- model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
 
22
 
23
- processor = AutoProcessor.from_pretrained(model_id)
 
24
 
25
- model = LlavaForConditionalGeneration.from_pretrained(
26
- model_id,
27
  torch_dtype=torch.float16,
28
  low_cpu_mem_usage=True,
29
  )
 
 
30
 
31
- model.to("cuda:0")
32
- model.generation_config.eos_token_id = 128009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  @spaces.GPU
36
- def bot_streaming(message, history):
37
  print(message)
38
  if message["files"]:
39
  # message["files"][-1] is a Dict or just a string
@@ -63,7 +125,7 @@ def bot_streaming(message, history):
63
  streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
64
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
65
 
66
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
67
  thread.start()
68
 
69
  text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
@@ -85,20 +147,55 @@ def bot_streaming(message, history):
85
  yield generated_text_without_prompt
86
 
87
 
88
- chatbot=gr.Chatbot(placeholder=PLACEHOLDER,scale=1)
89
- chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
 
90
  with gr.Blocks(fill_height=True, ) as demo:
91
- gr.ChatInterface(
92
- fn=bot_streaming,
93
- title="LLaVA Llama-3-8B",
94
- examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
95
- {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
96
- description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
97
- stop_btn="Stop Generation",
98
- multimodal=True,
99
- textbox=chat_input,
100
- chatbot=chatbot,
 
 
101
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  demo.queue(api_open=False)
104
  demo.launch(show_api=False, share=False)
 
18
  </div>
19
  """
20
 
21
+ model_id_llama3 = "xtuner/llava-llama-3-8b-v1_1-transformers"
22
+ model_id_phi3 = "xtuner/llava-llama-3-8b-v1_1-transformers"
23
 
24
+ processor = AutoProcessor.from_pretrained(model_id_llama3)
25
+ processor = AutoProcessor.from_pretrained(model_id_phi3)
26
 
27
+ model_llama3 = LlavaForConditionalGeneration.from_pretrained(
28
+ model_id_llama3,
29
  torch_dtype=torch.float16,
30
  low_cpu_mem_usage=True,
31
  )
32
+ model_llama3.to("cuda:0")
33
+ model_llama3.generation_config.eos_token_id = 128009
34
 
35
+ model_phi3 = LlavaForConditionalGeneration.from_pretrained(
36
+ model_id_phi3,
37
+ torch_dtype=torch.float16,
38
+ low_cpu_mem_usage=True,
39
+ )
40
+ model_phi3.to("cuda:0")
41
+ model_phi3.generation_config.eos_token_id = 128009
42
+
43
+
44
+ @spaces.GPU
45
+ def bot_streaming_llama3(message, history):
46
+ print(message)
47
+ if message["files"]:
48
+ # message["files"][-1] is a Dict or just a string
49
+ if type(message["files"][-1]) == dict:
50
+ image = message["files"][-1]["path"]
51
+ else:
52
+ image = message["files"][-1]
53
+ else:
54
+ # if there's no image uploaded for this turn, look for images in the past turns
55
+ # kept inside tuples, take the last one
56
+ for hist in history:
57
+ if type(hist[0]) == tuple:
58
+ image = hist[0][0]
59
+ try:
60
+ if image is None:
61
+ # Handle the case where image is None
62
+ gr.Error("You need to upload an image for LLaVA to work.")
63
+ except NameError:
64
+ # Handle the case where 'image' is not defined at all
65
+ gr.Error("You need to upload an image for LLaVA to work.")
66
+
67
+ prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
68
+ # print(f"prompt: {prompt}")
69
+ image = Image.open(image)
70
+ inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
71
+
72
+ streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
73
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
74
+
75
+ thread = Thread(target=model_llama3.generate, kwargs=generation_kwargs)
76
+ thread.start()
77
+
78
+ text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
79
+ # print(f"text_prompt: {text_prompt}")
80
+
81
+ buffer = ""
82
+ time.sleep(0.5)
83
+ for new_text in streamer:
84
+ # find <|eot_id|> and remove it from the new_text
85
+ if "<|eot_id|>" in new_text:
86
+ new_text = new_text.split("<|eot_id|>")[0]
87
+ buffer += new_text
88
+
89
+ # generated_text_without_prompt = buffer[len(text_prompt):]
90
+ generated_text_without_prompt = buffer
91
+ # print(generated_text_without_prompt)
92
+ time.sleep(0.06)
93
+ # print(f"new_text: {generated_text_without_prompt}")
94
+ yield generated_text_without_prompt
95
 
96
 
97
  @spaces.GPU
98
+ def bot_streaming_phi3(message, history):
99
  print(message)
100
  if message["files"]:
101
  # message["files"][-1] is a Dict or just a string
 
125
  streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
126
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
127
 
128
+ thread = Thread(target=model_phi3.generate, kwargs=generation_kwargs)
129
  thread.start()
130
 
131
  text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
 
147
  yield generated_text_without_prompt
148
 
149
 
150
+ #chatbot=gr.Chatbot(placeholder=PLACEHOLDER,scale=1)
151
+ #chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
152
+
153
  with gr.Blocks(fill_height=True, ) as demo:
154
+ with gr.Row():
155
+ chatbot1 = gr.Chatbot(
156
+ [],
157
+ elem_id="llama3",
158
+ bubble_full_width=False,
159
+ label='LLaVa-Llama3'
160
+ )
161
+ chatbot2 = gr.Chatbot(
162
+ [],
163
+ elem_id="phi3",
164
+ bubble_full_width=False,
165
+ label='LLaVa-Phi3'
166
  )
167
+
168
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
169
+
170
+ gr.Examples(examples=[[{"text": "What is on the flower?", "files": ["./bee.png"]}],],
171
+ {"text": "How to make this pastry?", "files": ["./baklava.png"]},],
172
+ inputs=chat_input)
173
+
174
+ #chat_input.submit(lambda: gr.MultimodalTextbox(interactive=False), None, [chat_input]).then(bot_streaming_llama3, [chat_input, chatbot1,], [chatbot1,])
175
+
176
+ chat_msg1 = chat_input.submit(bot_streaming_llama3, [chat_input, chatbot1,], [chatbot1,])
177
+ chat_msg2 = chat_input.submit(bot_streaming_phi3, [chat_input, chatbot2,], [chatbot2,])
178
+
179
+ #bot_msg1 = chat_msg1.then(bot, chatbot1, chatbot1, api_name="bot_response1")
180
+ #chat_msg1.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
181
+ #bot_msg2 = chat_msg2.then(bot, chatbot2, chatbot2, api_name="bot_response2")
182
+ #bot_msg2.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
183
+
184
+ chatbot1.like(print_like_dislike, None, None)
185
+ chatbot2.like(print_like_dislike, None, None)
186
+
187
+
188
+ #gr.ChatInterface(
189
+ #fn=bot_streaming_llama3,
190
+ #title="LLaVA Llama-3-8B",
191
+ #examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
192
+ # {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
193
+ #description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
194
+ #stop_btn="Stop Generation",
195
+ #multimodal=True,
196
+ #textbox=chat_input,
197
+ #chatbot=chatbot,
198
+ #)
199
 
200
  demo.queue(api_open=False)
201
  demo.launch(show_api=False, share=False)