prithivMLmods commited on
Commit
9e55e35
·
verified ·
1 Parent(s): c27c463

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -17
app.py CHANGED
@@ -23,15 +23,24 @@ from transformers.image_utils import load_image
23
  # Constants for text generation
24
  MAX_MAX_NEW_TOKENS = 2048
25
  DEFAULT_MAX_NEW_TOKENS = 1024
26
- # Increase or disable input truncation to avoid token mismatches
27
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
28
 
29
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30
 
31
- MODEL_ID = "nvidia/Cosmos-Reason1-7B"
32
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
33
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
34
- MODEL_ID,
 
 
 
 
 
 
 
 
 
35
  trust_remote_code=True,
36
  torch_dtype=torch.float16
37
  ).to("cuda").eval()
@@ -45,13 +54,12 @@ def downsample_video(video_path):
45
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
46
  fps = vidcap.get(cv2.CAP_PROP_FPS)
47
  frames = []
48
- # Sample 10 evenly spaced frames.
49
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
50
  for i in frame_indices:
51
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
52
  success, image = vidcap.read()
53
  if success:
54
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
55
  pil_image = Image.fromarray(image)
56
  timestamp = round(i / fps, 2)
57
  frames.append((pil_image, timestamp))
@@ -59,15 +67,25 @@ def downsample_video(video_path):
59
  return frames
60
 
61
  @spaces.GPU
62
- def generate_image(text: str, image: Image.Image,
63
  max_new_tokens: int = 1024,
64
  temperature: float = 0.6,
65
  top_p: float = 0.9,
66
  top_k: int = 50,
67
  repetition_penalty: float = 1.2):
68
  """
69
- Generates responses using the Cosmos-Reason1 model for image input.
70
  """
 
 
 
 
 
 
 
 
 
 
71
  if image is None:
72
  yield "Please upload an image."
73
  return
@@ -90,7 +108,7 @@ def generate_image(text: str, image: Image.Image,
90
  ).to("cuda")
91
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
92
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
93
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
94
  thread.start()
95
  buffer = ""
96
  for new_text in streamer:
