prithivMLmods commited on
Commit
7508a03
·
verified ·
1 Parent(s): 05081bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -3
app.py CHANGED
@@ -37,6 +37,7 @@ from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
37
  from diffusers.utils import export_to_ply
38
 
39
  os.system('pip install backoff')
 
40
  # Global constants and helper functions
41
 
42
  MAX_SEED = np.iinfo(np.int32).max
@@ -323,6 +324,14 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
323
  torch_dtype=torch.float16
324
  ).to("cuda").eval()
325
 
 
 
 
 
 
 
 
 
326
  # Asynchronous text-to-speech
327
 
328
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
@@ -464,7 +473,7 @@ def detect_objects(image: np.ndarray):
464
 
465
  return Image.fromarray(annotated_image)
466
 
467
- # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, and now @phi4 commands
468
 
469
  @spaces.GPU
470
  def generate(
@@ -484,7 +493,8 @@ def generate(
484
  - "@web": triggers a web search or webpage visit.
485
  - "@rAgent": initiates a reasoning chain using Llama mode.
486
  - "@yolo": triggers object detection using YOLO.
487
- - **"@phi4": triggers multimodal (image/audio) processing using the Phi-4 model.**
 
488
  """
489
  text = input_dict["text"]
490
  files = input_dict.get("files", [])
@@ -644,6 +654,48 @@ def generate(
644
  yield buffer
645
  return
646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  # --- Text and TTS branch ---
648
  tts_prefix = "@tts"
649
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
@@ -744,6 +796,7 @@ demo = gr.ChatInterface(
744
  ["@rAgent Explain how a binary search algorithm works."],
745
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
746
  ["@tts1 Explain Tower of Hanoi"],
 
747
  ],
748
  cache_examples=False,
749
  type="messages",
@@ -754,7 +807,7 @@ demo = gr.ChatInterface(
754
  label="Query Input",
755
  file_types=["image", "audio"],
756
  file_count="multiple",
757
- placeholder="‎ @tts1, @tts2, @image, @3d, @phi4 [image, audio], @rAgent, @web, @yolo, default [plain text]"
758
  ),
759
  stop_btn="Stop Generation",
760
  multimodal=True,
 
37
  from diffusers.utils import export_to_ply
38
 
39
  os.system('pip install backoff')
40
+
41
  # Global constants and helper functions
42
 
43
  MAX_SEED = np.iinfo(np.int32).max
 
324
  torch_dtype=torch.float16
325
  ).to("cuda").eval()
326
 
327
+ # ------------------------------------------------------------------------------
328
+ # New Gemma3-4b Multimodal Feature (Image & Text)
329
+ # ------------------------------------------------------------------------------
330
+ from transformers import AutoProcessor as Gemma3AutoProcessor, Gemma3ForConditionalGeneration
331
+ gemma3_model_id = "google/gemma-3-4b-it"
332
+ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(gemma3_model_id, device_map="auto").eval()
333
+ gemma3_processor = Gemma3AutoProcessor.from_pretrained(gemma3_model_id)
334
+
335
  # Asynchronous text-to-speech
336
 
337
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
 
473
 
474
  return Image.fromarray(annotated_image)
475
 
476
+ # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, @phi4, and now @gemma3-4b commands
477
 
478
  @spaces.GPU
479
  def generate(
 
493
  - "@web": triggers a web search or webpage visit.
494
  - "@rAgent": initiates a reasoning chain using Llama mode.
495
  - "@yolo": triggers object detection using YOLO.
496
+ - "@phi4": triggers multimodal (image/audio) processing using the Phi-4 model.
497
+ - **"@gemma3-4b": triggers multimodal (image/text) processing using the Gemma3-4b model.**
498
  """
499
  text = input_dict["text"]
500
  files = input_dict.get("files", [])
 
654
  yield buffer
655
  return
656
 
657
+ # --- Gemma3-4b Multimodal branch (Image/Text) with Streaming ---
658
+ if text.strip().lower().startswith("@gemma3-4b"):
659
+ question = text[len("@gemma3-4b"):].strip()
660
+ messages = [
661
+ {
662
+ "role": "system",
663
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
664
+ },
665
+ {
666
+ "role": "user",
667
+ "content": []
668
+ }
669
+ ]
670
+ if files:
671
+ try:
672
+ # If file is already a PIL Image, use it; otherwise try opening it.
673
+ if isinstance(files[0], Image.Image):
674
+ image = files[0]
675
+ else:
676
+ image = Image.open(files[0])
677
+ messages[1]["content"].append({"type": "image", "image": image})
678
+ except Exception as e:
679
+ yield f"Error processing image: {str(e)}"
680
+ return
681
+ messages[1]["content"].append({"type": "text", "text": question})
682
+ inputs = gemma3_processor.apply_chat_template(
683
+ messages, add_generation_prompt=True, tokenize=True,
684
+ return_dict=True, return_tensors="pt"
685
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
686
+ input_len = inputs["input_ids"].shape[-1]
687
+ streamer = TextIteratorStreamer(gemma3_processor, skip_prompt=True, skip_special_tokens=True)
688
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": False}
689
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
690
+ thread.start()
691
+ buffer = ""
692
+ yield progress_bar_html("Processing Gemma3-4b Multimodal")
693
+ for new_text in streamer:
694
+ buffer += new_text
695
+ time.sleep(0.01)
696
+ yield buffer
697
+ return
698
+
699
  # --- Text and TTS branch ---
700
  tts_prefix = "@tts"
701
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
 
796
  ["@rAgent Explain how a binary search algorithm works."],
797
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
798
  ["@tts1 Explain Tower of Hanoi"],
799
+ ["@gemma3-4b Describe this image in detail."]
800
  ],
801
  cache_examples=False,
802
  type="messages",
 
807
  label="Query Input",
808
  file_types=["image", "audio"],
809
  file_count="multiple",
810
+ placeholder="‎ @tts1, @tts2, @image, @3d, @phi4 [image, audio], @gemma3-4b, @rAgent, @web, @yolo, default [plain text]"
811
  ),
812
  stop_btn="Stop Generation",
813
  multimodal=True,