Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -33,9 +33,8 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
33 |
|
34 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
35 |
|
36 |
-
# -----------------------
|
37 |
# PROGRESS BAR HELPER
|
38 |
-
|
39 |
def progress_bar_html(label: str) -> str:
|
40 |
"""
|
41 |
Returns an HTML snippet for a thin progress bar with a label.
|
@@ -56,9 +55,8 @@ def progress_bar_html(label: str) -> str:
|
|
56 |
</style>
|
57 |
'''
|
58 |
|
59 |
-
# -----------------------
|
60 |
# TEXT & TTS MODELS
|
61 |
-
|
62 |
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
63 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
64 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -73,9 +71,8 @@ TTS_VOICES = [
|
|
73 |
"en-US-GuyNeural", # @tts2
|
74 |
]
|
75 |
|
76 |
-
# -----------------------
|
77 |
# MULTIMODAL (OCR) MODELS
|
78 |
-
|
79 |
MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
|
80 |
processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
|
81 |
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
@@ -84,15 +81,6 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
84 |
torch_dtype=torch.float16
|
85 |
).to("cuda").eval()
|
86 |
|
87 |
-
# -----------------------
|
88 |
-
# GEMMA3-4B MODEL SETUP (NEW FEATURE)
|
89 |
-
# -----------------------
|
90 |
-
gemma3_model_id = "google/gemma-3-4b-it"
|
91 |
-
gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
|
92 |
-
gemma3_model_id, device_map="auto"
|
93 |
-
).eval()
|
94 |
-
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
95 |
-
|
96 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
97 |
communicate = edge_tts.Communicate(text, voice)
|
98 |
await communicate.save(output_file)
|
@@ -130,9 +118,9 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
|
130 |
|
131 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
132 |
|
133 |
-
|
134 |
# STABLE DIFFUSION IMAGE GENERATION MODELS
|
135 |
-
|
136 |
if torch.cuda.is_available():
|
137 |
# Lightning 5 model
|
138 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
@@ -218,9 +206,18 @@ def save_image(img: Image.Image) -> str:
|
|
218 |
img.save(unique_name)
|
219 |
return unique_name
|
220 |
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
# MAIN GENERATION FUNCTION
|
223 |
-
|
224 |
@spaces.GPU
|
225 |
def generate(
|
226 |
input_dict: dict,
|
@@ -235,8 +232,8 @@ def generate(
|
|
235 |
files = input_dict.get("files", [])
|
236 |
|
237 |
lower_text = text.lower().strip()
|
238 |
-
|
239 |
-
#
|
240 |
if (lower_text.startswith("@lightningv5") or
|
241 |
lower_text.startswith("@lightningv4") or
|
242 |
lower_text.startswith("@turbov3")):
|
@@ -288,52 +285,53 @@ def generate(
|
|
288 |
yield gr.Image(image_path)
|
289 |
return
|
290 |
|
291 |
-
#
|
292 |
if lower_text.startswith("@gemma3-4b"):
|
293 |
-
# Remove the flag from the
|
294 |
prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
|
295 |
-
# Build messages: include a system message and user message.
|
296 |
-
messages = []
|
297 |
-
messages.append({
|
298 |
-
"role": "system",
|
299 |
-
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
300 |
-
})
|
301 |
-
user_content = []
|
302 |
if files:
|
303 |
-
# If
|
304 |
-
images = [load_image(
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
|
|
|
|
315 |
inputs = gemma3_processor.apply_chat_template(
|
316 |
messages, add_generation_prompt=True, tokenize=True,
|
317 |
return_dict=True, return_tensors="pt"
|
318 |
).to(gemma3_model.device, dtype=torch.bfloat16)
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
325 |
thread.start()
|
326 |
-
|
327 |
buffer = ""
|
328 |
-
yield progress_bar_html("Processing with Gemma3-
|
329 |
for new_text in streamer:
|
330 |
buffer += new_text
|
|
|
331 |
yield buffer
|
332 |
-
final_response = buffer
|
333 |
-
yield final_response
|
334 |
return
|
335 |
|
336 |
-
#
|
337 |
tts_prefix = "@tts"
|
338 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
339 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
|
|
33 |
|
34 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
35 |
|
|
|
36 |
# PROGRESS BAR HELPER
|
37 |
+
|
38 |
def progress_bar_html(label: str) -> str:
|
39 |
"""
|
40 |
Returns an HTML snippet for a thin progress bar with a label.
|
|
|
55 |
</style>
|
56 |
'''
|
57 |
|
|
|
58 |
# TEXT & TTS MODELS
|
59 |
+
|
60 |
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
61 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
62 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
71 |
"en-US-GuyNeural", # @tts2
|
72 |
]
|
73 |
|
|
|
74 |
# MULTIMODAL (OCR) MODELS
|
75 |
+
|
76 |
MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
|
77 |
processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
|
78 |
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
|
81 |
torch_dtype=torch.float16
|
82 |
).to("cuda").eval()
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
85 |
communicate = edge_tts.Communicate(text, voice)
|
86 |
await communicate.save(output_file)
|
|
|
118 |
|
119 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
120 |
|
121 |
+
|
122 |
# STABLE DIFFUSION IMAGE GENERATION MODELS
|
123 |
+
|
124 |
if torch.cuda.is_available():
|
125 |
# Lightning 5 model
|
126 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
206 |
img.save(unique_name)
|
207 |
return unique_name
|
208 |
|
209 |
+
|
210 |
+
# GEMMA3-4B MULTIMODAL MODEL
|
211 |
+
|
212 |
+
gemma3_model_id = "google/gemma-3-4b-it"
|
213 |
+
gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
|
214 |
+
gemma3_model_id, device_map="auto"
|
215 |
+
).eval()
|
216 |
+
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
217 |
+
|
218 |
+
|
219 |
# MAIN GENERATION FUNCTION
|
220 |
+
|
221 |
@spaces.GPU
|
222 |
def generate(
|
223 |
input_dict: dict,
|
|
|
232 |
files = input_dict.get("files", [])
|
233 |
|
234 |
lower_text = text.lower().strip()
|
235 |
+
|
236 |
+
# Image Generation Branch (Stable Diffusion models)
|
237 |
if (lower_text.startswith("@lightningv5") or
|
238 |
lower_text.startswith("@lightningv4") or
|
239 |
lower_text.startswith("@turbov3")):
|
|
|
285 |
yield gr.Image(image_path)
|
286 |
return
|
287 |
|
288 |
+
# GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
|
289 |
if lower_text.startswith("@gemma3-4b"):
|
290 |
+
# Remove the gemma3 flag from the prompt.
|
291 |
prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
if files:
|
293 |
+
# If image files are provided, load them.
|
294 |
+
images = [load_image(f) for f in files]
|
295 |
+
messages = [{
|
296 |
+
"role": "user",
|
297 |
+
"content": [
|
298 |
+
*[{"type": "image", "image": image} for image in images],
|
299 |
+
{"type": "text", "text": prompt_clean},
|
300 |
+
]
|
301 |
+
}]
|
302 |
+
else:
|
303 |
+
messages = [
|
304 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
305 |
+
{"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
|
306 |
+
]
|
307 |
inputs = gemma3_processor.apply_chat_template(
|
308 |
messages, add_generation_prompt=True, tokenize=True,
|
309 |
return_dict=True, return_tensors="pt"
|
310 |
).to(gemma3_model.device, dtype=torch.bfloat16)
|
311 |
+
streamer = TextIteratorStreamer(
|
312 |
+
gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
|
313 |
+
)
|
314 |
+
generation_kwargs = {
|
315 |
+
**inputs,
|
316 |
+
"streamer": streamer,
|
317 |
+
"max_new_tokens": max_new_tokens,
|
318 |
+
"do_sample": True,
|
319 |
+
"temperature": temperature,
|
320 |
+
"top_p": top_p,
|
321 |
+
"top_k": top_k,
|
322 |
+
"repetition_penalty": repetition_penalty,
|
323 |
+
}
|
324 |
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
325 |
thread.start()
|
|
|
326 |
buffer = ""
|
327 |
+
yield progress_bar_html("Processing with Gemma3-4b")
|
328 |
for new_text in streamer:
|
329 |
buffer += new_text
|
330 |
+
time.sleep(0.01)
|
331 |
yield buffer
|
|
|
|
|
332 |
return
|
333 |
|
334 |
+
# Otherwise, handle text/chat (and TTS) generation.
|
335 |
tts_prefix = "@tts"
|
336 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
337 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|