Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -84,6 +84,15 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
84 |
torch_dtype=torch.float16
|
85 |
).to("cuda").eval()
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
88 |
communicate = edge_tts.Communicate(text, voice)
|
89 |
await communicate.save(output_file)
|
@@ -209,15 +218,6 @@ def save_image(img: Image.Image) -> str:
|
|
209 |
img.save(unique_name)
|
210 |
return unique_name
|
211 |
|
212 |
-
# -----------------------
|
213 |
-
# GEMMA3-4B MULTIMODAL MODEL
|
214 |
-
# -----------------------
|
215 |
-
gemma3_model_id = "google/gemma-3-4b-it"
|
216 |
-
gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
|
217 |
-
gemma3_model_id, device_map="auto"
|
218 |
-
).eval()
|
219 |
-
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
220 |
-
|
221 |
# -----------------------
|
222 |
# MAIN GENERATION FUNCTION
|
223 |
# -----------------------
|
@@ -235,8 +235,8 @@ def generate(
|
|
235 |
files = input_dict.get("files", [])
|
236 |
|
237 |
lower_text = text.lower().strip()
|
238 |
-
|
239 |
-
#
|
240 |
if (lower_text.startswith("@lightningv5") or
|
241 |
lower_text.startswith("@lightningv4") or
|
242 |
lower_text.startswith("@turbov3")):
|
@@ -288,53 +288,52 @@ def generate(
|
|
288 |
yield gr.Image(image_path)
|
289 |
return
|
290 |
|
291 |
-
# GEMMA3-4B
|
292 |
if lower_text.startswith("@gemma3-4b"):
|
293 |
-
# Remove the
|
294 |
prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
if files:
|
296 |
-
# If
|
297 |
-
images = [load_image(
|
298 |
-
|
299 |
-
"
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
{"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
|
309 |
-
]
|
310 |
inputs = gemma3_processor.apply_chat_template(
|
311 |
messages, add_generation_prompt=True, tokenize=True,
|
312 |
return_dict=True, return_tensors="pt"
|
313 |
).to(gemma3_model.device, dtype=torch.bfloat16)
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
"streamer": streamer,
|
320 |
-
"max_new_tokens": max_new_tokens,
|
321 |
-
"do_sample": True,
|
322 |
-
"temperature": temperature,
|
323 |
-
"top_p": top_p,
|
324 |
-
"top_k": top_k,
|
325 |
-
"repetition_penalty": repetition_penalty,
|
326 |
-
}
|
327 |
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
328 |
thread.start()
|
|
|
329 |
buffer = ""
|
330 |
-
yield progress_bar_html("Processing with Gemma3-
|
331 |
for new_text in streamer:
|
332 |
buffer += new_text
|
333 |
-
time.sleep(0.01)
|
334 |
yield buffer
|
|
|
|
|
335 |
return
|
336 |
|
337 |
-
#
|
338 |
tts_prefix = "@tts"
|
339 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
340 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
|
|
84 |
torch_dtype=torch.float16
|
85 |
).to("cuda").eval()
|
86 |
|
87 |
+
# -----------------------
|
88 |
+
# GEMMA3-4B MODEL SETUP (NEW FEATURE)
|
89 |
+
# -----------------------
|
90 |
+
gemma3_model_id = "google/gemma-3-4b-it"
|
91 |
+
gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
|
92 |
+
gemma3_model_id, device_map="auto"
|
93 |
+
).eval()
|
94 |
+
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
95 |
+
|
96 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
97 |
communicate = edge_tts.Communicate(text, voice)
|
98 |
await communicate.save(output_file)
|
|
|
218 |
img.save(unique_name)
|
219 |
return unique_name
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
# -----------------------
|
222 |
# MAIN GENERATION FUNCTION
|
223 |
# -----------------------
|
|
|
235 |
files = input_dict.get("files", [])
|
236 |
|
237 |
lower_text = text.lower().strip()
|
238 |
+
|
239 |
+
# 1. IMAGE GENERATION COMMANDS (@lightningv5, @lightningv4, @turbov3)
|
240 |
if (lower_text.startswith("@lightningv5") or
|
241 |
lower_text.startswith("@lightningv4") or
|
242 |
lower_text.startswith("@turbov3")):
|
|
|
288 |
yield gr.Image(image_path)
|
289 |
return
|
290 |
|
291 |
+
# 2. GEMMA3-4B MULTIMODAL GENERATION (NEW FEATURE)
|
292 |
if lower_text.startswith("@gemma3-4b"):
|
293 |
+
# Remove the flag from the text prompt.
|
294 |
prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
|
295 |
+
# Build messages: include a system message and user message.
|
296 |
+
messages = []
|
297 |
+
messages.append({
|
298 |
+
"role": "system",
|
299 |
+
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
300 |
+
})
|
301 |
+
user_content = []
|
302 |
if files:
|
303 |
+
# If images are uploaded, load them and add them to the message.
|
304 |
+
images = [load_image(image) for image in files]
|
305 |
+
for img in images:
|
306 |
+
user_content.append({"type": "image", "image": img})
|
307 |
+
# Add the text part.
|
308 |
+
user_content.append({"type": "text", "text": prompt_clean})
|
309 |
+
messages.append({
|
310 |
+
"role": "user",
|
311 |
+
"content": user_content
|
312 |
+
})
|
313 |
+
|
314 |
+
# Prepare inputs using Gemma3's processor.
|
|
|
|
|
315 |
inputs = gemma3_processor.apply_chat_template(
|
316 |
messages, add_generation_prompt=True, tokenize=True,
|
317 |
return_dict=True, return_tensors="pt"
|
318 |
).to(gemma3_model.device, dtype=torch.bfloat16)
|
319 |
+
|
320 |
+
input_len = inputs["input_ids"].shape[-1]
|
321 |
+
# Create a text streamer for incremental generation.
|
322 |
+
streamer = TextIteratorStreamer(gemma3_processor, skip_prompt=True, skip_special_tokens=True)
|
323 |
+
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
325 |
thread.start()
|
326 |
+
|
327 |
buffer = ""
|
328 |
+
yield progress_bar_html("Processing with Gemma3-4B")
|
329 |
for new_text in streamer:
|
330 |
buffer += new_text
|
|
|
331 |
yield buffer
|
332 |
+
final_response = buffer
|
333 |
+
yield final_response
|
334 |
return
|
335 |
|
336 |
+
# 3. TEXT & TTS GENERATION
|
337 |
tts_prefix = "@tts"
|
338 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
339 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|