Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,73 +1,63 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
|
3 |
-
from transformers.models.smolvlm.video_processing_smolvlm import load_smolvlm_video
|
4 |
-
from transformers.image_utils import load_image
|
5 |
from threading import Thread
|
6 |
import re
|
7 |
import time
|
8 |
import torch
|
9 |
-
|
10 |
#import subprocess
|
11 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
12 |
|
13 |
from io import BytesIO
|
14 |
-
from transformers.image_utils import load_image
|
15 |
|
16 |
-
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-
|
17 |
-
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM2-
|
18 |
_attn_implementation="flash_attention_2",
|
19 |
-
torch_dtype=torch.bfloat16
|
20 |
|
21 |
|
22 |
-
|
23 |
def model_inference(
|
24 |
-
input_dict, history
|
25 |
):
|
26 |
text = input_dict["text"]
|
27 |
-
# first turn input_dict {'text': 'What', 'files': ['/tmp/gradio/0350274350a64a5737e1a5732f014aee2f28bb7344bbad5105c0d0b7e7334375/cats_2.mp4', '/tmp/gradio/2dd39f382fcf5444a1a2ac57ed6f9acafa775dd855248cf273034e8ce18aeff4/IMG_2201.JPG']}
|
28 |
-
# first turn history []
|
29 |
-
print("input_dict", input_dict)
|
30 |
-
print("history", history)
|
31 |
-
print("model.device", model.device)
|
32 |
images = []
|
33 |
# first conv turn
|
34 |
if history == []:
|
35 |
text = input_dict["text"]
|
36 |
-
resulting_messages = [{"role": "user", "content": [{"type": "text"
|
37 |
for file in input_dict["files"]:
|
38 |
if file.endswith(".mp4"):
|
39 |
-
resulting_messages[0]["content"].append({"type": "video"})
|
40 |
-
|
41 |
-
file, sampling_fps=1, max_frames=64
|
42 |
-
)
|
43 |
-
print("frames", frames)
|
44 |
-
images.append(frames)
|
45 |
elif file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
|
46 |
-
resulting_messages[0]["content"].append({"type": "image"})
|
47 |
-
|
48 |
-
print("images", images)
|
49 |
-
|
50 |
-
# second turn input_dict {'text': 'what', 'files': ['/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG']}
|
51 |
-
# second turn history [[('/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG',), None],
|
52 |
-
# [('/tmp/gradio/5b105e97e4876912b4e763902144540bd3ab00d9fd4016491337ee4f4c36f320/football.mp4',), None], ['what', None]]
|
53 |
-
|
54 |
-
# later conv turn
|
55 |
elif len(history) > 0:
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
|
73 |
|
@@ -75,26 +65,22 @@ def model_inference(
|
|
75 |
gr.Error("Please input a query and optionally image(s).")
|
76 |
|
77 |
if text == "" and images:
|
78 |
-
gr.Error("Please input a text query along the
|
79 |
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
inputs = processor(text=prompt, images=[images], padding=True, return_tensors="pt")
|
84 |
inputs = inputs.to(model.device)
|
85 |
-
|
86 |
-
"input_ids": inputs.input_ids,
|
87 |
-
"pixel_values": inputs.pixel_values,
|
88 |
-
"attention_mask": inputs.attention_mask,
|
89 |
-
"num_return_sequences": 1,
|
90 |
-
"no_repeat_ngram_size": 2,
|
91 |
-
"max_new_tokens": 500,
|
92 |
-
"min_new_tokens": 10,
|
93 |
-
}
|
94 |
|
95 |
# Generate
|
96 |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
97 |
-
generation_args = dict(inputs, streamer=streamer, max_new_tokens=
|
98 |
generated_text = ""
|
99 |
|
100 |
thread = Thread(target=model.generate, kwargs=generation_args)
|
@@ -127,6 +113,7 @@ demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video
|
|
127 |
examples=examples,
|
128 |
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
|
129 |
cache_examples=False,
|
|
|
130 |
type="messages"
|
131 |
)
|
132 |
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
|
|
|
|
|
3 |
from threading import Thread
|
4 |
import re
|
5 |
import time
|
6 |
import torch
|
7 |
+
import spaces
|
8 |
#import subprocess
|
9 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
10 |
|
11 |
from io import BytesIO
|
|
|
12 |
|
13 |
+
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-500M-Instruct")
|
14 |
+
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM2-500M-Instruct",
|
15 |
_attn_implementation="flash_attention_2",
|
16 |
+
torch_dtype=torch.bfloat16).to("cuda:0")
|
17 |
|
18 |
|
19 |
+
@spaces.GPU
|
20 |
def model_inference(
|
21 |
+
input_dict, history, max_tokens
|
22 |
):
|
23 |
text = input_dict["text"]
|
|
|
|
|
|
|
|
|
|
|
24 |
images = []
|
25 |
# first conv turn
|
26 |
if history == []:
|
27 |
text = input_dict["text"]
|
28 |
+
resulting_messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
29 |
for file in input_dict["files"]:
|
30 |
if file.endswith(".mp4"):
|
31 |
+
resulting_messages[0]["content"].append({"type": "video", "path": file})
|
32 |
+
|
|
|
|
|
|
|
|
|
33 |
elif file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
|
34 |
+
resulting_messages[0]["content"].append({"type": "image", "path": file})
|
35 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
elif len(history) > 0:
|
37 |
+
resulting_messages = []
|
38 |
+
for entry in history:
|
39 |
+
if entry["role"] == "user":
|
40 |
+
user_content = []
|
41 |
+
if isinstance(entry["content"], tuple):
|
42 |
+
file_name = entry["content"][0]
|
43 |
+
if file_name.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
|
44 |
+
user_content.append({"type": "image", "path": file_name})
|
45 |
+
elif file_name.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
|
46 |
+
user_content.append({"type": "video", "path": file_name})
|
47 |
+
elif isinstance(entry["content"], str):
|
48 |
+
user_content.insert(0, {"type": "text", "text": entry["content"]})
|
49 |
+
|
50 |
+
elif entry["role"] == "assistant":
|
51 |
+
resulting_messages.append({
|
52 |
+
"role": "user",
|
53 |
+
"content": user_content
|
54 |
+
})
|
55 |
+
resulting_messages.append({
|
56 |
+
"role": "assistant",
|
57 |
+
"content": [{"type": "text", "text": entry["content"]}]
|
58 |
+
})
|
59 |
+
user_content = []
|
60 |
+
|
61 |
|
62 |
|
63 |
|
|
|
65 |
gr.Error("Please input a query and optionally image(s).")
|
66 |
|
67 |
if text == "" and images:
|
68 |
+
gr.Error("Please input a text query along the images(s).")
|
69 |
|
70 |
+
inputs = processor.apply_chat_template(
|
71 |
+
resulting_messages,
|
72 |
+
add_generation_prompt=True,
|
73 |
+
tokenize=True,
|
74 |
+
return_dict=True,
|
75 |
+
return_tensors="pt",
|
76 |
+
)
|
77 |
|
|
|
78 |
inputs = inputs.to(model.device)
|
79 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# Generate
|
82 |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
83 |
+
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
|
84 |
generated_text = ""
|
85 |
|
86 |
thread = Thread(target=model.generate, kwargs=generation_args)
|
|
|
113 |
examples=examples,
|
114 |
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
|
115 |
cache_examples=False,
|
116 |
+
additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
|
117 |
type="messages"
|
118 |
)
|
119 |
|