prithivMLmods commited on
Commit
55a7e0e
·
verified ·
1 Parent(s): ed406dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -118
app.py CHANGED
@@ -15,12 +15,10 @@ from PIL import Image
15
  import cv2
16
 
17
  from transformers import (
18
- AutoModelForCausalLM,
19
- AutoTokenizer,
20
- TextIteratorStreamer,
21
- Qwen2VLForConditionalGeneration,
22
  AutoProcessor,
23
  Gemma3ForConditionalGeneration,
 
 
24
  )
25
  from transformers.image_utils import load_image
26
 
@@ -38,7 +36,7 @@ def progress_bar_html(label: str) -> str:
38
  <div style="display: flex; align-items: center;">
39
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
40
  <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
41
- <div style="width: 100%; height: 100%; background-color: #00FF00 ; animation: loading 1.5s linear infinite;"></div>
42
  </div>
43
  </div>
44
  <style>
@@ -49,18 +47,7 @@ def progress_bar_html(label: str) -> str:
49
  </style>
50
  '''
51
 
52
- # TEXT MODEL
53
-
54
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
55
- tokenizer = AutoTokenizer.from_pretrained(model_id)
56
- model = AutoModelForCausalLM.from_pretrained(
57
- model_id,
58
- device_map="auto",
59
- torch_dtype=torch.bfloat16,
60
- )
61
- model.eval()
62
-
63
- # MULTIMODAL (OCR) MODELS
64
 
65
  MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
66
  processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
@@ -102,7 +89,8 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
102
 
103
  dtype = torch.float16 if device.type == "cuda" else torch.float32
104
 
105
- # GEMMA3-4B MULTIMODAL MODEL
 
106
 
107
  gemma3_model_id = "google/gemma-3-4b-it" # alternative: google/gemma-3-12b-it
108
  gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
@@ -111,6 +99,7 @@ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
111
  gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
112
 
113
  # VIDEO PROCESSING HELPER
 
114
  def downsample_video(video_path):
115
  vidcap = cv2.VideoCapture(video_path)
116
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -144,15 +133,12 @@ def generate(
144
  ):
145
  text = input_dict["text"]
146
  files = input_dict.get("files", [])
147
-
148
  lower_text = text.lower().strip()
149
 
150
- # GEMMA3-4B TEXT & MULTIMODAL (image) Branch
151
- if lower_text.startswith("@gemma3"):
152
- # Remove the gemma3 flag from the prompt.
153
- prompt_clean = re.sub(r"@gemma3", "", text, flags=re.IGNORECASE).strip().strip('"')
154
  if files:
155
- # If image files are provided, load them.
156
  images = [load_image(f) for f in files]
157
  messages = [{
158
  "role": "user",
@@ -161,18 +147,18 @@ def generate(
161
  {"type": "text", "text": prompt_clean},
162
  ]
163
  }]
 
 
164
  else:
165
  messages = [
166
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
167
  {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
168
  ]
169
- inputs = gemma3_processor.apply_chat_template(
170
- messages, add_generation_prompt=True, tokenize=True,
171
- return_dict=True, return_tensors="pt"
172
- ).to(gemma3_model.device, dtype=torch.bfloat16)
173
- streamer = TextIteratorStreamer(
174
- gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
175
- )
176
  generation_kwargs = {
177
  **inputs,
178
  "streamer": streamer,
@@ -183,47 +169,107 @@ def generate(
183
  "top_k": top_k,
184
  "repetition_penalty": repetition_penalty,
185
  }
186
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
187
  thread.start()
188
  buffer = ""
189
- yield progress_bar_html("Processing with Gemma3")
190
  for new_text in streamer:
191
  buffer += new_text
 
192
  time.sleep(0.01)
193
  yield buffer
194
  return
195
 
196
- # GEMMA3-4B VIDEO Branch
197
- if lower_text.startswith("@video-infer"):
198
- # Remove the video flag from the prompt.
199
- prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
200
- if files:
201
- # Assume the first file is a video.
 
202
  video_path = files[0]
203
  frames = downsample_video(video_path)
204
  messages = [
205
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
206
  {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
207
  ]
208
- # Append each frame as an image with a timestamp label.
209
  for frame in frames:
210
  image, timestamp = frame
211
  image_path = f"video_frame_{uuid.uuid4().hex}.png"
212
  image.save(image_path)
213
  messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
214
  messages[1]["content"].append({"type": "image", "url": image_path})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  else:
216
- messages = [
217
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
218
- {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
219
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  inputs = gemma3_processor.apply_chat_template(
221
  messages, add_generation_prompt=True, tokenize=True,
222
  return_dict=True, return_tensors="pt"
223
  ).to(gemma3_model.device, dtype=torch.bfloat16)
224
- streamer = TextIteratorStreamer(
225
- gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
226
- )
227
  generation_kwargs = {
228
  **inputs,
229
  "streamer": streamer,
@@ -236,70 +282,16 @@ def generate(
236
  }
237
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
238
  thread.start()
239
- buffer = ""
240
- yield progress_bar_html("Processing video with Gemma3")
241
- for new_text in streamer:
242
- buffer += new_text
243
- time.sleep(0.01)
244
- yield buffer
245
- return
246
-
247
- # Otherwise, handle text/chat generation.
248
- conversation = clean_chat_history(chat_history)
249
- conversation.append({"role": "user", "content": text})
250
-
251
- if files:
252
- images = [load_image(image) for image in files] if len(files) > 1 else [load_image(files[0])]
253
- messages = [{
254
- "role": "user",
255
- "content": [
256
- *[{"type": "image", "image": image} for image in images],
257
- {"type": "text", "text": text},
258
- ]
259
- }]
260
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
261
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
262
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
263
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
264
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
265
- thread.start()
266
-
267
- buffer = ""
268
- yield progress_bar_html("Processing with Qwen2VL OCR")
269
- for new_text in streamer:
270
- buffer += new_text
271
- buffer = buffer.replace("<|im_end|>", "")
272
- time.sleep(0.01)
273
- yield buffer
274
- else:
275
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
276
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
277
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
278
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
279
- input_ids = input_ids.to(model.device)
280
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
281
- generation_kwargs = {
282
- "input_ids": input_ids,
283
- "streamer": streamer,
284
- "max_new_tokens": max_new_tokens,
285
- "do_sample": True,
286
- "top_p": top_p,
287
- "top_k": top_k,
288
- "temperature": temperature,
289
- "num_beams": 1,
290
- "repetition_penalty": repetition_penalty,
291
- }
292
- t = Thread(target=model.generate, kwargs=generation_kwargs)
293
- t.start()
294
-
295
  outputs = []
296
  for new_text in streamer:
297
  outputs.append(new_text)
298
  yield "".join(outputs)
299
-
300
  final_response = "".join(outputs)
301
  yield final_response
302
 
 
 
 
303
  demo = gr.ChatInterface(
304
  fn=generate,
305
  additional_inputs=[
@@ -312,7 +304,7 @@ demo = gr.ChatInterface(
312
  examples=[
313
  [
314
  {
315
- "text": "@gemma3 Create a short story based on the images.",
316
  "files": [
317
  "examples/1111.jpg",
318
  "examples/2222.jpg",
@@ -320,24 +312,24 @@ demo = gr.ChatInterface(
320
  ],
321
  }
322
  ],
323
- [{"text": "@gemma3 Explain the Image", "files": ["examples/3.jpg"]}],
324
- [{"text": "@video-infer Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
325
- [{"text": "@gemma3 Which movie character is this?", "files": ["examples/9999.jpg"]}],
326
- ["@gemma3 Explain Critical Temperature of Substance"],
327
- [{"text": "@gemma3 Transcription of the letter", "files": ["examples/222.png"]}],
328
- [{"text": "@video-infer Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
329
- [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
330
- [{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
331
- [{"text": "@video-infer Summarize the events in this video", "files": ["examples/sky.mp4"]}],
332
- [{"text": "@video-infer What is in the video ?", "files": ["examples/redlight.mp4"]}],
333
  ["Python Program for Array Rotation"],
334
- ["@gemma3 Explain Critical Temperature of Substance"]
335
  ],
336
  cache_examples=False,
337
  type="messages",
338
- description="# **Gemma 3 `@gemma3, @video-infer for video understanding`**",
339
  fill_height=True,
340
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Tag--> @gemma3 for multimodal, @video-infer for video !"),
341
  stop_btn="Stop Generation",
342
  multimodal=True,
343
  )
 
15
  import cv2
16
 
17
  from transformers import (
 
 
 
 
18
  AutoProcessor,
19
  Gemma3ForConditionalGeneration,
20
+ Qwen2VLForConditionalGeneration,
21
+ TextIteratorStreamer,
22
  )
23
  from transformers.image_utils import load_image
24
 
 
36
  <div style="display: flex; align-items: center;">
37
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
38
  <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
39
+ <div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div>
40
  </div>
41
  </div>
42
  <style>
 
47
  </style>
48
  '''
