Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,55 +1,87 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
|
|
|
3 |
from transformers.image_utils import load_image
|
4 |
from threading import Thread
|
5 |
import re
|
6 |
import time
|
7 |
import torch
|
8 |
-
import spaces
|
9 |
#import subprocess
|
10 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M")
|
14 |
-
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M",
|
15 |
-
torch_dtype=torch.bfloat16,
|
16 |
-
#_attn_implementation="flash_attention_2"
|
17 |
-
).to("cuda")
|
18 |
|
19 |
-
|
20 |
def model_inference(
|
21 |
input_dict, history
|
22 |
):
|
23 |
text = input_dict["text"]
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
|
|
33 |
if text == "" and not images:
|
34 |
gr.Error("Please input a query and optionally image(s).")
|
35 |
|
36 |
if text == "" and images:
|
37 |
gr.Error("Please input a text query along the image(s).")
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
resulting_messages = [
|
43 |
-
{
|
44 |
-
"role": "user",
|
45 |
-
"content": [{"type": "image"} for _ in range(len(images))] + [
|
46 |
-
{"type": "text", "text": text}
|
47 |
-
]
|
48 |
-
}
|
49 |
-
]
|
50 |
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
|
51 |
-
|
52 |
-
inputs =
|
|
|
53 |
generation_args = {
|
54 |
"input_ids": inputs.input_ids,
|
55 |
"pixel_values": inputs.pixel_values,
|
@@ -90,11 +122,12 @@ examples=[
|
|
90 |
[{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
|
91 |
[{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
|
92 |
]
|
93 |
-
demo = gr.ChatInterface(fn=model_inference, title="
|
94 |
-
description="Play with [
|
95 |
examples=examples,
|
96 |
-
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
|
97 |
-
cache_examples=False
|
|
|
98 |
)
|
99 |
|
100 |
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
|
3 |
+
from transformers.models.smolvlm.video_processing_smolvlm import load_smolvlm_video
|
4 |
from transformers.image_utils import load_image
|
5 |
from threading import Thread
|
6 |
import re
|
7 |
import time
|
8 |
import torch
|
9 |
+
#import spaces
|
10 |
#import subprocess
|
11 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
12 |
|
13 |
+
from io import BytesIO
|
14 |
+
from transformers.image_utils import load_image
|
15 |
+
|
16 |
+
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
|
17 |
+
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
18 |
+
_attn_implementation="flash_attention_2",
|
19 |
+
torch_dtype=torch.bfloat16, device_map="auto")
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
#@spaces.GPU
|
23 |
def model_inference(
|
24 |
input_dict, history
|
25 |
):
|
26 |
text = input_dict["text"]
|
27 |
+
# first turn input_dict {'text': 'What', 'files': ['/tmp/gradio/0350274350a64a5737e1a5732f014aee2f28bb7344bbad5105c0d0b7e7334375/cats_2.mp4', '/tmp/gradio/2dd39f382fcf5444a1a2ac57ed6f9acafa775dd855248cf273034e8ce18aeff4/IMG_2201.JPG']}
|
28 |
+
# first turn history []
|
29 |
+
print("input_dict", input_dict)
|
30 |
+
print("history", history)
|
31 |
+
print("model.device", model.device)
|
32 |
+
images = []
|
33 |
+
# first conv turn
|
34 |
+
if history == []:
|
35 |
+
text = input_dict["text"]
|
36 |
+
resulting_messages = [{"role": "user", "content": [{"type": "text"}, {"type": "text", "text": text}]}]
|
37 |
+
for file in input_dict["files"]:
|
38 |
+
if file.endswith(".mp4"):
|
39 |
+
resulting_messages[0]["content"].append({"type": "video"})
|
40 |
+
frames, timestamps, duration_sec = load_smolvlm_video(
|
41 |
+
file, sampling_fps=1, max_frames=64
|
42 |
+
)
|
43 |
+
print("frames", frames)
|
44 |
+
images.append(frames)
|
45 |
+
elif file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
|
46 |
+
resulting_messages[0]["content"].append({"type": "image"})
|
47 |
+
images.append(load_image(file))
|
48 |
+
print("images", images)
|
49 |
+
|
50 |
+
# second turn input_dict {'text': 'what', 'files': ['/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG']}
|
51 |
+
# second turn history [[('/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG',), None],
|
52 |
+
# [('/tmp/gradio/5b105e97e4876912b4e763902144540bd3ab00d9fd4016491337ee4f4c36f320/football.mp4',), None], ['what', None]]
|
53 |
+
|
54 |
+
# later conv turn
|
55 |
+
elif len(history) > 0:
|
56 |
+
for hist in history:
|
57 |
+
if isinstance(hist[0], tuple):
|
58 |
+
if hist[0][0].endswith(".mp4"):
|
59 |
+
resulting_messages.append({"role": "user", "content": [{"type": "video"}, {"type": "text", "text": hist[0][0]}]})
|
60 |
+
frames, timestamps, duration_sec = load_smolvlm_video(
|
61 |
+
file, sampling_fps=1, max_frames=64
|
62 |
+
)
|
63 |
+
images.append(frames)
|
64 |
+
else:
|
65 |
+
resulting_messages.append({"role": "user", "content": [{"type": "image"}, {"type": "text", "text": hist[0][0]}]})
|
66 |
+
images.append(load_image(hist[0][0]))
|
67 |
+
elif isinstance(hist[0], str):
|
68 |
+
resulting_messages.append({"role": "user", "content": [{"type": "text"}, {"type": "text", "text": hist[0]}]})
|
69 |
+
if isinstance(hist[1], str):
|
70 |
+
resulting_messages.append({"role": "user", "content": [{"type": "text"}, {"type": "text", "text": hist[0]}]})
|
71 |
|
72 |
|
73 |
+
|
74 |
if text == "" and not images:
|
75 |
gr.Error("Please input a query and optionally image(s).")
|
76 |
|
77 |
if text == "" and images:
|
78 |
gr.Error("Please input a text query along the image(s).")
|
79 |
|
80 |
+
print("resulting_messages", resulting_messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
|
82 |
+
|
83 |
+
inputs = processor(text=prompt, images=[images], padding=True, return_tensors="pt")
|
84 |
+
inputs = inputs.to(model.device)
|
85 |
generation_args = {
|
86 |
"input_ids": inputs.input_ids,
|
87 |
"pixel_values": inputs.pixel_values,
|
|
|
122 |
[{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
|
123 |
[{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
|
124 |
]
|
125 |
+
demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video Model Ever 📺",
|
126 |
+
description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
|
127 |
examples=examples,
|
128 |
+
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
|
129 |
+
cache_examples=False,
|
130 |
+
type="messages"
|
131 |
)
|
132 |
|
133 |
|