prithivMLmods commited on
Commit
2d99b82
·
verified ·
1 Parent(s): eed6cef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -155
app.py CHANGED
@@ -12,7 +12,6 @@ import spaces
12
  import torch
13
  import numpy as np
14
  from PIL import Image
15
- import edge_tts
16
  import cv2
17
 
18
  from transformers import (
@@ -24,7 +23,6 @@ from transformers import (
24
  Gemma3ForConditionalGeneration,
25
  )
26
  from transformers.image_utils import load_image
27
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
28
 
29
  # Constants
30
  MAX_MAX_NEW_TOKENS = 2048
@@ -51,7 +49,7 @@ def progress_bar_html(label: str) -> str:
51
  </style>
52
  '''
53
 
54
- # TEXT & TTS MODELS
55
 
56
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
57
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -62,11 +60,6 @@ model = AutoModelForCausalLM.from_pretrained(
62
  )
63
  model.eval()
64
 
65
- TTS_VOICES = [
66
- "en-US-JennyNeural", # @tts1
67
- "en-US-GuyNeural", # @tts2
68
- ]
69
-
70
  # MULTIMODAL (OCR) MODELS
71
 
72
  MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
@@ -77,11 +70,6 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
77
  torch_dtype=torch.float16
78
  ).to("cuda").eval()
79
 
80
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
81
- communicate = edge_tts.Communicate(text, voice)
82
- await communicate.save(output_file)
83
- return output_file
84
-
85
  def clean_chat_history(chat_history):
86
  cleaned = []
87
  for msg in chat_history:
@@ -114,46 +102,9 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
114
 
115
  dtype = torch.float16 if device.type == "cuda" else torch.float32
116
 
117
- # STABLE DIFFUSION IMAGE GENERATION MODEL (Lightning 5 only)
118
-
119
- if torch.cuda.is_available():
120
- pipe = StableDiffusionXLPipeline.from_pretrained(
121
- "SG161222/RealVisXL_V5.0_Lightning",
122
- torch_dtype=dtype,
123
- use_safetensors=True,
124
- add_watermarker=False
125
- ).to(device)
126
- pipe.text_encoder = pipe.text_encoder.half()
127
- if ENABLE_CPU_OFFLOAD:
128
- pipe.enable_model_cpu_offload()
129
- else:
130
- pipe.to(device)
131
- print("Loaded RealVisXL_V5.0_Lightning on Device!")
132
- if USE_TORCH_COMPILE:
133
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
134
- print("Model RealVisXL_V5.0_Lightning Compiled!")
135
- else:
136
- pipe = StableDiffusionXLPipeline.from_pretrained(
137
- "SG161222/RealVisXL_V5.0_Lightning",
138
- torch_dtype=dtype,
139
- use_safetensors=True,
140
- add_watermarker=False
141
- ).to(device)
142
- print("Running on CPU; model loaded in float32.")
143
-
144
- DEFAULT_MODEL = "Lightning 5"
145
- models = {
146
- "Lightning 5": pipe
147
- }
148
-
149
- def save_image(img: Image.Image) -> str:
150
- unique_name = str(uuid.uuid4()) + ".png"
151
- img.save(unique_name)
152
- return unique_name
153
-
154
  # GEMMA3-4B MULTIMODAL MODEL
155
 
156
- gemma3_model_id = "google/gemma-3-4b-it" #alter google/gemma-3-12b-it
157
  gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
158
  gemma3_model_id, device_map="auto"
159
  ).eval()
@@ -196,91 +147,51 @@ def generate(
196
 
197
  lower_text = text.lower().strip()
198
 
199
- # IMAGE GENERATION BRANCH (Stable Diffusion model using @lightningv5)
200
- if lower_text.startswith("@lightningv5"):
201
- # Remove the model flag from the prompt.
202
- prompt_clean = re.sub(r"@lightningv5", "", text, flags=re.IGNORECASE).strip().strip('"')
203
-
204
- # Default parameters for single image generation.
205
- width = 1024
206
- height = 1024
207
- guidance_scale = 6.0
208
- seed_val = 0
209
- randomize_seed_flag = True
210
-
211
- seed_val = int(randomize_seed_fn(seed_val, randomize_seed_flag))
212
- generator = torch.Generator(device=device).manual_seed(seed_val)
213
-
214
- options = {
215
- "prompt": prompt_clean,
216
- "negative_prompt": default_negative,
217
- "width": width,
218
- "height": height,
219
- "guidance_scale": guidance_scale,
220
- "num_inference_steps": 30,
221
- "generator": generator,
222
- "num_images_per_prompt": 1,
223
- "use_resolution_binning": True,
224
- "output_type": "pil",
225
- }
226
- if device.type == "cuda":
227
- torch.cuda.empty_cache()
228
-
229
- yield progress_bar_html("Processing Image Generation")
230
- images = models["Lightning 5"](**options).images
231
- image_path = save_image(images[0])
232
- yield gr.Image(image_path)
233
- return
234
-
235
  # GEMMA3-4B TEXT & MULTIMODAL (image) Branch
236
  if lower_text.startswith("@gemma3"):
237
- # If it is video, let the dedicated branch handle it.
238
- if lower_text.startswith("@video-infer"):
239
- pass # video branch is handled below.
240
- else:
241
- # Remove the gemma3 flag from the prompt.
242
- prompt_clean = re.sub(r"@gemma3", "", text, flags=re.IGNORECASE).strip().strip('"')
243
- if files:
244
- # If image files are provided, load them.
245
- images = [load_image(f) for f in files]
246
- messages = [{
247
- "role": "user",
248
- "content": [
249
- *[{"type": "image", "image": image} for image in images],
250
- {"type": "text", "text": prompt_clean},
251
- ]
252
- }]
253
- else:
254
- messages = [
255
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
256
- {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
257
  ]
258
- inputs = gemma3_processor.apply_chat_template(
259
- messages, add_generation_prompt=True, tokenize=True,
260
- return_dict=True, return_tensors="pt"
261
- ).to(gemma3_model.device, dtype=torch.bfloat16)
262
- streamer = TextIteratorStreamer(
263
- gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
264
- )
265
- generation_kwargs = {
266
- **inputs,
267
- "streamer": streamer,
268
- "max_new_tokens": max_new_tokens,
269
- "do_sample": True,
270
- "temperature": temperature,
271
- "top_p": top_p,
272
- "top_k": top_k,
273
- "repetition_penalty": repetition_penalty,
274
- }
275
- thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
276
- thread.start()
277
- buffer = ""
278
- yield progress_bar_html("Processing with Gemma3")
279
- for new_text in streamer:
280
- buffer += new_text
281
- time.sleep(0.01)
282
- yield buffer
283
- return
 
 
 
 
 
 
284
 
285
  # GEMMA3-4B VIDEO Branch
286
  if lower_text.startswith("@video-infer"):
@@ -333,20 +244,9 @@ def generate(
333
  yield buffer
334
  return
335
 
336
- # Otherwise, handle text/chat (and TTS) generation.
337
- tts_prefix = "@tts"
338
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
339
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
340
-
341
- if is_tts and voice_index:
342
- voice = TTS_VOICES[voice_index - 1]
343
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
344
- conversation = [{"role": "user", "content": text}]
345
- else:
346
- voice = None
347
- text = text.replace(tts_prefix, "").strip()
348
- conversation = clean_chat_history(chat_history)
349
- conversation.append({"role": "user", "content": text})
350
 
351
  if files:
352
  images = [load_image(image) for image in files] if len(files) > 1 else [load_image(files[0])]
@@ -400,10 +300,6 @@ def generate(
400
  final_response = "".join(outputs)
401
  yield final_response
402
 
403
- if is_tts and voice:
404
- output_file = asyncio.run(text_to_speech(final_response, voice))
405
- yield gr.Audio(output_file, autoplay=True)
406
-
407
  demo = gr.ChatInterface(
408
  fn=generate,
409
  additional_inputs=[
@@ -422,16 +318,13 @@ demo = gr.ChatInterface(
422
  [{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
423
  [{"text": "@video-infer Summarize the events in this video", "files": ["examples/sky.mp4"]}],
424
  [{"text": "@video-infer What is in the video ?", "files": ["examples/redlight.mp4"]}],
425
- ['@lightningv5 Chocolate dripping from a donut'],
426
  ["Python Program for Array Rotation"],
427
- ["@tts1 Who is Nikola Tesla, and why did he die?"],
428
- ["@tts2 What causes rainbows to form?"],
429
  ],
430
  cache_examples=False,
431
  type="messages",
432
  description="# **Gemma 3 `@gemma3, @video-infer for video understanding`**",
433
  fill_height=True,
434
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="@gemma3 for multimodal, @video-infer for video, @lightningv5 for image gen !"),
435
  stop_btn="Stop Generation",
436
  multimodal=True,
437
  )
 
12
  import torch
13
  import numpy as np
14
  from PIL import Image
 
15
  import cv2
16
 
17
  from transformers import (
 
23
  Gemma3ForConditionalGeneration,
24
  )
25
  from transformers.image_utils import load_image
 
26
 
27
  # Constants
28
  MAX_MAX_NEW_TOKENS = 2048
 
49
  </style>
50
  '''
51
 
52
+ # TEXT MODEL
53
 
54
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
55
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
60
  )
