prithivMLmods commited on
Commit
a3c0180
·
verified ·
1 Parent(s): 51559fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -52
app.py CHANGED
@@ -33,9 +33,8 @@ MAX_SEED = np.iinfo(np.int32).max
33
 
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
 
36
- # -----------------------
37
  # PROGRESS BAR HELPER
38
- # -----------------------
39
  def progress_bar_html(label: str) -> str:
40
  """
41
  Returns an HTML snippet for a thin progress bar with a label.
@@ -56,9 +55,8 @@ def progress_bar_html(label: str) -> str:
56
  </style>
57
  '''
58
 
59
- # -----------------------
60
  # TEXT & TTS MODELS
61
- # -----------------------
62
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
63
  tokenizer = AutoTokenizer.from_pretrained(model_id)
64
  model = AutoModelForCausalLM.from_pretrained(
@@ -73,9 +71,8 @@ TTS_VOICES = [
73
  "en-US-GuyNeural", # @tts2
74
  ]
75
 
76
- # -----------------------
77
  # MULTIMODAL (OCR) MODELS
78
- # -----------------------
79
  MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
80
  processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
81
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -84,15 +81,6 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
84
  torch_dtype=torch.float16
85
  ).to("cuda").eval()
86
 
87
- # -----------------------
88
- # GEMMA3-4B MODEL SETUP (NEW FEATURE)
89
- # -----------------------
90
- gemma3_model_id = "google/gemma-3-4b-it"
91
- gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
92
- gemma3_model_id, device_map="auto"
93
- ).eval()
94
- gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
95
-
96
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
97
  communicate = edge_tts.Communicate(text, voice)
98
  await communicate.save(output_file)
@@ -130,9 +118,9 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
130
 
131
  dtype = torch.float16 if device.type == "cuda" else torch.float32
132
 
133
- # -----------------------
134
  # STABLE DIFFUSION IMAGE GENERATION MODELS
135
- # -----------------------
136
  if torch.cuda.is_available():
137
  # Lightning 5 model
138
  pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -218,9 +206,18 @@ def save_image(img: Image.Image) -> str:
218
  img.save(unique_name)
219
  return unique_name
220
 
221
- # -----------------------
 
 
 
 
 
 
 
 
 
222
  # MAIN GENERATION FUNCTION
223
- # -----------------------
224
  @spaces.GPU
225
  def generate(
226
  input_dict: dict,
@@ -235,8 +232,8 @@ def generate(
235
  files = input_dict.get("files", [])
236
 
237
  lower_text = text.lower().strip()
238
-
239
- # 1. IMAGE GENERATION COMMANDS (@lightningv5, @lightningv4, @turbov3)
240
  if (lower_text.startswith("@lightningv5") or
241
  lower_text.startswith("@lightningv4") or
242
  lower_text.startswith("@turbov3")):
@@ -288,52 +285,53 @@ def generate(
288
  yield gr.Image(image_path)
289
  return
290
 
291
- # 2. GEMMA3-4B MULTIMODAL GENERATION (NEW FEATURE)
292
  if lower_text.startswith("@gemma3-4b"):
293
- # Remove the flag from the text prompt.
294
  prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
295
- # Build messages: include a system message and user message.
296
- messages = []
297
- messages.append({
298
- "role": "system",
299
- "content": [{"type": "text", "text": "You are a helpful assistant."}]
300
- })
301
- user_content = []
302
  if files:
303
- # If images are uploaded, load them and add them to the message.
304
- images = [load_image(image) for image in files]
305
- for img in images:
306
- user_content.append({"type": "image", "image": img})
307
- # Add the text part.
308
- user_content.append({"type": "text", "text": prompt_clean})
309
- messages.append({
310
- "role": "user",
311
- "content": user_content
312
- })
313
-
314
- # Prepare inputs using Gemma3's processor.
 
 
315
  inputs = gemma3_processor.apply_chat_template(
316
  messages, add_generation_prompt=True, tokenize=True,
317
  return_dict=True, return_tensors="pt"
318
  ).to(gemma3_model.device, dtype=torch.bfloat16)
319
-
320
- input_len = inputs["input_ids"].shape[-1]
321
- # Create a text streamer for incremental generation.
322
- streamer = TextIteratorStreamer(gemma3_processor, skip_prompt=True, skip_special_tokens=True)
323
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
 
 
 
 
 
 
 
 
324
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
325
  thread.start()
326
-
327
  buffer = ""
328
- yield progress_bar_html("Processing with Gemma3-4B")
329
  for new_text in streamer:
330
  buffer += new_text
 
331
  yield buffer
332
- final_response = buffer
333
- yield final_response
334
  return
335
 
336
- # 3. TEXT & TTS GENERATION
337
  tts_prefix = "@tts"
338
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
339
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
33
 
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
 
 
36
  # PROGRESS BAR HELPER
37
+
38
  def progress_bar_html(label: str) -> str:
39
  """
