merve HF staff commited on
Commit
de4762a
·
verified ·
1 Parent(s): 18c7142

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -59
app.py CHANGED
@@ -1,73 +1,63 @@
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
3
- from transformers.models.smolvlm.video_processing_smolvlm import load_smolvlm_video
4
- from transformers.image_utils import load_image
5
  from threading import Thread
6
  import re
7
  import time
8
  import torch
9
- #import spaces
10
  #import subprocess
11
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
  from io import BytesIO
14
- from transformers.image_utils import load_image
15
 
16
- processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
17
- model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct",
18
  _attn_implementation="flash_attention_2",
19
- torch_dtype=torch.bfloat16, device_map="auto")
20
 
21
 
22
- #@spaces.GPU
23
  def model_inference(
24
- input_dict, history
25
  ):
26
  text = input_dict["text"]
27
- # first turn input_dict {'text': 'What', 'files': ['/tmp/gradio/0350274350a64a5737e1a5732f014aee2f28bb7344bbad5105c0d0b7e7334375/cats_2.mp4', '/tmp/gradio/2dd39f382fcf5444a1a2ac57ed6f9acafa775dd855248cf273034e8ce18aeff4/IMG_2201.JPG']}
28
- # first turn history []
29
- print("input_dict", input_dict)
30
- print("history", history)
31
- print("model.device", model.device)
32
  images = []
33
  # first conv turn
34
  if history == []:
35
  text = input_dict["text"]
36
- resulting_messages = [{"role": "user", "content": [{"type": "text"}, {"type": "text", "text": text}]}]
37
  for file in input_dict["files"]:
38
  if file.endswith(".mp4"):
39
- resulting_messages[0]["content"].append({"type": "video"})
40
- frames, timestamps, duration_sec = load_smolvlm_video(
41
- file, sampling_fps=1, max_frames=64
42
- )
43
- print("frames", frames)
44
- images.append(frames)
45
  elif file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
46
- resulting_messages[0]["content"].append({"type": "image"})
47
- images.append(load_image(file))
48
- print("images", images)
49
-
50
- # second turn input_dict {'text': 'what', 'files': ['/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG']}
51
- # second turn history [[('/tmp/gradio/7bafdcc4722c4b9902a4936439b3bb694927abd72106a946d773a15cc1c630d7/IMG_2198.JPG',), None],
52
- # [('/tmp/gradio/5b105e97e4876912b4e763902144540bd3ab00d9fd4016491337ee4f4c36f320/football.mp4',), None], ['what', None]]
53
-
54
- # later conv turn
55
  elif len(history) > 0:
56
- for hist in history:
57
- if isinstance(hist[0], tuple):
58
- if hist[0][0].endswith(".mp4"):
59
- resulting_messages.append({"role": "user", "content": [{"type": "video"}, {"type": "text", "text": hist[0][0]}]})
60
- frames, timestamps, duration_sec = load_smolvlm_video(
61
- file, sampling_fps=1, max_frames=64
62
- )
63
- images.append(frames)
64
- else:
65
- resulting_messages.append({"role": "user", "content": [{"type": "image"}, {"type": "text", "text": hist[0][0]}]})
66
- images.append(load_image(hist[0][0]))
67
- elif isinstance(hist[0], str):
68
- resulting_messages.append({"role": "user", "content": [{"type": "text"}, {"type": "text", "text": hist[0]}]})
69
- if isinstance(hist[1], str):
70
- resulting_messages.append({"role": "user", "content": [{"type": "text"}, {"type": "text", "text": hist[0]}]})
 
 
 
 
 
 
 
 
 
71
 
72
 
73
 
@@ -75,26 +65,22 @@ def model_inference(
75
  gr.Error("Please input a query and optionally image(s).")
76
 
77
  if text == "" and images:
78
- gr.Error("Please input a text query along the image(s).")
79
 
80
- print("resulting_messages", resulting_messages)
81
- prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
 
 
 
 
 
82
 
83
- inputs = processor(text=prompt, images=[images], padding=True, return_tensors="pt")
84
  inputs = inputs.to(model.device)
85
- generation_args = {
86
- "input_ids": inputs.input_ids,
87
- "pixel_values": inputs.pixel_values,
88
- "attention_mask": inputs.attention_mask,
89
- "num_return_sequences": 1,
90
- "no_repeat_ngram_size": 2,
91
- "max_new_tokens": 500,
92
- "min_new_tokens": 10,
93
- }
94
 
95
  # Generate
96
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
97
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=500)
98
  generated_text = ""
99
 
100
  thread = Thread(target=model.generate, kwargs=generation_args)
@@ -127,6 +113,7 @@ demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video
127
  examples=examples,
128
  textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
129
  cache_examples=False,
 
130
  type="messages"
131
  )
132
 
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
 
 
3
  from threading import Thread
4
  import re
5
  import time
6
  import torch
7
+ import spaces
8
  #import subprocess
9
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
  from io import BytesIO
 
12
 
13
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-500M-Instruct")
14
+ model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM2-500M-Instruct",
15
  _attn_implementation="flash_attention_2",
16
+ torch_dtype=torch.bfloat16).to("cuda:0")
17
 
18
 
19
+ @spaces.GPU
20
  def model_inference(
21
+ input_dict, history, max_tokens
22
  ):
23
  text = input_dict["text"]
 
 
 
 
 
24
  images = []
25
  # first conv turn
26
  if history == []:
27
  text = input_dict["text"]
28
+ resulting_messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
29
  for file in input_dict["files"]:
30
  if file.endswith(".mp4"):
31
+ resulting_messages[0]["content"].append({"type": "video", "path": file})
32
+
 
 
 
 
33
  elif file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
34
+ resulting_messages[0]["content"].append({"type": "image", "path": file})
35
+
 
 
 
 
 
 
 
36
  elif len(history) > 0:
37
+ resulting_messages = []
38
+ for entry in history:
39
+ if entry["role"] == "user":
40
+ user_content = []
41
+ if isinstance(entry["content"], tuple):
42
+ file_name = entry["content"][0]
43
+ if file_name.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
44
+ user_content.append({"type": "image", "path": file_name})
45
+ elif file_name.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
46
+ user_content.append({"type": "video", "path": file_name})
47
+ elif isinstance(entry["content"], str):
48
+ user_content.insert(0, {"type": "text", "text": entry["content"]})
49
+
50
+ elif entry["role"] == "assistant":
51
+ resulting_messages.append({
52
+ "role": "user",
53
+ "content": user_content
54
+ })
55
+ resulting_messages.append({
56
+ "role": "assistant",
57
+ "content": [{"type": "text", "text": entry["content"]}]
58
+ })
59
+ user_content = []
60
+
61
 
62
 
63
 
 
65
  gr.Error("Please input a query and optionally image(s).")
66
 
67
  if text == "" and images:
68
+ gr.Error("Please input a text query along the images(s).")
69
 
70
+ inputs = processor.apply_chat_template(
71
+ resulting_messages,
72
+ add_generation_prompt=True,
73
+ tokenize=True,
74
+ return_dict=True,
75
+ return_tensors="pt",
76
+ )
77
 
 
78
  inputs = inputs.to(model.device)
79
+
 
 
 
 
 
 
 
 
80
 
81
  # Generate
82
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
83
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
84
  generated_text = ""
85
 
86
  thread = Thread(target=model.generate, kwargs=generation_args)
 
113
  examples=examples,
114
  textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
115
  cache_examples=False,
116
+ additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
117
  type="messages"
118
  )
119