merve HF Staff commited on
Commit
18c7142
·
verified ·
1 Parent(s): 7869234

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -31
app.py CHANGED
@@ -1,55 +1,87 @@
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
 
3
  from transformers.image_utils import load_image
4
  from threading import Thread
5
  import re
6
  import time
7
  import torch
8
- import spaces
9
  #import subprocess
10
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
11
 
 
 
 
 
 
 
 
12
 
13
- processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M")
14
- model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct-250M",
15
- torch_dtype=torch.bfloat16,
16
- #_attn_implementation="flash_attention_2"
17
- ).to("cuda")
18
 
19
- @spaces.GPU
20
  def model_inference(
21
  input_dict, history
22
  ):
23
  text = input_dict["text"]
24
- print(input_dict["files"])
25
- if len(input_dict["files"]) > 1:
26
- images = [load_image(image) for image in input_dict["files"]]
27
- elif len(input_dict["files"]) == 1:
28
- images = [load_image(input_dict["files"][0])]
29
- else:
30
- images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
 
33
  if text == "" and not images:
34
  gr.Error("Please input a query and optionally image(s).")
35
 
36
  if text == "" and images:
37
  gr.Error("Please input a text query along the image(s).")
38
 
39
-
40
-
41
-
42
- resulting_messages = [
43
- {
44
- "role": "user",
45
- "content": [{"type": "image"} for _ in range(len(images))] + [
46
- {"type": "text", "text": text}
47
- ]
48
- }
49
- ]
50
  prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
51
- inputs = processor(text=prompt, images=[images], return_tensors="pt")
52
- inputs = inputs.to('cuda')
 
53
  generation_args = {
54
  "input_ids": inputs.input_ids,
55
  "pixel_values": inputs.pixel_values,
@@ -90,11 +122,12 @@ examples=[
90
  [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
91
  [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
92
  ]
93
- demo = gr.ChatInterface(fn=model_inference, title="SmolVLM-256M: The Smollest VLM ever 💫",
94
- description="Play with [HuggingFaceTB/SmolVLM-Instruct-250M](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct-250M) in this demo. To get started, upload an image and text or try one of the examples. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
95
  examples=examples,
96
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
97
- cache_examples=False
 
98
  )
99
 
100
 
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
3
+ from transformers.models.smolvlm.video_processing_smolvlm import load_smolvlm_video
4
  from transformers.image_utils import load_image
5
  from threading import Thread
6
  import re
7
  import time
8
  import torch
9
+ #import spaces
10
  #import subprocess
11
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
+ from io import BytesIO
14
+ from transformers.image_utils import load_image
15
+
16
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
17
+ model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct",
18
+ _attn_implementation="flash_attention_2",
19
+ torch_dtype=torch.bfloat16, device_map="auto")
20
 
 
 
 
 
 
21
 
22
+ #@spaces.GPU
23
  def model_inference(
24
  input_dict, history
25
  ):
26
  text = input_dict["text"]
27
+ # first turn input_dict {'text': 'What', 'files': ['/tmp/gradio/0350274350a64a5737e1a5732f014aee2f28bb7344bbad5105c0d0b7e7334375/cats_2.mp4', '/tmp/gradio/2dd39f382fcf5444a1a2ac57ed6f9acafa775dd855248cf273034e8ce18aeff4/IMG_2201.JPG']}
28
+ # first turn history []
29
+ print("input_dict", input_dict)
30
+ print("history", history)
31
+ print("model.device", model.device)
32
+ images = []
33
+ # first conv turn
34
+ if history == []:
35
+ text = input_dict["text"]
36
+ resulting_messages = [{"role": "user", "content": [{"type": "text"}, {"type": "text", "text": text}]}]
37
+ for file in input_dict["files"]:
38
+ if file.endswith(".mp4"):
39
+ resulting_messages[0]["content"].append({"type": "video"})
40
+ frames, timestamps, duration_sec = load_smolvlm_video(
41
+ file, sampling_fps=1, max_frames=64
42
+ )
43
+ print("frames", frames)
44
+ images.append(frames)
45
+ elif file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
46
+ resulting_messages[0]["content"].append({"type": "image"})
47
+ images.append(load_image(file))
48
+ print("images", images)
49
+
50
+ # second turn input_dict {'text': 'what', 'files': ['/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG']}
51
+ # second turn history [[('/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG',), None],
52
+ # [('/tmp/gradio/5b105e97e4876912b4e763902144540bd3ab00d9fd4016491337ee4f4c36f320/football.mp4',), None], ['what', None]]
53
+
54
+ # later conv turn
55
+ elif len(history) > 0:
56
+ for hist in history:
57
+ if isinstance(hist[0], tuple):
58
+ if hist[0][0].endswith(".mp4"):
59
+ resulting_messages.append({"role": "user", "content": [{"type": "video"}, {"type": "text", "text": hist[0][0]}]})
60
+ frames, timestamps, duration_sec = load_smolvlm_video(
61
+ file, sampling_fps=1, max_frames=64
62
+ )
63
+ images.append(frames)
64
+ else:
65
+ resulting_messages.append({"role": "user", "content": [{"type": "image"}, {"type": "text", "text": hist[0][0]}]})
66
+ images.append(load_image(hist[0][0]))
67
+ elif isinstance(hist[0], str):
68
+ resulting_messages.append({"role": "user", "content": [{"type": "text"}, {"type": "text", "text": hist[0]}]})
69
+ if isinstance(hist[1], str):
70
+ resulting_messages.append({"role": "user", "content": [{"type": "text"}, {"type": "text", "text": hist[0]}]})
71
 
72
 
73
+
74
  if text == "" and not images:
75
  gr.Error("Please input a query and optionally image(s).")
76
 
77
  if text == "" and images:
78
  gr.Error("Please input a text query along the image(s).")
79
 
80
+ print("resulting_messages", resulting_messages)
 
 
 
 
 
 
 
 
 
 
81
  prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
82
+
83
+ inputs = processor(text=prompt, images=[images], padding=True, return_tensors="pt")
84
+ inputs = inputs.to(model.device)
85
  generation_args = {
86
  "input_ids": inputs.input_ids,
87
  "pixel_values": inputs.pixel_values,
 
122
  [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
123
  [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
124
  ]
125
+ demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video Model Ever 📺",
126
+ description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
127
  examples=examples,
128
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
129
+ cache_examples=False,
130
+ type="messages"
131
  )
132
 
133