prithivMLmods commited on
Commit
6565507
ยท
verified ยท
1 Parent(s): 2d6715b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -7
app.py CHANGED
@@ -253,12 +253,15 @@ phi4_model = AutoModelForCausalLM.from_pretrained(
253
  _attn_implementation="eager",
254
  )
255
 
256
- # ------------------------------------------------------------------------------
257
- # Gradio UI configuration
258
- # ------------------------------------------------------------------------------
 
 
259
 
260
  DESCRIPTION = """
261
- # Agent Dino ๐ŸŒ """
 
262
 
263
  css = '''
264
  h1 {
@@ -447,7 +450,7 @@ def detect_objects(image: np.ndarray):
447
 
448
  return Image.fromarray(annotated_image)
449
 
450
- # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, and now @phi4 commands
451
 
452
  @spaces.GPU
453
  def generate(
@@ -467,7 +470,8 @@ def generate(
467
  - "@web": triggers a web search or webpage visit.
468
  - "@rAgent": initiates a reasoning chain using Llama mode.
469
  - "@yolo": triggers object detection using YOLO.
470
- - **"@phi4": triggers multimodal (image/audio) processing using the Phi-4 model.**
 
471
  """
472
  text = input_dict["text"]
473
  files = input_dict.get("files", [])
@@ -626,6 +630,37 @@ def generate(
626
  yield buffer
627
  return
628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  # --- Text and TTS branch ---
630
  tts_prefix = "@tts"
631
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
@@ -724,6 +759,7 @@ demo = gr.ChatInterface(
724
  ["@rAgent Explain how a binary search algorithm works."],
725
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
726
  ["@tts1 Explain Tower of Hanoi"],
 
727
  ],
728
  cache_examples=False,
729
  type="messages",
@@ -734,7 +770,7 @@ demo = gr.ChatInterface(
734
  label="Query Input",
735
  file_types=["image", "audio"],
736
  file_count="multiple",
737
- placeholder="@tts1, @tts2, @image, @3d, @phi4 [image, audio], @rAgent, @web, @yolo, default [plain text]"
738
  ),
739
  stop_btn="Stop Generation",
740
  multimodal=True,
 
253
  _attn_implementation="eager",
254
  )
255
 
256
+ grpo_model_name = "prithivMLmods/SmolLM2-360M-Grpo-r999"
257
+ grpo_device = "cuda" if torch.cuda.is_available() else "cpu"
258
+ grpo_tokenizer = AutoTokenizer.from_pretrained(grpo_model_name)
259
+ grpo_model = AutoModelForCausalLM.from_pretrained(grpo_model_name).to(grpo_device)
260
+
261
 
262
  DESCRIPTION = """
263
+ # Agent Dino ๐ŸŒ 
264
+ """
265
 
266
  css = '''
267
  h1 {
 
450
 
451
  return Image.fromarray(annotated_image)
452
 
453
+ # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, @phi4, and now @grpo commands
454
 
455
  @spaces.GPU
456
  def generate(
 
470
  - "@web": triggers a web search or webpage visit.
471
  - "@rAgent": initiates a reasoning chain using Llama mode.
472
  - "@yolo": triggers object detection using YOLO.
473
+ - "@phi4": triggers multimodal (image/audio) processing using the Phi-4 model.
474
+ - **"@grpo": triggers text generation using the GRPO model with a text streamer.**
475
  """
476
  text = input_dict["text"]
477
  files = input_dict.get("files", [])
 
630
  yield buffer
631
  return
632
 
633
+ # --- GRPO Text Generation branch ---
634
+ if text.strip().lower().startswith("@grpo"):
635
+ prompt = text[len("@grpo"):].strip()
636
+ yield "๐Ÿ“ Generating text with @grpo..."
637
+ messages = [
638
+ {"role": "system", "content": "Please respond in this specific format ONLY:\n<thinking>\n input your reasoning behind your answer in between these reasoning tags.\n</thinking>\n<answer>\nyour answer in between these answer tags.\n</answer>\n"},
639
+ {"role": "user", "content": prompt}
640
+ ]
641
+ # Use the GRPO tokenizer's chat template if available, otherwise simply join the messages.
642
+ input_text = grpo_tokenizer.apply_chat_template(messages, tokenize=False) if hasattr(grpo_tokenizer, "apply_chat_template") else "\n".join([msg["content"] for msg in messages])
643
+ inputs = grpo_tokenizer.encode(input_text, return_tensors="pt").to(grpo_model.device)
644
+ streamer = TextIteratorStreamer(grpo_tokenizer, skip_prompt=True, skip_special_tokens=True)
645
+ generation_kwargs = {
646
+ "input_ids": inputs,
647
+ "max_new_tokens": 100,
648
+ "temperature": 0.2,
649
+ "top_p": 0.9,
650
+ "do_sample": True,
651
+ "use_cache": False,
652
+ "streamer": streamer,
653
+ }
654
+ thread = Thread(target=grpo_model.generate, kwargs=generation_kwargs)
655
+ thread.start()
656
+ buffer = ""
657
+ yield "๐Ÿค” Thinking..."
658
+ for new_text in streamer:
659
+ buffer += new_text
660
+ time.sleep(0.01)
661
+ yield buffer
662
+ return
663
+
664
  # --- Text and TTS branch ---
665
  tts_prefix = "@tts"
666
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
 
759
  ["@rAgent Explain how a binary search algorithm works."],
760
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
761
  ["@tts1 Explain Tower of Hanoi"],
762
+ ["@grpo If there are 12 cookies in a dozen and you have 5 dozen, how many cookies do you have?"],
763
  ],
764
  cache_examples=False,
765
  type="messages",
 
770
  label="Query Input",
771
  file_types=["image", "audio"],
772
  file_count="multiple",
773
+ placeholder="@tts1, @tts2, @image, @3d, @phi4 [image, audio], @rAgent, @web, @yolo, @grpo, default [plain text]"
774
  ),
775
  stop_btn="Stop Generation",
776
  multimodal=True,