49
 
50
+ # Qwen2-VL (for optional image inference)
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
53
  processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
 
89
 
90
  dtype = torch.float16 if device.type == "cuda" else torch.float32
91
 
92
+
93
+ # Gemma3 Model (default for text, image, & video inference)
94
 
95
  gemma3_model_id = "google/gemma-3-4b-it" # alternative: google/gemma-3-12b-it
96
  gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
 
99
  gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
100
 
101
  # VIDEO PROCESSING HELPER
102
+
103
  def downsample_video(video_path):
104
  vidcap = cv2.VideoCapture(video_path)
105
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
 
133
  ):
134
  text = input_dict["text"]
135
  files = input_dict.get("files", [])
 
136
  lower_text = text.lower().strip()
137
 
138
+ # ----- Qwen2-VL branch (triggered with @qwen2-vl) -----
139
+ if lower_text.startswith("@qwen2-vl"):
140
+ prompt_clean = re.sub(r"@qwen2-vl", "", text, flags=re.IGNORECASE).strip().strip('"')
 
141
  if files:
 
142
  images = [load_image(f) for f in files]
143
  messages = [{
144
  "role": "user",
 
147
  {"type": "text", "text": prompt_clean},
148
  ]
149
  }]
150
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
151
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
152
  else:
153
  messages = [
154
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
155
  {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
156
  ]
157
+ inputs = processor.apply_chat_template(
158
+ messages, add_generation_prompt=True, tokenize=True,
159
+ return_dict=True, return_tensors="pt"
160
+ ).to("cuda", dtype=torch.float16)
161
+ streamer = TextIteratorStreamer(processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
162
  generation_kwargs = {
163
  **inputs,
164
  "streamer": streamer,
 
169
  "top_k": top_k,
170
  "repetition_penalty": repetition_penalty,
171
  }
172
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
173
  thread.start()
174
  buffer = ""
175
+ yield progress_bar_html("Processing with Qwen2VL")
176
  for new_text in streamer:
177
  buffer += new_text
178
+ buffer = buffer.replace("<|im_end|>", "")
179
  time.sleep(0.01)
180
  yield buffer
181
  return
182
 
183
+ # ----- Default branch: Gemma3 (for text, image, & video inference) -----
184
+ if files:
185
+ # Check if any provided file is a video based on extension.
186
+ video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
187
+ if any(str(f).lower().endswith(video_extensions) for f in files):
188
+ # Video inference branch.
189
+ prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
190
  video_path = files[0]
191
  frames = downsample_video(video_path)
192
  messages = [
193
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
194
  {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
195
  ]
196
+ # Append each frame (with its timestamp) to the conversation.
197
  for frame in frames:
198
  image, timestamp = frame
199
  image_path = f"video_frame_{uuid.uuid4().hex}.png"
200
  image.save(image_path)
201
  messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
202
  messages[1]["content"].append({"type": "image", "url": image_path})
203
+ inputs = gemma3_processor.apply_chat_template(
204
+ messages, add_generation_prompt=True, tokenize=True,
205
+ return_dict=True, return_tensors="pt"
206
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
207
+ streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
208
+ generation_kwargs = {
209
+ **inputs,
210
+ "streamer": streamer,
211
+ "max_new_tokens": max_new_tokens,
212
+ "do_sample": True,
213
+ "temperature": temperature,
214
+ "top_p": top_p,
215
+ "top_k": top_k,
216
+ "repetition_penalty": repetition_penalty,
217
+ }
218
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
219
+ thread.start()
220
+ buffer = ""
221
+ yield progress_bar_html("Processing video with Gemma3")
222
+ for new_text in streamer:
223
+ buffer += new_text
224
+ time.sleep(0.01)
225
+ yield buffer
226
+ return
227
  else:
228
+ # Image inference branch.
229
+ prompt_clean = re.sub(r"@gemma3", "", text, flags=re.IGNORECASE).strip().strip('"')
230
+ images = [load_image(f) for f in files]
231
+ messages = [{
232
+ "role": "user",
233
+ "content": [
234
+ *[{"type": "image", "image": image} for image in images],
235
+ {"type": "text", "text": prompt_clean},
236
+ ]
237
+ }]
238
+ inputs = gemma3_processor.apply_chat_template(
239
+ messages, tokenize=True, add_generation_prompt=True,
240
+ return_dict=True, return_tensors="pt"
241
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
242
+ streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
243
+ generation_kwargs = {
244
+ **inputs,
245
+ "streamer": streamer,
246
+ "max_new_tokens": max_new_tokens,
247
+ "do_sample": True,
248
+ "temperature": temperature,
249
+ "top_p": top_p,
250
+ "top_k": top_k,
251
+ "repetition_penalty": repetition_penalty,
252
+ }
253
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
254
+ thread.start()
255
+ buffer = ""
256
+ yield progress_bar_html("Processing with Gemma3")
257
+ for new_text in streamer:
258
+ buffer += new_text
259
+ time.sleep(0.01)
260
+ yield buffer
261
+ return
262
+ else:
263
+ # Text-only inference branch.
264
+ messages = [
265
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
266
+ {"role": "user", "content": [{"type": "text", "text": text}]}
267
+ ]
268
  inputs = gemma3_processor.apply_chat_template(
269
  messages, add_generation_prompt=True, tokenize=True,
270
  return_dict=True, return_tensors="pt"
271
  ).to(gemma3_model.device, dtype=torch.bfloat16)
272
+ streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
273
  generation_kwargs = {
274
  **inputs,
275
  "streamer": streamer,
 
282
  }
283
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
284
  thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  outputs = []
286
  for new_text in streamer:
287
  outputs.append(new_text)
288
  yield "".join(outputs)
 
289
  final_response = "".join(outputs)
290
  yield final_response
291
 
292
+
293
+ # Gradio Interface
294
+
295
  demo = gr.ChatInterface(
296
  fn=generate,
297
  additional_inputs=[
 
304
  examples=[
305
  [
306
  {
307
+ "text": "Create a short story based on the images.",
308
  "files": [
309
  "examples/1111.jpg",
310
  "examples/2222.jpg",
 
312
  ],
313
  }
314
  ],
315
+ [{"text": "Explain the Image", "files": ["examples/3.jpg"]}],
316
+ [{"text": "Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
317
+ [{"text": "Which movie character is this?", "files": ["examples/9999.jpg"]}],
318
+ ["Explain Critical Temperature of Substance"],
319
+ [{"text": "@qwen2-vl Transcription of the letter", "files": ["examples/222.png"]}],
320
+ [{"text": "Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
321
+ [{"text": "Describe the video", "files": ["examples/Missing.mp4"]}],
322
+ [{"text": "Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
323
+ [{"text": "Summarize the events in this video", "files": ["examples/sky.mp4"]}],
324
+ [{"text": "What is in the video ?", "files": ["examples/redlight.mp4"]}],
325
  ["Python Program for Array Rotation"],
326
+ ["Explain Critical Temperature of Substance"]
327
  ],
328
  cache_examples=False,
329
  type="messages",
330
+ description="# **Gemma 3 Multimodal** \n`Use @qwen2-vl to switch to Qwen2-VL OCR for image inference and @video-infer for video input`",
331
  fill_height=True,
332
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Tag with @qwen2-vl for Qwen2-VL inference if needed."),
333
  stop_btn="Stop Generation",
334
  multimodal=True,
335
  )