Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ import torch
|
|
13 |
import numpy as np
|
14 |
from PIL import Image
|
15 |
import edge_tts
|
|
|
16 |
|
17 |
from transformers import (
|
18 |
AutoModelForCausalLM,
|
@@ -113,7 +114,6 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
|
113 |
|
114 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
115 |
|
116 |
-
|
117 |
# STABLE DIFFUSION IMAGE GENERATION MODELS
|
118 |
|
119 |
if torch.cuda.is_available():
|
@@ -201,7 +201,6 @@ def save_image(img: Image.Image) -> str:
|
|
201 |
img.save(unique_name)
|
202 |
return unique_name
|
203 |
|
204 |
-
|
205 |
# GEMMA3-4B MULTIMODAL MODEL
|
206 |
|
207 |
gemma3_model_id = "google/gemma-3-4b-it"
|
@@ -210,6 +209,25 @@ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
210 |
).eval()
|
211 |
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
# MAIN GENERATION FUNCTION
|
215 |
|
@@ -228,7 +246,7 @@ def generate(
|
|
228 |
|
229 |
lower_text = text.lower().strip()
|
230 |
|
231 |
-
#
|
232 |
if (lower_text.startswith("@lightningv5") or
|
233 |
lower_text.startswith("@lightningv4") or
|
234 |
lower_text.startswith("@turbov3")):
|
@@ -280,20 +298,76 @@ def generate(
|
|
280 |
yield gr.Image(image_path)
|
281 |
return
|
282 |
|
283 |
-
# GEMMA3-4B
|
284 |
if lower_text.startswith("@gemma3-4b"):
|
285 |
-
#
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
]
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
else:
|
298 |
messages = [
|
299 |
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
@@ -319,7 +393,7 @@ def generate(
|
|
319 |
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
320 |
thread.start()
|
321 |
buffer = ""
|
322 |
-
yield progress_bar_html("Processing with Gemma3-4b")
|
323 |
for new_text in streamer:
|
324 |
buffer += new_text
|
325 |
time.sleep(0.01)
|
@@ -408,7 +482,9 @@ demo = gr.ChatInterface(
|
|
408 |
],
|
409 |
examples=[
|
410 |
[{"text": "@gemma3-4b Explain the Image", "files": ["examples/3.jpg"]}],
|
411 |
-
[{"text": "@gemma3-4b
|
|
|
|
|
412 |
[{"text": "@gemma3-4b Where do the major drought happen?", "files": ["examples/111.png"]}],
|
413 |
[{"text": "@gemma3-4b Transcription of the letter", "files": ["examples/222.png"]}],
|
414 |
['@lightningv5 Chocolate dripping from a donut'],
|
@@ -420,9 +496,9 @@ demo = gr.ChatInterface(
|
|
420 |
],
|
421 |
cache_examples=False,
|
422 |
type="messages",
|
423 |
-
description="# **Imagineo Chat `@gemma3-4b 'prompt..', @lightningv5, etc..`**",
|
424 |
fill_height=True,
|
425 |
-
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags
|
426 |
stop_btn="Stop Generation",
|
427 |
multimodal=True,
|
428 |
)
|
|
|
13 |
import numpy as np
|
14 |
from PIL import Image
|
15 |
import edge_tts
|
16 |
+
import cv2 # New import for video processing
|
17 |
|
18 |
from transformers import (
|
19 |
AutoModelForCausalLM,
|
|
|
114 |
|
115 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
116 |
|
|
|
117 |
# STABLE DIFFUSION IMAGE GENERATION MODELS
|
118 |
|
119 |
if torch.cuda.is_available():
|
|
|
201 |
img.save(unique_name)
|
202 |
return unique_name
|
203 |
|
|
|
204 |
# GEMMA3-4B MULTIMODAL MODEL
|
205 |
|
206 |
gemma3_model_id = "google/gemma-3-4b-it"
|
|
|
209 |
).eval()
|
210 |
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
211 |
|
212 |
+
# VIDEO PROCESSING HELPER
|
213 |
+
def downsample_video(video_path):
|
214 |
+
vidcap = cv2.VideoCapture(video_path)
|
215 |
+
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
216 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
217 |
+
frames = []
|
218 |
+
# Sample 10 evenly spaced frames.
|
219 |
+
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
|
220 |
+
for i in frame_indices:
|
221 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
222 |
+
success, image = vidcap.read()
|
223 |
+
if success:
|
224 |
+
# Convert from BGR to RGB and then to PIL Image.
|
225 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
226 |
+
pil_image = Image.fromarray(image)
|
227 |
+
timestamp = round(i / fps, 2)
|
228 |
+
frames.append((pil_image, timestamp))
|
229 |
+
vidcap.release()
|
230 |
+
return frames
|
231 |
|
232 |
# MAIN GENERATION FUNCTION
|
233 |
|
|
|
246 |
|
247 |
lower_text = text.lower().strip()
|
248 |
|
249 |
+
# IMAGE GENERATION BRANCH (Stable Diffusion models)
|
250 |
if (lower_text.startswith("@lightningv5") or
|
251 |
lower_text.startswith("@lightningv4") or
|
252 |
lower_text.startswith("@turbov3")):
|
|
|
298 |
yield gr.Image(image_path)
|
299 |
return
|
300 |
|
301 |
+
# GEMMA3-4B TEXT & MULTIMODAL (image) Branch
|
302 |
if lower_text.startswith("@gemma3-4b"):
|
303 |
+
# If it is video, let the dedicated branch handle it.
|
304 |
+
if lower_text.startswith("@gemma3-4b-video"):
|
305 |
+
pass # video branch is handled below.
|
306 |
+
else:
|
307 |
+
# Remove the gemma3 flag from the prompt.
|
308 |
+
prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
|
309 |
+
if files:
|
310 |
+
# If image files are provided, load them.
|
311 |
+
images = [load_image(f) for f in files]
|
312 |
+
messages = [{
|
313 |
+
"role": "user",
|
314 |
+
"content": [
|
315 |
+
*[{"type": "image", "image": image} for image in images],
|
316 |
+
{"type": "text", "text": prompt_clean},
|
317 |
+
]
|
318 |
+
}]
|
319 |
+
else:
|
320 |
+
messages = [
|
321 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
322 |
+
{"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
|
323 |
]
|
324 |
+
inputs = gemma3_processor.apply_chat_template(
|
325 |
+
messages, add_generation_prompt=True, tokenize=True,
|
326 |
+
return_dict=True, return_tensors="pt"
|
327 |
+
).to(gemma3_model.device, dtype=torch.bfloat16)
|
328 |
+
streamer = TextIteratorStreamer(
|
329 |
+
gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
|
330 |
+
)
|
331 |
+
generation_kwargs = {
|
332 |
+
**inputs,
|
333 |
+
"streamer": streamer,
|
334 |
+
"max_new_tokens": max_new_tokens,
|
335 |
+
"do_sample": True,
|
336 |
+
"temperature": temperature,
|
337 |
+
"top_p": top_p,
|
338 |
+
"top_k": top_k,
|
339 |
+
"repetition_penalty": repetition_penalty,
|
340 |
+
}
|
341 |
+
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
342 |
+
thread.start()
|
343 |
+
buffer = ""
|
344 |
+
yield progress_bar_html("Processing with Gemma3-4b")
|
345 |
+
for new_text in streamer:
|
346 |
+
buffer += new_text
|
347 |
+
time.sleep(0.01)
|
348 |
+
yield buffer
|
349 |
+
return
|
350 |
+
|
351 |
+
# NEW: GEMMA3-4B VIDEO Branch
|
352 |
+
if lower_text.startswith("@gemma3-4b-video"):
|
353 |
+
# Remove the video flag from the prompt.
|
354 |
+
prompt_clean = re.sub(r"@gemma3-4b-video", "", text, flags=re.IGNORECASE).strip().strip('"')
|
355 |
+
if files:
|
356 |
+
# Assume the first file is a video.
|
357 |
+
video_path = files[0]
|
358 |
+
frames = downsample_video(video_path)
|
359 |
+
messages = [
|
360 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
361 |
+
{"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
|
362 |
+
]
|
363 |
+
# Append each frame as an image with a timestamp label.
|
364 |
+
for frame in frames:
|
365 |
+
image, timestamp = frame
|
366 |
+
# Save the frame image to a temporary unique filename.
|
367 |
+
image_path = f"video_frame_{uuid.uuid4().hex}.png"
|
368 |
+
image.save(image_path)
|
369 |
+
messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
|
370 |
+
messages[1]["content"].append({"type": "image", "url": image_path})
|
371 |
else:
|
372 |
messages = [
|
373 |
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
|
|
393 |
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
394 |
thread.start()
|
395 |
buffer = ""
|
396 |
+
yield progress_bar_html("Processing with Gemma3-4b Video")
|
397 |
for new_text in streamer:
|
398 |
buffer += new_text
|
399 |
time.sleep(0.01)
|
|
|
482 |
],
|
483 |
examples=[
|
484 |
[{"text": "@gemma3-4b Explain the Image", "files": ["examples/3.jpg"]}],
|
485 |
+
[{"text": "@gemma3-4b-video Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
|
486 |
+
[{"text": "@gemma3-4b-video Summarize the events in this video", "files": ["examples/sky.mp4"]}],
|
487 |
+
[{"text": "@gemma3-4b-video What is in the video ?", "files": ["examples/redlight.mp4"]}],
|
488 |
[{"text": "@gemma3-4b Where do the major drought happen?", "files": ["examples/111.png"]}],
|
489 |
[{"text": "@gemma3-4b Transcription of the letter", "files": ["examples/222.png"]}],
|
490 |
['@lightningv5 Chocolate dripping from a donut'],
|
|
|
496 |
],
|
497 |
cache_examples=False,
|
498 |
type="messages",
|
499 |
+
description="# **Imagineo Chat `@gemma3-4b 'prompt..', @gemma3-4b-video, @lightningv5, etc..`**",
|
500 |
fill_height=True,
|
501 |
+
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="use the tags @gemma3-4b for multimodal, @gemma3-4b-video for video, @lightningv5, @lightningv4, @turbov3 for image gen !"),
|
502 |
stop_btn="Stop Generation",
|
503 |
multimodal=True,
|
504 |
)
|