@@ -100,15 +118,25 @@ def generate_image(text: str, image: Image.Image,
100
  yield buffer
101
 
102
  @spaces.GPU
103
- def generate_video(text: str, video_path: str,
104
  max_new_tokens: int = 1024,
105
  temperature: float = 0.6,
106
  top_p: float = 0.9,
107
  top_k: int = 50,
108
  repetition_penalty: float = 1.2):
109
  """
110
- Generates responses using the Cosmos-Reason1 model for video input.
111
  """
 
 
 
 
 
 
 
 
 
 
112
  if video_path is None:
113
  yield "Please upload a video."
114
  return
@@ -118,7 +146,6 @@ def generate_video(text: str, video_path: str,
118
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
119
  {"role": "user", "content": [{"type": "text", "text": text}]}
120
  ]
121
- # Append each frame with its timestamp.
122
  for frame in frames:
123
  image, timestamp = frame
124
  messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
@@ -143,7 +170,7 @@ def generate_video(text: str, video_path: str,
143
  "top_k": top_k,
144
  "repetition_penalty": repetition_penalty,
145
  }
146
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
147
  thread.start()
148
  buffer = ""
149
  for new_text in streamer:
@@ -163,7 +190,6 @@ video_examples = [
163
  ["Identify the main actions in the video", "videos/2.mp4"]
164
  ]
165
 
166
-
167
  css = """
168
  .submit-btn {
169
  background-color: #2980b9 !important;
@@ -176,13 +202,17 @@ css = """
176
 
177
  # Create the Gradio Interface
178
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
179
- gr.Markdown("# **Cosmos-Reason1 by [NVIDIA](https://huggingface.co/nvidia/Cosmos-Reason1-7B)**")
180
  with gr.Row():
181
  with gr.Column():
182
  with gr.Tabs():
183
  with gr.TabItem("Image Inference"):
184
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
185
  image_upload = gr.Image(type="pil", label="Image")
 
 
 
 
186
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
187
  gr.Examples(
188
  examples=image_examples,
@@ -191,6 +221,10 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
191
  with gr.TabItem("Video Inference"):
192
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
193
  video_upload = gr.Video(label="Video")
 
 
 
 
194
  video_submit = gr.Button("Submit", elem_classes="submit-btn")
195
  gr.Examples(
196
  examples=video_examples,
@@ -208,12 +242,12 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
208
 
209
  image_submit.click(
210
  fn=generate_image,
211
- inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
212
  outputs=output
213
  )
214
  video_submit.click(
215
  fn=generate_video,
216
- inputs=[video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
217
  outputs=output
218
  )
219
 
 
23
  # Constants for text generation
24
  MAX_MAX_NEW_TOKENS = 2048
25
  DEFAULT_MAX_NEW_TOKENS = 1024
 
26
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
27
 
28
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
 
30
+ # Load Cosmos-Reason1-7B
31
+ MODEL_ID_M = "nvidia/Cosmos-Reason1-7B"
32
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
33
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
34
+ MODEL_ID_M,
35
+ trust_remote_code=True,
36
+ torch_dtype=torch.float16
37
+ ).to("cuda").eval()
38
+
39
+ # Load MiMo-VL-7B-RL
40
+ MODEL_ID_X = "XiaomiMiMo/MiMo-VL-7B-RL"
41
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
42
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
43
+ MODEL_ID_X,
44
  trust_remote_code=True,
45
  torch_dtype=torch.float16
46
  ).to("cuda").eval()
 
54
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
55
  fps = vidcap.get(cv2.CAP_PROP_FPS)
56
  frames = []
 
57
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
58
  for i in frame_indices:
59
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
60
  success, image = vidcap.read()
61
  if success:
62
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
63
  pil_image = Image.fromarray(image)
64
  timestamp = round(i / fps, 2)
65
  frames.append((pil_image, timestamp))
 
67
  return frames
68
 
69
  @spaces.GPU
70
+ def generate_image(model_name: str, text: str, image: Image.Image,
71
  max_new_tokens: int = 1024,
72
  temperature: float = 0.6,
73
  top_p: float = 0.9,
74
  top_k: int = 50,
75
  repetition_penalty: float = 1.2):
76
  """
77
+ Generates responses using the selected model for image input.
78
  """
79
+ if model_name == "Cosmos-Reason1-7B":
80
+ processor = processor_m
81
+ model = model_m
82
+ elif model_name == "MiMo-VL-7B-RL":
83
+ processor = processor_x
84
+ model = model_x
85
+ else:
86
+ yield "Invalid model selected."
87
+ return
88
+
89
  if image is None:
90
  yield "Please upload an image."
91
  return
 
108
  ).to("cuda")
109
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
110
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
111
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
112
  thread.start()
113
  buffer = ""
114
  for new_text in streamer:
 
118
  yield buffer
119
 
120
  @spaces.GPU
121
+ def generate_video(model_name: str, text: str, video_path: str,
122
  max_new_tokens: int = 1024,
123
  temperature: float = 0.6,
124
  top_p: float = 0.9,
125
  top_k: int = 50,
126
  repetition_penalty: float = 1.2):
127
  """
128
+ Generates responses using the selected model for video input.
129
  """
130
+ if model_name == "Cosmos-Reason1-7B":
131
+ processor = processor_m
132
+ model = model_m
133
+ elif model_name == "MiMo-VL-7B-RL":
134
+ processor = processor_x
135
+ model = model_x
136
+ else:
137
+ yield "Invalid model selected."
138
+ return
139
+
140
  if video_path is None:
141
  yield "Please upload a video."
142
  return
 
146
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
147
  {"role": "user", "content": [{"type": "text", "text": text}]}
148
  ]
 
149
  for frame in frames:
150
  image, timestamp = frame
151
  messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
 
170
  "top_k": top_k,
171
  "repetition_penalty": repetition_penalty,
172
  }
173
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
174
  thread.start()
175
  buffer = ""
176
  for new_text in streamer:
 
190
  ["Identify the main actions in the video", "videos/2.mp4"]
191
  ]
192
 
 
193
  css = """
194
  .submit-btn {
195
  background-color: #2980b9 !important;
 
202
 
203
  # Create the Gradio Interface
204
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
205
+ gr.Markdown("# **Vision-Language Model Inference**")
206
  with gr.Row():
207
  with gr.Column():
208
  with gr.Tabs():
209
  with gr.TabItem("Image Inference"):
210
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
211
  image_upload = gr.Image(type="pil", label="Image")
212
+ model_choice = gr.Dropdown(
213
+ choices=["Cosmos-Reason1-7B", "MiMo-VL-7B-RL"],
214
+ label="Select Model",
215
+ value="Cosmos-Reason1-7B")
216
  image_submit = gr.Button("Submit", elem_classes="submit-btn")
217
  gr.Examples(
218
  examples=image_examples,
 
221
  with gr.TabItem("Video Inference"):
222
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
223
  video_upload = gr.Video(label="Video")
224
+ model_choice = gr.Dropdown(
225
+ choices=["Cosmos-Reason1-7B", "MiMo-VL-7B-RL"],
226
+ label="Select Model",
227
+ value="Cosmos-Reason1-7B")
228
  video_submit = gr.Button("Submit", elem_classes="submit-btn")
229
  gr.Examples(
230
  examples=video_examples,
 
242
 
243
  image_submit.click(
244
  fn=generate_image,
245
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
246
  outputs=output
247
  )
248
  video_submit.click(
249
  fn=generate_video,
250
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
251
  outputs=output
252
  )
253