prithivMLmods commited on
Commit
4bfb84e
·
verified ·
1 Parent(s): 46ce972

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -2
app.py CHANGED
@@ -20,6 +20,7 @@ from transformers import (
20
  TextIteratorStreamer,
21
  Qwen2VLForConditionalGeneration,
22
  AutoProcessor,
 
23
  )
24
  from transformers.image_utils import load_image
25
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
@@ -208,6 +209,15 @@ def save_image(img: Image.Image) -> str:
208
  img.save(unique_name)
209
  return unique_name
210
 
 
 
 
 
 
 
 
 
 
211
  # -----------------------
212
  # MAIN GENERATION FUNCTION
213
  # -----------------------
@@ -225,7 +235,8 @@ def generate(
225
  files = input_dict.get("files", [])
226
 
227
  lower_text = text.lower().strip()
228
- # If the prompt is an image generation command (using model flags)
 
229
  if (lower_text.startswith("@lightningv5") or
230
  lower_text.startswith("@lightningv4") or
231
  lower_text.startswith("@turbov3")):
@@ -277,6 +288,52 @@ def generate(
277
  yield gr.Image(image_path)
278
  return
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # Otherwise, handle text/chat (and TTS) generation.
281
  tts_prefix = "@tts"
282
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
@@ -391,7 +448,7 @@ demo = gr.ChatInterface(
391
  description=DESCRIPTION,
392
  css=css,
393
  fill_height=True,
394
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @lightningv5 @lightningv4 @turbov3 for image gen !"),
395
  stop_btn="Stop Generation",
396
  multimodal=True,
397
  )
 
20
  TextIteratorStreamer,
21
  Qwen2VLForConditionalGeneration,
22
  AutoProcessor,
23
+ Gemma3ForConditionalGeneration, # New import for Gemma3-4B
24
  )
25
  from transformers.image_utils import load_image
26
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
209
  img.save(unique_name)
210
  return unique_name
211
 
212
+ # -----------------------
213
+ # GEMMA3-4B MULTIMODAL MODEL
214
+ # -----------------------
215
+ gemma3_model_id = "google/gemma-3-4b-it"
216
+ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
217
+ gemma3_model_id, device_map="auto"
218
+ ).eval()
219
+ gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
220
+
221
  # -----------------------
222
  # MAIN GENERATION FUNCTION
223
  # -----------------------
 
235
  files = input_dict.get("files", [])
236
 
237
  lower_text = text.lower().strip()
238
+
239
+ # Image Generation Branch (Stable Diffusion models)
240
  if (lower_text.startswith("@lightningv5") or
241
  lower_text.startswith("@lightningv4") or
242
  lower_text.startswith("@turbov3")):
 
288
  yield gr.Image(image_path)
289
  return
290
 
291
+ # GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
292
+ if lower_text.startswith("@gemma3-4b"):
293
+ # Remove the gemma3 flag from the prompt.
294
+ prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
295
+ if files:
296
+ # If image files are provided, load them.
297
+ images = [load_image(f) for f in files]
298
+ messages = [{
299
+ "role": "user",
300
+ "content": [
301
+ *[{"type": "image", "image": image} for image in images],
302
+ {"type": "text", "text": prompt_clean},
303
+ ]
304
+ }]
305
+ else:
306
+ messages = [
307
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
308
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
309
+ ]
310
+ inputs = gemma3_processor.apply_chat_template(
311
+ messages, add_generation_prompt=True, tokenize=True,
312
+ return_dict=True, return_tensors="pt"
313
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
314
+ streamer = TextIteratorStreamer(
315
+ gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
316
+ )
317
+ generation_kwargs = {
318
+ **inputs,
319
+ "streamer": streamer,
320
+ "max_new_tokens": max_new_tokens,
321
+ "do_sample": True,
322
+ "temperature": temperature,
323
+ "top_p": top_p,
324
+ "top_k": top_k,
325
+ "repetition_penalty": repetition_penalty,
326
+ }
327
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
328
+ thread.start()
329
+ buffer = ""
330
+ yield progress_bar_html("Processing with Gemma3-4b")
331
+ for new_text in streamer:
332
+ buffer += new_text
333
+ time.sleep(0.01)
334
+ yield buffer
335
+ return
336
+
337
  # Otherwise, handle text/chat (and TTS) generation.
338
  tts_prefix = "@tts"
339
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
 
448
  description=DESCRIPTION,
449
  css=css,
450
  fill_height=True,
451
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @lightningv5 @lightningv4 @turbov3 or @gemma3-4b for multimodal gen !"),
452
  stop_btn="Stop Generation",
453
  multimodal=True,
454
  )