prithivMLmods commited on
Commit
4e7ff73
·
verified ·
1 Parent(s): e7c8feb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -50
app.py CHANGED
@@ -114,10 +114,9 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
114
 
115
  dtype = torch.float16 if device.type == "cuda" else torch.float32
116
 
117
- # STABLE DIFFUSION IMAGE GENERATION MODELS
118
 
119
  if torch.cuda.is_available():
120
- # Lightning 5 model
121
  pipe = StableDiffusionXLPipeline.from_pretrained(
122
  "SG161222/RealVisXL_V5.0_Lightning",
123
  torch_dtype=dtype,
@@ -133,24 +132,6 @@ if torch.cuda.is_available():
133
  if USE_TORCH_COMPILE:
134
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
135
  print("Model RealVisXL_V5.0_Lightning Compiled!")
136
-
137
- # Lightning 4 model
138
- pipe2 = StableDiffusionXLPipeline.from_pretrained(
139
- "SG161222/RealVisXL_V4.0_Lightning",
140
- torch_dtype=dtype,
141
- use_safetensors=True,
142
- add_watermarker=False,
143
- ).to(device)
144
- pipe2.text_encoder = pipe2.text_encoder.half()
145
- if ENABLE_CPU_OFFLOAD:
146
- pipe2.enable_model_cpu_offload()
147
- else:
148
- pipe2.to(device)
149
- print("Loaded RealVisXL_V4.0 on Device!")
150
- if USE_TORCH_COMPILE:
151
- pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
152
- print("Model RealVisXL_V4.0 Compiled!")
153
-
154
  else:
155
  pipe = StableDiffusionXLPipeline.from_pretrained(
156
  "SG161222/RealVisXL_V5.0_Lightning",
@@ -158,19 +139,11 @@ else:
158
  use_safetensors=True,
159
  add_watermarker=False
160
  ).to(device)
161
- pipe2 = StableDiffusionXLPipeline.from_pretrained(
162
- "SG161222/RealVisXL_V4.0_Lightning",
163
- torch_dtype=dtype,
164
- use_safetensors=True,
165
- add_watermarker=False,
166
- ).to(device)
167
- print("Running on CPU; models loaded in float32.")
168
 
169
  DEFAULT_MODEL = "Lightning 5"
170
- MODEL_CHOICES = [DEFAULT_MODEL, "Lightning 4"]
171
  models = {
172
- "Lightning 5": pipe,
173
- "Lightning 4": pipe2
174
  }
175
 
176
  def save_image(img: Image.Image) -> str:
@@ -223,21 +196,10 @@ def generate(
223
 
224
  lower_text = text.lower().strip()
225
 
226
- # IMAGE GENERATION BRANCH (Stable Diffusion models)
227
- if (lower_text.startswith("@lightningv5") or
228
- lower_text.startswith("@lightningv4")):
229
-
230
- # Determine model choice based on flag.
231
- model_choice = None
232
- if "@lightningv5" in lower_text:
233
- model_choice = "Lightning 5"
234
- elif "@lightningv4" in lower_text:
235
- model_choice = "Lightning 4"
236
-
237
  # Remove the model flag from the prompt.
238
- prompt_clean = re.sub(r"@lightningv5", "", text, flags=re.IGNORECASE)
239
- prompt_clean = re.sub(r"@lightningv4", "", prompt_clean, flags=re.IGNORECASE)
240
- prompt_clean = prompt_clean.strip().strip('"')
241
 
242
  # Default parameters for single image generation.
243
  width = 1024
@@ -264,9 +226,8 @@ def generate(
264
  if device.type == "cuda":
265
  torch.cuda.empty_cache()
266
 
267
- selected_pipe = models.get(model_choice, pipe)
268
  yield progress_bar_html("Processing Image Generation")
269
- images = selected_pipe(**options).images
270
  image_path = save_image(images[0])
271
  yield gr.Image(image_path)
272
  return
@@ -321,7 +282,7 @@ def generate(
321
  yield buffer
322
  return
323
 
324
- # NEW: GEMMA3-4B VIDEO Branch
325
  if lower_text.startswith("@video-infer"):
326
  # Remove the video flag from the prompt.
327
  prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
@@ -336,7 +297,6 @@ def generate(
336
  # Append each frame as an image with a timestamp label.
337
  for frame in frames:
338
  image, timestamp = frame
339
- # Save the frame image to a temporary unique filename.
340
  image_path = f"video_frame_{uuid.uuid4().hex}.png"
341
  image.save(image_path)
342
  messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
@@ -465,14 +425,13 @@ demo = gr.ChatInterface(
465
  ['@lightningv5 Chocolate dripping from a donut'],
466
  ["Python Program for Array Rotation"],
467
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
468
- ['@lightningv4 Cat holding a sign that says hello world'],
469
  ["@tts2 What causes rainbows to form?"],
470
  ],
471
  cache_examples=False,
472
  type="messages",
473
  description="# **Gemma 3 `@gemma3-4b, @video-infer for video understanding`**",
474
  fill_height=True,
475
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="@gemma3-4b for multimodal, @video-infer for video, @lightningv5, @lightningv4 for image gen !"),
476
  stop_btn="Stop Generation",
477
  multimodal=True,
478
  )
 
114
 
115
  dtype = torch.float16 if device.type == "cuda" else torch.float32
116
 
117
+ # STABLE DIFFUSION IMAGE GENERATION MODEL (Lightning 5 only)
118
 
119
  if torch.cuda.is_available():
 
120
  pipe = StableDiffusionXLPipeline.from_pretrained(
121
  "SG161222/RealVisXL_V5.0_Lightning",
122
  torch_dtype=dtype,
 
132
  if USE_TORCH_COMPILE:
133
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
134
  print("Model RealVisXL_V5.0_Lightning Compiled!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  else:
136
  pipe = StableDiffusionXLPipeline.from_pretrained(
137
  "SG161222/RealVisXL_V5.0_Lightning",
 
139
  use_safetensors=True,
140
  add_watermarker=False
141
  ).to(device)
142
+ print("Running on CPU; model loaded in float32.")
 
 
 
 
 
 
143
 
144
  DEFAULT_MODEL = "Lightning 5"
 
145
  models = {
146
+ "Lightning 5": pipe
 
147
  }
148
 
149
  def save_image(img: Image.Image) -> str:
 
196
 
197
  lower_text = text.lower().strip()
198
 
199
+ # IMAGE GENERATION BRANCH (Stable Diffusion model using @lightningv5)
200
+ if lower_text.startswith("@lightningv5"):
 
 
 
 
 
 
 
 
 
201
  # Remove the model flag from the prompt.
202
+ prompt_clean = re.sub(r"@lightningv5", "", text, flags=re.IGNORECASE).strip().strip('"')
 
 
203
 
204
  # Default parameters for single image generation.
205
  width = 1024
 
226
  if device.type == "cuda":
227
  torch.cuda.empty_cache()
228
 
 
229
  yield progress_bar_html("Processing Image Generation")
230
+ images = models["Lightning 5"](**options).images
231
  image_path = save_image(images[0])
232
  yield gr.Image(image_path)
233
  return
 
282
  yield buffer
283
  return
284
 
285
+ # GEMMA3-4B VIDEO Branch
286
  if lower_text.startswith("@video-infer"):
287
  # Remove the video flag from the prompt.
288
  prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
 
297
  # Append each frame as an image with a timestamp label.
298
  for frame in frames:
299
  image, timestamp = frame
 
300
  image_path = f"video_frame_{uuid.uuid4().hex}.png"
301
  image.save(image_path)
302
  messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
 
425
  ['@lightningv5 Chocolate dripping from a donut'],
426
  ["Python Program for Array Rotation"],
427
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
 
428
  ["@tts2 What causes rainbows to form?"],
429
  ],
430
  cache_examples=False,
431
  type="messages",
432
  description="# **Gemma 3 `@gemma3-4b, @video-infer for video understanding`**",
433
  fill_height=True,
434
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="@gemma3-4b for multimodal, @video-infer for video, @lightningv5 for image gen !"),
435
  stop_btn="Stop Generation",
436
  multimodal=True,
437
  )