Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -33,9 +33,8 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
| 33 |
|
| 34 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 35 |
|
| 36 |
-
# -----------------------
|
| 37 |
# PROGRESS BAR HELPER
|
| 38 |
-
|
| 39 |
def progress_bar_html(label: str) -> str:
|
| 40 |
"""
|
| 41 |
Returns an HTML snippet for a thin progress bar with a label.
|
|
@@ -56,9 +55,8 @@ def progress_bar_html(label: str) -> str:
|
|
| 56 |
</style>
|
| 57 |
'''
|
| 58 |
|
| 59 |
-
# -----------------------
|
| 60 |
# TEXT & TTS MODELS
|
| 61 |
-
|
| 62 |
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
| 63 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 64 |
model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -73,9 +71,8 @@ TTS_VOICES = [
|
|
| 73 |
"en-US-GuyNeural", # @tts2
|
| 74 |
]
|
| 75 |
|
| 76 |
-
# -----------------------
|
| 77 |
# MULTIMODAL (OCR) MODELS
|
| 78 |
-
|
| 79 |
MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
|
| 80 |
processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
|
| 81 |
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
@@ -84,15 +81,6 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
| 84 |
torch_dtype=torch.float16
|
| 85 |
).to("cuda").eval()
|
| 86 |
|
| 87 |
-
# -----------------------
|
| 88 |
-
# GEMMA3-4B MODEL SETUP (NEW FEATURE)
|
| 89 |
-
# -----------------------
|
| 90 |
-
gemma3_model_id = "google/gemma-3-4b-it"
|
| 91 |
-
gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
|
| 92 |
-
gemma3_model_id, device_map="auto"
|
| 93 |
-
).eval()
|
| 94 |
-
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
| 95 |
-
|
| 96 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
| 97 |
communicate = edge_tts.Communicate(text, voice)
|
| 98 |
await communicate.save(output_file)
|
|
@@ -130,9 +118,9 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
|
| 130 |
|
| 131 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
| 132 |
|
| 133 |
-
|
| 134 |
# STABLE DIFFUSION IMAGE GENERATION MODELS
|
| 135 |
-
|
| 136 |
if torch.cuda.is_available():
|
| 137 |
# Lightning 5 model
|
| 138 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
@@ -218,9 +206,18 @@ def save_image(img: Image.Image) -> str:
|
|
| 218 |
img.save(unique_name)
|
| 219 |
return unique_name
|
| 220 |
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
# MAIN GENERATION FUNCTION
|
| 223 |
-
|
| 224 |
@spaces.GPU
|
| 225 |
def generate(
|
| 226 |
input_dict: dict,
|
|
@@ -235,8 +232,8 @@ def generate(
|
|
| 235 |
files = input_dict.get("files", [])
|
| 236 |
|
| 237 |
lower_text = text.lower().strip()
|
| 238 |
-
|
| 239 |
-
#
|
| 240 |
if (lower_text.startswith("@lightningv5") or
|
| 241 |
lower_text.startswith("@lightningv4") or
|
| 242 |
lower_text.startswith("@turbov3")):
|
|
@@ -288,52 +285,53 @@ def generate(
|
|
| 288 |
yield gr.Image(image_path)
|
| 289 |
return
|
| 290 |
|
| 291 |
-
#
|
| 292 |
if lower_text.startswith("@gemma3-4b"):
|
| 293 |
-
# Remove the flag from the
|
| 294 |
prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
|
| 295 |
-
# Build messages: include a system message and user message.
|
| 296 |
-
messages = []
|
| 297 |
-
messages.append({
|
| 298 |
-
"role": "system",
|
| 299 |
-
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
| 300 |
-
})
|
| 301 |
-
user_content = []
|
| 302 |
if files:
|
| 303 |
-
# If
|
| 304 |
-
images = [load_image(
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
| 315 |
inputs = gemma3_processor.apply_chat_template(
|
| 316 |
messages, add_generation_prompt=True, tokenize=True,
|
| 317 |
return_dict=True, return_tensors="pt"
|
| 318 |
).to(gemma3_model.device, dtype=torch.bfloat16)
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
| 325 |
thread.start()
|
| 326 |
-
|
| 327 |
buffer = ""
|
| 328 |
-
yield progress_bar_html("Processing with Gemma3-
|
| 329 |
for new_text in streamer:
|
| 330 |
buffer += new_text
|
|
|
|
| 331 |
yield buffer
|
| 332 |
-
final_response = buffer
|
| 333 |
-
yield final_response
|
| 334 |
return
|
| 335 |
|
| 336 |
-
#
|
| 337 |
tts_prefix = "@tts"
|
| 338 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
| 339 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
|
|
|
| 33 |
|
| 34 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 35 |
|
|
|
|
| 36 |
# PROGRESS BAR HELPER
|
| 37 |
+
|
| 38 |
def progress_bar_html(label: str) -> str:
|
| 39 |
"""
|
| 40 |
Returns an HTML snippet for a thin progress bar with a label.
|
|
|
|
| 55 |
</style>
|
| 56 |
'''
|
| 57 |
|
|
|
|
| 58 |
# TEXT & TTS MODELS
|
| 59 |
+
|
| 60 |
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
| 61 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 62 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 71 |
"en-US-GuyNeural", # @tts2
|
| 72 |
]
|
| 73 |
|
|
|
|
| 74 |
# MULTIMODAL (OCR) MODELS
|
| 75 |
+
|
| 76 |
MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
|
| 77 |
processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
|
| 78 |
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
|
|
| 81 |
torch_dtype=torch.float16
|
| 82 |
).to("cuda").eval()
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
| 85 |
communicate = edge_tts.Communicate(text, voice)
|
| 86 |
await communicate.save(output_file)
|
|
|
|
| 118 |
|
| 119 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
| 120 |
|
| 121 |
+
|
| 122 |
# STABLE DIFFUSION IMAGE GENERATION MODELS
|
| 123 |
+
|
| 124 |
if torch.cuda.is_available():
|
| 125 |
# Lightning 5 model
|
| 126 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
|
| 206 |
img.save(unique_name)
|
| 207 |
return unique_name
|
| 208 |
|
| 209 |
+
|
| 210 |
+
# GEMMA3-4B MULTIMODAL MODEL
|
| 211 |
+
|
| 212 |
+
gemma3_model_id = "google/gemma-3-4b-it"
|
| 213 |
+
gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
|
| 214 |
+
gemma3_model_id, device_map="auto"
|
| 215 |
+
).eval()
|
| 216 |
+
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
# MAIN GENERATION FUNCTION
|
| 220 |
+
|
| 221 |
@spaces.GPU
|
| 222 |
def generate(
|
| 223 |
input_dict: dict,
|
|
|
|
| 232 |
files = input_dict.get("files", [])
|
| 233 |
|
| 234 |
lower_text = text.lower().strip()
|
| 235 |
+
|
| 236 |
+
# Image Generation Branch (Stable Diffusion models)
|
| 237 |
if (lower_text.startswith("@lightningv5") or
|
| 238 |
lower_text.startswith("@lightningv4") or
|
| 239 |
lower_text.startswith("@turbov3")):
|
|
|
|
| 285 |
yield gr.Image(image_path)
|
| 286 |
return
|
| 287 |
|
| 288 |
+
# GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
|
| 289 |
if lower_text.startswith("@gemma3-4b"):
|
| 290 |
+
# Remove the gemma3 flag from the prompt.
|
| 291 |
prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
if files:
|
| 293 |
+
# If image files are provided, load them.
|
| 294 |
+
images = [load_image(f) for f in files]
|
| 295 |
+
messages = [{
|
| 296 |
+
"role": "user",
|
| 297 |
+
"content": [
|
| 298 |
+
*[{"type": "image", "image": image} for image in images],
|
| 299 |
+
{"type": "text", "text": prompt_clean},
|
| 300 |
+
]
|
| 301 |
+
}]
|
| 302 |
+
else:
|
| 303 |
+
messages = [
|
| 304 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
| 305 |
+
{"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
|
| 306 |
+
]
|
| 307 |
inputs = gemma3_processor.apply_chat_template(
|
| 308 |
messages, add_generation_prompt=True, tokenize=True,
|
| 309 |
return_dict=True, return_tensors="pt"
|
| 310 |
).to(gemma3_model.device, dtype=torch.bfloat16)
|
| 311 |
+
streamer = TextIteratorStreamer(
|
| 312 |
+
gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
|
| 313 |
+
)
|
| 314 |
+
generation_kwargs = {
|
| 315 |
+
**inputs,
|
| 316 |
+
"streamer": streamer,
|
| 317 |
+
"max_new_tokens": max_new_tokens,
|
| 318 |
+
"do_sample": True,
|
| 319 |
+
"temperature": temperature,
|
| 320 |
+
"top_p": top_p,
|
| 321 |
+
"top_k": top_k,
|
| 322 |
+
"repetition_penalty": repetition_penalty,
|
| 323 |
+
}
|
| 324 |
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
| 325 |
thread.start()
|
|
|
|
| 326 |
buffer = ""
|
| 327 |
+
yield progress_bar_html("Processing with Gemma3-4b")
|
| 328 |
for new_text in streamer:
|
| 329 |
buffer += new_text
|
| 330 |
+
time.sleep(0.01)
|
| 331 |
yield buffer
|
|
|
|
|
|
|
| 332 |
return
|
| 333 |
|
| 334 |
+
# Otherwise, handle text/chat (and TTS) generation.
|
| 335 |
tts_prefix = "@tts"
|
| 336 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
| 337 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|