61
  model.eval()
62
 
 
 
 
 
 
63
  # MULTIMODAL (OCR) MODELS
64
 
65
  MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
 
70
  torch_dtype=torch.float16
71
  ).to("cuda").eval()
72
 
 
 
 
 
 
73
  def clean_chat_history(chat_history):
74
  cleaned = []
75
  for msg in chat_history:
 
102
 
103
  dtype = torch.float16 if device.type == "cuda" else torch.float32
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  # GEMMA3-4B MULTIMODAL MODEL
106
 
107
+ gemma3_model_id = "google/gemma-3-4b-it" # alternative: google/gemma-3-12b-it
108
  gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
109
  gemma3_model_id, device_map="auto"
110
  ).eval()
 
147
 
148
  lower_text = text.lower().strip()
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  # GEMMA3-4B TEXT & MULTIMODAL (image) Branch
151
  if lower_text.startswith("@gemma3"):
152
+ # Remove the gemma3 flag from the prompt.
153
+ prompt_clean = re.sub(r"@gemma3", "", text, flags=re.IGNORECASE).strip().strip('"')
154
+ if files:
155
+ # If image files are provided, load them.
156
+ images = [load_image(f) for f in files]
157
+ messages = [{
158
+ "role": "user",
159
+ "content": [
160
+ *[{"type": "image", "image": image} for image in images],
161
+ {"type": "text", "text": prompt_clean},
 
 
 
 
 
 
 
 
 
 
162
  ]
163
+ }]
164
+ else:
165
+ messages = [
166
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
167
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
168
+ ]
169
+ inputs = gemma3_processor.apply_chat_template(
170
+ messages, add_generation_prompt=True, tokenize=True,
171
+ return_dict=True, return_tensors="pt"
172
+ ).to(gemma3_model.device, dtype=torch.bfloat16)
173
+ streamer = TextIteratorStreamer(
174
+ gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
175
+ )
176
+ generation_kwargs = {
177
+ **inputs,
178
+ "streamer": streamer,
179
+ "max_new_tokens": max_new_tokens,
180
+ "do_sample": True,
181
+ "temperature": temperature,
182
+ "top_p": top_p,
183
+ "top_k": top_k,
184
+ "repetition_penalty": repetition_penalty,
185
+ }
186
+ thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
187
+ thread.start()
188
+ buffer = ""
189
+ yield progress_bar_html("Processing with Gemma3")
190
+ for new_text in streamer:
191
+ buffer += new_text
192
+ time.sleep(0.01)
193
+ yield buffer
194
+ return
195
 