40
  Returns an HTML snippet for a thin progress bar with a label.
 
55
  </style>
56
  '''
57
 
 
58
  # TEXT & TTS MODELS
59
+
60
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
61
  tokenizer = AutoTokenizer.from_pretrained(model_id)
62
  model = AutoModelForCausalLM.from_pretrained(
 
71
  "en-US-GuyNeural", # @tts2
72
  ]
73
 
 
74
  # MULTIMODAL (OCR) MODELS
75
+
76
  MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
77
  processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
78
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
81
  torch_dtype=torch.float16
82
  ).to("cuda").eval()
83
 
 
 
 
 
 
 
 
 
 
84
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
85
  communicate = edge_tts.Communicate(text, voice)
86
  await communicate.save(output_file)
 
118
 
119
  dtype = torch.float16 if device.type == "cuda" else torch.float32
120
 
121
+
122
  # STABLE DIFFUSION IMAGE GENERATION MODELS
123
+
124
  if torch.cuda.is_available():
125
  # Lightning 5 model
126
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
206
  img.save(unique_name)
207
  return unique_name
208
 
209
+
210
+ # GEMMA3-4B MULTIMODAL MODEL
211
+
212
+ gemma3_model_id = "google/gemma-3-4b-it"
213
+ gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
214
+ gemma3_model_id, device_map="auto"
215
+ ).eval()
216
+ gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
217
+
218
+
219
  # MAIN GENERATION FUNCTION
220
+
221
  @spaces.GPU
222
  def generate(
223
  input_dict: dict,
 
232
  files = input_dict.get("files", [])
233
 
234
  lower_text = text.lower().strip()
235
+
236
+ # Image Generation Branch (Stable Diffusion models)
237
  if (lower_text.startswith("@lightningv5") or
238
  lower_text.startswith("@lightningv4") or
239
  lower_text.startswith("@turbov3")):
 
285
  yield gr.Image(image_path)
286
  return
287
 
288
+ # GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
289
  if lower_text.startswith("@gemma3-4b"):
290
+ # Remove the gemma3 flag from the prompt.
291
  prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
 
 
 
 
 
 
 
292
  if files:
293
+ # If image files are provided, load them.
294
+ images = [load_image(f) for f in files]
295
+ messages = [{
296
+ "role": "user",
297
+ "content": [
298
+ *[{"type": "image", "image": image} for image in images],
299
+ {"type": "text", "text": prompt_clean},
300
+ ]
301
+ }]
302
+ else:
303
+ messages = [
304
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
305
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
306
+ ]
307
  inputs = gemma3_processor.apply_chat_template(
308
  messages, add_generation_prompt=True, tokenize=True,
309
  return_dict=True, return_tensors="pt"
310
  ).to(gemma3_model.device, dtype=torch.bfloat16)
311
+ streamer = TextIteratorStreamer(
312
+ gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
313
+ )
314
+ generation_kwargs = {
315
+ **inputs,
316
+ "streamer": streamer,
317
+ "max_new_tokens": max_new_tokens,
318
+ "do_sample": True,
319
+ "temperature": temperature,
320
+ "top_p": top_p,
321
+ "top_k": top_k,
322
+ "repetition_penalty": repetition_penalty,
323
+ }
324
  thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
325
  thread.start()
 
326
  buffer = ""
327
+ yield progress_bar_html("Processing with Gemma3-4b")
328
  for new_text in streamer:
329
  buffer += new_text
330
+ time.sleep(0.01)
331
  yield buffer
 
 
332
  return
333
 
334
+ # Otherwise, handle text/chat (and TTS) generation.
335
  tts_prefix = "@tts"
336
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
337
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)