prithivMLmods commited on
Commit
e780483
·
verified ·
1 Parent(s): 27e1a3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -19
app.py CHANGED
@@ -13,6 +13,7 @@ import torch
13
  import numpy as np
14
  from PIL import Image
15
  import edge_tts
 
16
 
17
  from transformers import (
18
  AutoModelForCausalLM,
@@ -113,7 +114,6 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
113
 
114
  dtype = torch.float16 if device.type == "cuda" else torch.float32
115
 
116
-
117
  # STABLE DIFFUSION IMAGE GENERATION MODELS
118
 
119
  if torch.cuda.is_available():
@@ -201,7 +201,6 @@ def save_image(img: Image.Image) -> str:
201
  img.save(unique_name)
202
  return unique_name
203
 
204
-
205
  # GEMMA3-4B MULTIMODAL MODEL
206
 
207
  gemma3_model_id = "google/gemma-3-4b-it"
@@ -210,6 +209,25 @@ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
210
  ).eval()
211
  gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  # MAIN GENERATION FUNCTION
215
 
@@ -228,7 +246,7 @@ def generate(
228
 
229
  lower_text = text.lower().strip()
230
 
231
- # Image Generation Branch (Stable Diffusion models)
232
  if (lower_text.startswith("@lightningv5") or
233
  lower_text.startswith("@lightningv4") or
234
  lower_text.startswith("@turbov3")):
@@ -280,20 +298,76 @@ def generate(
280
  yield gr.Image(image_path)
281
  return
282
 
283
- # GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
284
  if lower_text.startswith("@gemma3-4b"):
285
- # Remove the gemma3 flag from the prompt.
286
- prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
287
- if files:
288
- # If image files are provided, load them.
289
- images = [load_image(f) for f in files]
290
- messages = [{
291
- "role": "user",
292
- "content": [
293
- *[{"type": "image", "image": image} for image in images],
294
- {"type": "text", "text": prompt_clean},
 
 
 
 
 
 
 
 
 
 
295
  ]
296
- }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  else:
298
  messages = [
299
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
@@ -319,7 +393,7 @@ def generate(
319
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
320
  thread.start()
321
  buffer = ""
322
- yield progress_bar_html("Processing with Gemma3-4b")
323
  for new_text in streamer:
324
  buffer += new_text
325
  time.sleep(0.01)
@@ -408,7 +482,9 @@ demo = gr.ChatInterface(
408
  ],
409
  examples=[
410
  [{"text": "@gemma3-4b Explain the Image", "files": ["examples/3.jpg"]}],
411
- [{"text": "@gemma3-4b What's funny about this image ?", "files": ["examples/images.jpeg"]}],
 
 
412
  [{"text": "@gemma3-4b Where do the major drought happen?", "files": ["examples/111.png"]}],
413
  [{"text": "@gemma3-4b Transcription of the letter", "files": ["examples/222.png"]}],
414
  ['@lightningv5 Chocolate dripping from a donut'],
@@ -420,9 +496,9 @@ demo = gr.ChatInterface(
420
  ],
421
  cache_examples=False,
422
  type="messages",
423
- description="# **Imagineo Chat `@gemma3-4b 'prompt..', @lightningv5, etc..`**",
424
  fill_height=True,
425
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @gemma3-4b for multimodal, @lightningv5, @lightningv4 @turbov3 for image gen !"),
426
  stop_btn="Stop Generation",
427
  multimodal=True,
428
  )
 
13
  import numpy as np
14
  from PIL import Image
15
  import edge_tts
16
+ import cv2 # New import for video processing
17
 
18
  from transformers import (
19
  AutoModelForCausalLM,
 
114
 
115
  dtype = torch.float16 if device.type == "cuda" else torch.float32
116
 
 
117
  # STABLE DIFFUSION IMAGE GENERATION MODELS
118
 
119
  if torch.cuda.is_available():
 
201
  img.save(unique_name)
202
  return unique_name
203
 
 
204
  # GEMMA3-4B MULTIMODAL MODEL
205
 
206
  gemma3_model_id = "google/gemma-3-4b-it"
 
209
  ).eval()
210
  gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
211
 
212
+ # VIDEO PROCESSING HELPER
213
+ def downsample_video(video_path):
214
+ vidcap = cv2.VideoCapture(video_path)
215
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
216
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
217
+ frames = []
218
+ # Sample 10 evenly spaced frames.
219
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
220
+ for i in frame_indices:
221
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
222
+ success, image = vidcap.read()
223
+ if success:
224
+ # Convert from BGR to RGB and then to PIL Image.
225
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
226
+ pil_image = Image.fromarray(image)
227
+ timestamp = round(i / fps, 2)
228
+ frames.append((pil_image, timestamp))
229
+ vidcap.release()
230
+ return frames
231
 
232
  # MAIN GENERATION FUNCTION
233
 
 
246
 
247
  lower_text = text.lower().strip()
248
 
249
+ # IMAGE GENERATION BRANCH (Stable Diffusion models)
250
  if (lower_text.startswith("@lightningv5") or
251
  lower_text.startswith("@lightningv4") or
252
  lower_text.startswith("@turbov3")):
 
298
  yield gr.Image(image_path)
299
  return
300
 
301
+ # GEMMA3-4B TEXT & MULTIMODAL (image) Branch
302
  if lower_text.startswith("@gemma3-4b"):
303
+ # If it is video, let the dedicated branch handle it.
304
+ if lower_text.startswith("@gemma3-4b-video"):
305
+ pass # video branch is handled below.
306
+ else:
307
+ # Remove the gemma3 flag from the prompt.
308
+ prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
309
+ if files:
310
+ # If image files are provided, load them.
311
+ images = [load_image(f) for f in files]
312
+ messages = [{
313
+ "role": "user",
314
+ "content": [
315
+ *[{"type": "image", "image": image} for image in images],
316
+ {"type": "text", "text": prompt_clean},
317
+ ]
318
+ }]
319
+ else:
320
+ messages = [
321
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
322
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
323
  ]
324
+ inputs = gemma3_processor.apply_chat_template(
325
+ messages, add_generation_prompt=True, tokenize=True,
326
+ return_dict=True, return_tensors="pt"
327
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
328
+ streamer = TextIteratorStreamer(
329
+ gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
330
+ )
331
+ generation_kwargs = {
332
+ **inputs,
333
+ "streamer": streamer,
334
+ "max_new_tokens": max_new_tokens,
335
+ "do_sample": True,
336
+ "temperature": temperature,
337
+ "top_p": top_p,
338
+ "top_k": top_k,
339
+ "repetition_penalty": repetition_penalty,
340
+ }
341
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
342
+ thread.start()
343
+ buffer = ""
344
+ yield progress_bar_html("Processing with Gemma3-4b")
345
+ for new_text in streamer:
346
+ buffer += new_text
347
+ time.sleep(0.01)
348
+ yield buffer
349
+ return
350
+
351
+ # NEW: GEMMA3-4B VIDEO Branch
352
+ if lower_text.startswith("@gemma3-4b-video"):
353
+ # Remove the video flag from the prompt.
354
+ prompt_clean = re.sub(r"@gemma3-4b-video", "", text, flags=re.IGNORECASE).strip().strip('"')
355
+ if files:
356
+ # Assume the first file is a video.
357
+ video_path = files[0]
358
+ frames = downsample_video(video_path)
359
+ messages = [
360
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
361
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
362
+ ]
363
+ # Append each frame as an image with a timestamp label.
364
+ for frame in frames:
365
+ image, timestamp = frame
366
+ # Save the frame image to a temporary unique filename.
367
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
368
+ image.save(image_path)
369
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
370
+ messages[1]["content"].append({"type": "image", "url": image_path})
371
  else:
372
  messages = [
373
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
 
393
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
394
  thread.start()
395
  buffer = ""
396
+ yield progress_bar_html("Processing with Gemma3-4b Video")
397
  for new_text in streamer:
398
  buffer += new_text
399
  time.sleep(0.01)
 
482
  ],
483
  examples=[
484
  [{"text": "@gemma3-4b Explain the Image", "files": ["examples/3.jpg"]}],
485
+ [{"text": "@gemma3-4b-video Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
486
+ [{"text": "@gemma3-4b-video Summarize the events in this video", "files": ["examples/sky.mp4"]}],
487
+ [{"text": "@gemma3-4b-video What is in the video ?", "files": ["examples/redlight.mp4"]}],
488
  [{"text": "@gemma3-4b Where do the major drought happen?", "files": ["examples/111.png"]}],
489
  [{"text": "@gemma3-4b Transcription of the letter", "files": ["examples/222.png"]}],
490
  ['@lightningv5 Chocolate dripping from a donut'],
 
496
  ],
497
  cache_examples=False,
498
  type="messages",
499
+ description="# **Imagineo Chat `@gemma3-4b 'prompt..', @gemma3-4b-video, @lightningv5, etc..`**",
500
  fill_height=True,
501
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="use the tags @gemma3-4b for multimodal, @gemma3-4b-video for video, @lightningv5, @lightningv4, @turbov3 for image gen !"),
502
  stop_btn="Stop Generation",
503
  multimodal=True,
504
  )