196
  # GEMMA3-4B VIDEO Branch
197
  if lower_text.startswith("@video-infer"):
 
244
  yield buffer
245
  return
246
 
247
+ # Otherwise, handle text/chat generation.
248
+ conversation = clean_chat_history(chat_history)
249
+ conversation.append({"role": "user", "content": text})
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  if files:
252
  images = [load_image(image) for image in files] if len(files) > 1 else [load_image(files[0])]
 
300
  final_response = "".join(outputs)
301
  yield final_response
302
 
 
 
 
 
303
  demo = gr.ChatInterface(
304
  fn=generate,
305
  additional_inputs=[
 
318
  [{"text": "@video-infer Explain what is happening in this video ?", "files": ["examples/oreo.mp4"]}],
319
  [{"text": "@video-infer Summarize the events in this video", "files": ["examples/sky.mp4"]}],
320
  [{"text": "@video-infer What is in the video ?", "files": ["examples/redlight.mp4"]}],
 
321
  ["Python Program for Array Rotation"],
 
 
322
  ],
323
  cache_examples=False,
324
  type="messages",
325
  description="# **Gemma 3 `@gemma3, @video-infer for video understanding`**",
326
  fill_height=True,
327
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="@gemma3 for multimodal, @video-infer for video !"),
328
  stop_btn="Stop Generation",
329
  multimodal=True,
330
  )