prithivMLmods commited on
Commit
736d689
·
verified ·
1 Parent(s): 3afc0fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -127
app.py CHANGED
@@ -15,12 +15,10 @@ from PIL import Image
15
  import cv2
16
 
17
  from transformers import (
18
- AutoModelForCausalLM,
19
- AutoTokenizer,
20
- TextIteratorStreamer,
21
- Qwen2VLForConditionalGeneration,
22
  AutoProcessor,
23
  Gemma3ForConditionalGeneration,
 
 
24
  )
25
  from transformers.image_utils import load_image
26
 
@@ -38,7 +36,7 @@ def progress_bar_html(label: str) -> str:
38
  <div style="display: flex; align-items: center;">
39
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
40
  <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
41
- <div style="width: 100%; height: 100%; background-color: #00FF00 ; animation: loading 1.5s linear infinite;"></div>
42
  </div>
43
  </div>
44
  <style>
@@ -49,18 +47,7 @@ def progress_bar_html(label: str) -> str:
49
  </style>
50
  '''
51
 
52
- # TEXT MODEL
53
-
54
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
55
- tokenizer = AutoTokenizer.from_pretrained(model_id)
56
- model = AutoModelForCausalLM.from_pretrained(
57
- model_id,
58
- device_map="auto",
59
- torch_dtype=torch.bfloat16,
60
- )
61
- model.eval()
62
-
63
- # MULTIMODAL (OCR) MODELS
64
 
65
  MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
66
  processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
@@ -102,7 +89,8 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
102
 
103
  dtype = torch.float16 if device.type == "cuda" else torch.float32
104
 
105
- # GEMMA3-4B MULTIMODAL MODEL
 
106
 
107
  gemma3_model_id = "google/gemma-3-4b-it" # alternative: google/gemma-3-12b-it
108
  gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
@@ -111,6 +99,7 @@ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
111
  gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
112
 
113
  # VIDEO PROCESSING HELPER
 
114
  def downsample_video(video_path):
115
  vidcap = cv2.VideoCapture(video_path)
116
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -144,15 +133,12 @@ def generate(
144
  ):
145
  text = input_dict["text"]
146
  files = input_dict.get("files", [])
147
-
148
  lower_text = text.lower().strip()
149
 
150
- # GEMMA3-4B TEXT & MULTIMODAL (image) Branch
151
- if lower_text.startswith("@gemma3"):
152
- # Remove the gemma3 flag from the prompt.
153
- prompt_clean = re.sub(r"@gemma3", "", text, flags=re.IGNORECASE).strip().strip('"')
154
  if files:
155
- # If image files are provided, load them.
156
  images = [load_image(f) for f in files]
157
  messages = [{
158
  "role": "user",
@@ -161,18 +147,18 @@ def generate(
161
  {"type": "text", "text": prompt_clean},
162
  ]
163
  }]
 
 
164
  else:
165
  messages = [
166
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
167
  {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
168
  ]
169
- inputs = gemma3_processor.apply_chat_template(
170
- messages, add_generation_prompt=True, tokenize=True,
171
- return_dict=True, return_tensors="pt"
172
- ).to(gemma3_model.device, dtype=torch.bfloat16)
173
- streamer = TextIteratorStreamer(
174
- gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
175
- )
176
  generation_kwargs = {
177
  **inputs,
178
  "streamer": streamer,
@@ -183,47 +169,106 @@ def generate(
183
  "top_k": top_k,
184
  "repetition_penalty": repetition_penalty,
185
  }
186
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
187
  thread.start()
188
  buffer = ""
189
- yield progress_bar_html("Processing with Gemma3")
190
  for new_text in streamer:
191
  buffer += new_text
192
  time.sleep(0.01)
193
  yield buffer
194
  return
195
 
196
- # GEMMA3-4B VIDEO Branch
197
- if lower_text.startswith("@video-infer"):
198
- # Remove the video flag from the prompt.
199
- prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
200
- if files:
201
- # Assume the first file is a video.
 
202
  video_path = files[0]
203
  frames = downsample_video(video_path)
204
  messages = [
205
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
206
  {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
207
  ]
208
- # Append each frame as an image with a timestamp label.
209
  for frame in frames:
210
  image, timestamp = frame
211
  image_path = f"video_frame_{uuid.uuid4().hex}.png"
212
  image.save(image_path)
213
  messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
214
  messages[1]["content"].append({"type": "image", "url": image_path})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  else:
216
- messages = [
217
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
218
- {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
219
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  inputs = gemma3_processor.apply_chat_template(
221
  messages, add_generation_prompt=True, tokenize=True,
222
  return_dict=True, return_tensors="pt"
223
  ).to(gemma3_model.device, dtype=torch.bfloat16)
224
- streamer = TextIteratorStreamer(
225
- gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
226
- )
227
  generation_kwargs = {
228
  **inputs,
229
  "streamer": streamer,
@@ -236,70 +281,16 @@ def generate(
236
  }
237
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
238
  thread.start()
239
- buffer = ""
240
- yield progress_bar_html("Processing video with Gemma3")
241
- for new_text in streamer:
242
- buffer += new_text
243
- time.sleep(0.01)
244
- yield buffer
245
- return
246
-
247
- # Otherwise, handle text/chat generation.
248
- conversation = clean_chat_history(chat_history)
249
- conversation.append({"role": "user", "content": text})
250
-
251
- if files:
252
- images = [load_image(image) for image in files] if len(files) > 1 else [load_image(files[0])]
253
- messages = [{
254
- "role": "user",
255
- "content": [
256
- *[{"type": "image", "image": image} for image in images],
257
- {"type": "text", "text": text},
258
- ]
259
- }]
260
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
261
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
262
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
263
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
264
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
265
- thread.start()
266
-
267
- buffer = ""
268
- yield progress_bar_html("Processing with Qwen2VL OCR")
269
- for new_text in streamer:
270
- buffer += new_text
271
- buffer = buffer.replace("<|im_end|>", "")
272
- time.sleep(0.01)
273
- yield buffer
274
- else:
275
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
276
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
277
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
278
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
279
- input_ids = input_ids.to(model.device)
280
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
281
- generation_kwargs = {
282
- "input_ids": input_ids,
283
- "streamer": streamer,
284
- "max_new_tokens": max_new_tokens,
285
- "do_sample": True,
286
- "top_p": top_p,
287
- "top_k": top_k,
288
- "temperature": temperature,
289
- "num_beams": 1,
290
- "repetition_penalty": repetition_penalty,
291
- }
292
- t = Thread(target=model.generate, kwargs=generation_kwargs)
293
- t.start()
294
-
295
  outputs = []
296
  for new_text in streamer:
297
  outputs.append(new_text)
298
  yield "".join(outputs)
299
-
300
  final_response = "".join(outputs)
301
  yield final_response
302
 
 
 
 
303
  demo = gr.ChatInterface(
304
  fn=generate,
305
  additional_inputs=[
@@ -310,34 +301,25 @@ demo = gr.ChatInterface(
310
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
311
  ],
312
  examples=[
313
- [
314
- {
315
- "text": "@gemma3 Create a short story based on the images.",
316
- "files": [
317
- "examples/1111.jpg",
318
- "examples/2222.jpg",
319
- "examples/3333.jpg",
320
- ],
321
- }
322
- ],
323
- [{"text": "@gemma3 Explain the Image", "files": ["examples/3.jpg"]}],
324
- [{"text": "@video-infer Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
325
- [{"text": "@gemma3 Which movie character is this?", "files": ["examples/9999.jpg"]}],
326
- ["@gemma3 Explain Critical Temperature of Substance"],
327
- [{"text": "@gemma3 Transcription of the letter", "files": ["examples/222.png"]}],
328
- [{"text": "@video-infer Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
329
- [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
330
- [{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
331
- [{"text": "@video-infer Summarize the events in this video", "files": ["examples/sky.mp4"]}],
332
- [{"text": "@video-infer What is in the video ?", "files": ["examples/redlight.mp4"]}],
333
  ["Python Program for Array Rotation"],
334
- ["@gemma3 Explain Critical Temperature of Substance"]
335
  ],
336
  cache_examples=False,
337
  type="messages",
338
- description="# **Gemma 3 `@gemma3, @video-infer for video understanding`**",
339
  fill_height=True,
340
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Tag--> @gemma3 for multimodal, @video-infer for video !"),
341
  stop_btn="Stop Generation",
342
  multimodal=True,
343
  )
 
15
  import cv2
16
 
17
  from transformers import (
 
 
 
 
18
  AutoProcessor,
19
  Gemma3ForConditionalGeneration,
20
+ Qwen2VLForConditionalGeneration,
21
+ TextIteratorStreamer,
22
  )
23
  from transformers.image_utils import load_image
24
 
 
36
  <div style="display: flex; align-items: center;">
37
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
38
  <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
39
+ <div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div>
40
  </div>
41
  </div>
42
  <style>
 
47
  </style>
48
  '''
