prithivMLmods commited on
Commit
51559fa
·
verified ·
1 Parent(s): dc7620f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -43
app.py CHANGED
@@ -84,6 +84,15 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
84
  torch_dtype=torch.float16
85
  ).to("cuda").eval()
86
 
 
 
 
 
 
 
 
 
 
87
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
88
  communicate = edge_tts.Communicate(text, voice)
89
  await communicate.save(output_file)
@@ -209,15 +218,6 @@ def save_image(img: Image.Image) -> str:
209
  img.save(unique_name)
210
  return unique_name
211
 
212
- # -----------------------
213
- # GEMMA3-4B MULTIMODAL MODEL
214
- # -----------------------
215
- gemma3_model_id = "google/gemma-3-4b-it"
216
- gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
217
- gemma3_model_id, device_map="auto"
218
- ).eval()
219
- gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
220
-
221
  # -----------------------
222
  # MAIN GENERATION FUNCTION
223
  # -----------------------
@@ -235,8 +235,8 @@ def generate(
235
  files = input_dict.get("files", [])
236
 
237
  lower_text = text.lower().strip()
238
-
239
- # Image Generation Branch (Stable Diffusion models)
240
  if (lower_text.startswith("@lightningv5") or
241
  lower_text.startswith("@lightningv4") or
242
  lower_text.startswith("@turbov3")):
@@ -288,53 +288,52 @@ def generate(
288
  yield gr.Image(image_path)
289
  return
290
 
291
- # GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
292
  if lower_text.startswith("@gemma3-4b"):
293
- # Remove the gemma3 flag from the prompt.
294
  prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
 
 
 
 
 
 
 
295
  if files:
296
- # If image files are provided, load them.
297
- images = [load_image(f) for f in files]
298
- messages = [{
299
- "role": "user",
300
- "content": [
301
- *[{"type": "image", "image": image} for image in images],
302
- {"type": "text", "text": prompt_clean},
303
- ]
304
- }]
305
- else:
306
- messages = [
307
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
308
- {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
309
- ]
310
  inputs = gemma3_processor.apply_chat_template(
311
  messages, add_generation_prompt=True, tokenize=True,
312
  return_dict=True, return_tensors="pt"
313
  ).to(gemma3_model.device, dtype=torch.bfloat16)
314
- streamer = TextIteratorStreamer(
315
- gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
316
- )
317
- generation_kwargs = {
318
- **inputs,
319
- "streamer": streamer,
320
- "max_new_tokens": max_new_tokens,
321
- "do_sample": True,
322
- "temperature": temperature,
323
- "top_p": top_p,
324
- "top_k": top_k,
325
- "repetition_penalty": repetition_penalty,
326
- }
327
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
328
  thread.start()
 
329
  buffer = ""
330
- yield progress_bar_html("Processing with Gemma3-4b")
331
  for new_text in streamer:
332
  buffer += new_text
333
- time.sleep(0.01)
334
  yield buffer
 
 
335
  return
336
 
337
- # Otherwise, handle text/chat (and TTS) generation.
338
  tts_prefix = "@tts"
339
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
340
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
84
  torch_dtype=torch.float16
85
  ).to("cuda").eval()
86
 
87
+ # -----------------------
88
+ # GEMMA3-4B MODEL SETUP (NEW FEATURE)
89
+ # -----------------------
90
+ gemma3_model_id = "google/gemma-3-4b-it"
91
+ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
92
+ gemma3_model_id, device_map="auto"
93
+ ).eval()
94
+ gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
95
+
96
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
97
  communicate = edge_tts.Communicate(text, voice)
98
  await communicate.save(output_file)
 
218
  img.save(unique_name)
219
  return unique_name
220
 
 
 
 
 
 
 
 
 
 
221
  # -----------------------
222
  # MAIN GENERATION FUNCTION
223
  # -----------------------
 
235
  files = input_dict.get("files", [])
236
 
237
  lower_text = text.lower().strip()
238
+
239
+ # 1. IMAGE GENERATION COMMANDS (@lightningv5, @lightningv4, @turbov3)
240
  if (lower_text.startswith("@lightningv5") or
241
  lower_text.startswith("@lightningv4") or
242
  lower_text.startswith("@turbov3")):
 
288
  yield gr.Image(image_path)
289
  return
290
 
291
+ # 2. GEMMA3-4B MULTIMODAL GENERATION (NEW FEATURE)
292
  if lower_text.startswith("@gemma3-4b"):
293
+ # Remove the flag from the text prompt.
294
  prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
295
+ # Build messages: include a system message and user message.
296
+ messages = []
297
+ messages.append({
298
+ "role": "system",
299
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
300
+ })
301
+ user_content = []
302
  if files:
303
+ # If images are uploaded, load them and add them to the message.
304
+ images = [load_image(image) for image in files]
305
+ for img in images:
306
+ user_content.append({"type": "image", "image": img})
307
+ # Add the text part.
308
+ user_content.append({"type": "text", "text": prompt_clean})
309
+ messages.append({
310
+ "role": "user",
311
+ "content": user_content
312
+ })
313
+
314
+ # Prepare inputs using Gemma3's processor.
 
 
315
  inputs = gemma3_processor.apply_chat_template(
316
  messages, add_generation_prompt=True, tokenize=True,
317
  return_dict=True, return_tensors="pt"
318
  ).to(gemma3_model.device, dtype=torch.bfloat16)
319
+
320
+ input_len = inputs["input_ids"].shape[-1]
321
+ # Create a text streamer for incremental generation.
322
+ streamer = TextIteratorStreamer(gemma3_processor, skip_prompt=True, skip_special_tokens=True)
323
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
 
 
 
 
 
 
 
 
324
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
325
  thread.start()
326
+
327
  buffer = ""
328
+ yield progress_bar_html("Processing with Gemma3-4B")
329
  for new_text in streamer:
330
  buffer += new_text
 
331
  yield buffer
332
+ final_response = buffer
333
+ yield final_response
334
  return
335
 
336
+ # 3. TEXT & TTS GENERATION
337
  tts_prefix = "@tts"
338
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
339
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)