Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -20,6 +20,7 @@ from transformers import (
|
|
20 |
TextIteratorStreamer,
|
21 |
Qwen2VLForConditionalGeneration,
|
22 |
AutoProcessor,
|
|
|
23 |
)
|
24 |
from transformers.image_utils import load_image
|
25 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
@@ -208,6 +209,15 @@ def save_image(img: Image.Image) -> str:
|
|
208 |
img.save(unique_name)
|
209 |
return unique_name
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
# -----------------------
|
212 |
# MAIN GENERATION FUNCTION
|
213 |
# -----------------------
|
@@ -225,7 +235,8 @@ def generate(
|
|
225 |
files = input_dict.get("files", [])
|
226 |
|
227 |
lower_text = text.lower().strip()
|
228 |
-
|
|
|
229 |
if (lower_text.startswith("@lightningv5") or
|
230 |
lower_text.startswith("@lightningv4") or
|
231 |
lower_text.startswith("@turbov3")):
|
@@ -277,6 +288,52 @@ def generate(
|
|
277 |
yield gr.Image(image_path)
|
278 |
return
|
279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
# Otherwise, handle text/chat (and TTS) generation.
|
281 |
tts_prefix = "@tts"
|
282 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
@@ -391,7 +448,7 @@ demo = gr.ChatInterface(
|
|
391 |
description=DESCRIPTION,
|
392 |
css=css,
|
393 |
fill_height=True,
|
394 |
-
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @lightningv5 @lightningv4 @turbov3 for
|
395 |
stop_btn="Stop Generation",
|
396 |
multimodal=True,
|
397 |
)
|
|
|
20 |
TextIteratorStreamer,
|
21 |
Qwen2VLForConditionalGeneration,
|
22 |
AutoProcessor,
|
23 |
+
Gemma3ForConditionalGeneration, # New import for Gemma3-4B
|
24 |
)
|
25 |
from transformers.image_utils import load_image
|
26 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
|
|
209 |
img.save(unique_name)
|
210 |
return unique_name
|
211 |
|
212 |
+
# -----------------------
|
213 |
+
# GEMMA3-4B MULTIMODAL MODEL
|
214 |
+
# -----------------------
|
215 |
+
gemma3_model_id = "google/gemma-3-4b-it"
|
216 |
+
gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
|
217 |
+
gemma3_model_id, device_map="auto"
|
218 |
+
).eval()
|
219 |
+
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
220 |
+
|
221 |
# -----------------------
|
222 |
# MAIN GENERATION FUNCTION
|
223 |
# -----------------------
|
|
|
235 |
files = input_dict.get("files", [])
|
236 |
|
237 |
lower_text = text.lower().strip()
|
238 |
+
|
239 |
+
# Image Generation Branch (Stable Diffusion models)
|
240 |
if (lower_text.startswith("@lightningv5") or
|
241 |
lower_text.startswith("@lightningv4") or
|
242 |
lower_text.startswith("@turbov3")):
|
|
|
288 |
yield gr.Image(image_path)
|
289 |
return
|
290 |
|
291 |
+
# GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
|
292 |
+
if lower_text.startswith("@gemma3-4b"):
|
293 |
+
# Remove the gemma3 flag from the prompt.
|
294 |
+
prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
|
295 |
+
if files:
|
296 |
+
# If image files are provided, load them.
|
297 |
+
images = [load_image(f) for f in files]
|
298 |
+
messages = [{
|
299 |
+
"role": "user",
|
300 |
+
"content": [
|
301 |
+
*[{"type": "image", "image": image} for image in images],
|
302 |
+
{"type": "text", "text": prompt_clean},
|
303 |
+
]
|
304 |
+
}]
|
305 |
+
else:
|
306 |
+
messages = [
|
307 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
308 |
+
{"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
|
309 |
+
]
|
310 |
+
inputs = gemma3_processor.apply_chat_template(
|
311 |
+
messages, add_generation_prompt=True, tokenize=True,
|
312 |
+
return_dict=True, return_tensors="pt"
|
313 |
+
).to(gemma3_model.device, dtype=torch.bfloat16)
|
314 |
+
streamer = TextIteratorStreamer(
|
315 |
+
gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
|
316 |
+
)
|
317 |
+
generation_kwargs = {
|
318 |
+
**inputs,
|
319 |
+
"streamer": streamer,
|
320 |
+
"max_new_tokens": max_new_tokens,
|
321 |
+
"do_sample": True,
|
322 |
+
"temperature": temperature,
|
323 |
+
"top_p": top_p,
|
324 |
+
"top_k": top_k,
|
325 |
+
"repetition_penalty": repetition_penalty,
|
326 |
+
}
|
327 |
+
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
328 |
+
thread.start()
|
329 |
+
buffer = ""
|
330 |
+
yield progress_bar_html("Processing with Gemma3-4b")
|
331 |
+
for new_text in streamer:
|
332 |
+
buffer += new_text
|
333 |
+
time.sleep(0.01)
|
334 |
+
yield buffer
|
335 |
+
return
|
336 |
+
|
337 |
# Otherwise, handle text/chat (and TTS) generation.
|
338 |
tts_prefix = "@tts"
|
339 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
|
|
448 |
description=DESCRIPTION,
|
449 |
css=css,
|
450 |
fill_height=True,
|
451 |
+
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @lightningv5 @lightningv4 @turbov3 or @gemma3-4b for multimodal gen !"),
|
452 |
stop_btn="Stop Generation",
|
453 |
multimodal=True,
|
454 |
)
|