49
 
50
+ # Qwen2-VL (for optional image inference)
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
53
  processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
 
89
 
90
  dtype = torch.float16 if device.type == "cuda" else torch.float32
91
 
92
+
93
+ # Gemma3 Model (default for text, image, & video inference)
94
 
95
  gemma3_model_id = "google/gemma-3-4b-it" # alternative: google/gemma-3-12b-it
96
  gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
 
99
  gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
100
 
101
  # VIDEO PROCESSING HELPER
102
+
103
  def downsample_video(video_path):
104
  vidcap = cv2.VideoCapture(video_path)
105
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
 
133
  ):
134
  text = input_dict["text"]
135
  files = input_dict.get("files", [])
 
136
  lower_text = text.lower().strip()
137
 
138
+ # ----- Qwen2-VL branch (triggered with @qwen2-vl) -----
139
+ if lower_text.startswith("@qwen2-vl"):
140
+ prompt_clean = re.sub(r"@qwen2-vl", "", text, flags=re.IGNORECASE).strip().strip('"')
 
141
  if files:
 
142
  images = [load_image(f) for f in files]
143
  messages = [{
144
  "role": "user",
 
147
  {"type": "text", "text": prompt_clean},
148
  ]
149
  }]
150
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
151
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
152
  else:
153
  messages = [
154
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
155
  {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
156
  ]
157
+ inputs = processor.apply_chat_template(
158
+ messages, add_generation_prompt=True, tokenize=True,
159
+ return_dict=True, return_tensors="pt"
160
+ ).to("cuda", dtype=torch.float16)
161
+ streamer = TextIteratorStreamer(processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
162
  generation_kwargs = {
163
  **inputs,
164
  "streamer": streamer,
 
169
  "top_k": top_k,
170
  "repetition_penalty": repetition_penalty,
171
  }
172
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
173
  thread.start()
174
  buffer = ""
175
+ yield progress_bar_html("Processing with Qwen2VL")
176
  for new_text in streamer:
177
  buffer += new_text
178
  time.sleep(0.01)
179
  yield buffer
180
  return
181
 
182
+ # ----- Default branch: Gemma3 (for text, image, & video inference) -----
183
+ if files:
184
+ # Check if any provided file is a video based on extension.
185
+ video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
186
+ if any(str(f).lower().endswith(video_extensions) for f in files):
187
+ # Video inference branch.
188
+ prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
189
  video_path = files[0]
190
  frames = downsample_video(video_path)
191
  messages = [
192
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
193
  {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
194
  ]
195
+ # Append each frame (with its timestamp) to the conversation.
196
  for frame in frames:
197
  image, timestamp = frame
198
  image_path = f"video_frame_{uuid.uuid4().hex}.png"
199
  image.save(image_path)
200
  messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
201
  messages[1]["content"].append({"type": "image", "url": image_path})
202
+ inputs = gemma3_processor.apply_chat_template(
203
+ messages, add_generation_prompt=True, tokenize=True,
204
+ return_dict=True, return_tensors="pt"
205
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
206
+ streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
207
+ generation_kwargs = {
208
+ **inputs,
209
+ "streamer": streamer,
210
+ "max_new_tokens": max_new_tokens,
211
+ "do_sample": True,
212
+ "temperature": temperature,
213
+ "top_p": top_p,
214
+ "top_k": top_k,
215
+ "repetition_penalty": repetition_penalty,
216
+ }
217
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
218
+ thread.start()
219
+ buffer = ""
220
+ yield progress_bar_html("Processing video with Gemma3")
221
+ for new_text in streamer:
222
+ buffer += new_text
223
+ time.sleep(0.01)
224
+ yield buffer
225
+ return
226
  else:
227
+ # Image inference branch.
228
+ prompt_clean = re.sub(r"@gemma3", "", text, flags=re.IGNORECASE).strip().strip('"')
229
+ images = [load_image(f) for f in files]
230
+ messages = [{
231
+ "role": "user",
232
+ "content": [
233
+ *[{"type": "image", "image": image} for image in images],
234
+ {"type": "text", "text": prompt_clean},
235
+ ]
236
+ }]
237
+ inputs = gemma3_processor.apply_chat_template(
238
+ messages, tokenize=True, add_generation_prompt=True,
239
+ return_dict=True, return_tensors="pt"
240
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
241
+ streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
242
+ generation_kwargs = {
243
+ **inputs,
244
+ "streamer": streamer,
245
+ "max_new_tokens": max_new_tokens,
246
+ "do_sample": True,
247
+ "temperature": temperature,
248
+ "top_p": top_p,
249
+ "top_k": top_k,
250
+ "repetition_penalty": repetition_penalty,
251
+ }
252
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
253
+ thread.start()
254
+ buffer = ""
255
+ yield progress_bar_html("Processing with Gemma3")
256
+ for new_text in streamer:
257
+ buffer += new_text
258
+ time.sleep(0.01)
259
+ yield buffer
260
+ return
261
+ else:
262
+ # Text-only inference branch.
263
+ messages = [
264
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
265
+ {"role": "user", "content": [{"type": "text", "text": text}]}
266
+ ]
267
  inputs = gemma3_processor.apply_chat_template(
268
  messages, add_generation_prompt=True, tokenize=True,
269
  return_dict=True, return_tensors="pt"
270
  ).to(gemma3_model.device, dtype=torch.bfloat16)
271
+ streamer = TextIteratorStreamer(gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
272
  generation_kwargs = {
273
  **inputs,
274
  "streamer": streamer,
 
281
  }
282
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
283
  thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  outputs = []
285
  for new_text in streamer:
286
  outputs.append(new_text)
287
  yield "".join(outputs)
 
288
  final_response = "".join(outputs)
289
  yield final_response
290
 
291
+
292
+ # Gradio Interface
293
+
294
  demo = gr.ChatInterface(
295
  fn=generate,
296
  additional_inputs=[
 
301
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
302
  ],
303
  examples=[
304
+ [{"text": "Create a short story based on the images.", "files": ["examples/1111.jpg", "examples/2222.jpg", "examples/3333.jpg"]}],
305
+ [{"text": "Explain the Image", "files": ["examples/3.jpg"]}],
306
+ [{"text": "Explain the content of the Advertisement", "files": ["examples/videoplayback.mp4"]}],
307
+ [{"text": "Which movie character is this?", "files": ["examples/9999.jpg"]}],
308
+ ["Explain Critical Temperature of Substance"],
309
+ [{"text": "Transcription of the letter", "files": ["examples/222.png"]}],
310
+ [{"text": "Explain the content of the video in detail", "files": ["examples/breakfast.mp4"]}],
311
+ [{"text": "Describe the video", "files": ["examples/Missing.mp4"]}],
312
+ [{"text": "Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
313
+ [{"text": "Summarize the events in this video", "files": ["examples/sky.mp4"]}],
314
+ [{"text": "What is in the video ?", "files": ["examples/redlight.mp4"]}],
 
 
 
 
 
 
 
 
 
315
  ["Python Program for Array Rotation"],
316
+ ["Explain Critical Temperature of Substance"]
317
  ],
318
  cache_examples=False,
319
  type="messages",
320
+ description="# **Gemma 3 Multimodal**",
321
  fill_height=True,
322
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="Tag with @qwen2-vl for Qwen2-VL inference if needed."),
323
  stop_btn="Stop Generation",
324
  multimodal=True,
325
  )