Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -114,10 +114,9 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
|
114 |
|
115 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
116 |
|
117 |
-
# STABLE DIFFUSION IMAGE GENERATION
|
118 |
|
119 |
if torch.cuda.is_available():
|
120 |
-
# Lightning 5 model
|
121 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
122 |
"SG161222/RealVisXL_V5.0_Lightning",
|
123 |
torch_dtype=dtype,
|
@@ -133,24 +132,6 @@ if torch.cuda.is_available():
|
|
133 |
if USE_TORCH_COMPILE:
|
134 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
135 |
print("Model RealVisXL_V5.0_Lightning Compiled!")
|
136 |
-
|
137 |
-
# Lightning 4 model
|
138 |
-
pipe2 = StableDiffusionXLPipeline.from_pretrained(
|
139 |
-
"SG161222/RealVisXL_V4.0_Lightning",
|
140 |
-
torch_dtype=dtype,
|
141 |
-
use_safetensors=True,
|
142 |
-
add_watermarker=False,
|
143 |
-
).to(device)
|
144 |
-
pipe2.text_encoder = pipe2.text_encoder.half()
|
145 |
-
if ENABLE_CPU_OFFLOAD:
|
146 |
-
pipe2.enable_model_cpu_offload()
|
147 |
-
else:
|
148 |
-
pipe2.to(device)
|
149 |
-
print("Loaded RealVisXL_V4.0 on Device!")
|
150 |
-
if USE_TORCH_COMPILE:
|
151 |
-
pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
|
152 |
-
print("Model RealVisXL_V4.0 Compiled!")
|
153 |
-
|
154 |
else:
|
155 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
156 |
"SG161222/RealVisXL_V5.0_Lightning",
|
@@ -158,19 +139,11 @@ else:
|
|
158 |
use_safetensors=True,
|
159 |
add_watermarker=False
|
160 |
).to(device)
|
161 |
-
|
162 |
-
"SG161222/RealVisXL_V4.0_Lightning",
|
163 |
-
torch_dtype=dtype,
|
164 |
-
use_safetensors=True,
|
165 |
-
add_watermarker=False,
|
166 |
-
).to(device)
|
167 |
-
print("Running on CPU; models loaded in float32.")
|
168 |
|
169 |
DEFAULT_MODEL = "Lightning 5"
|
170 |
-
MODEL_CHOICES = [DEFAULT_MODEL, "Lightning 4"]
|
171 |
models = {
|
172 |
-
"Lightning 5": pipe
|
173 |
-
"Lightning 4": pipe2
|
174 |
}
|
175 |
|
176 |
def save_image(img: Image.Image) -> str:
|
@@ -223,21 +196,10 @@ def generate(
|
|
223 |
|
224 |
lower_text = text.lower().strip()
|
225 |
|
226 |
-
# IMAGE GENERATION BRANCH (Stable Diffusion
|
227 |
-
if
|
228 |
-
lower_text.startswith("@lightningv4")):
|
229 |
-
|
230 |
-
# Determine model choice based on flag.
|
231 |
-
model_choice = None
|
232 |
-
if "@lightningv5" in lower_text:
|
233 |
-
model_choice = "Lightning 5"
|
234 |
-
elif "@lightningv4" in lower_text:
|
235 |
-
model_choice = "Lightning 4"
|
236 |
-
|
237 |
# Remove the model flag from the prompt.
|
238 |
-
prompt_clean = re.sub(r"@lightningv5", "", text, flags=re.IGNORECASE)
|
239 |
-
prompt_clean = re.sub(r"@lightningv4", "", prompt_clean, flags=re.IGNORECASE)
|
240 |
-
prompt_clean = prompt_clean.strip().strip('"')
|
241 |
|
242 |
# Default parameters for single image generation.
|
243 |
width = 1024
|
@@ -264,9 +226,8 @@ def generate(
|
|
264 |
if device.type == "cuda":
|
265 |
torch.cuda.empty_cache()
|
266 |
|
267 |
-
selected_pipe = models.get(model_choice, pipe)
|
268 |
yield progress_bar_html("Processing Image Generation")
|
269 |
-
images =
|
270 |
image_path = save_image(images[0])
|
271 |
yield gr.Image(image_path)
|
272 |
return
|
@@ -321,7 +282,7 @@ def generate(
|
|
321 |
yield buffer
|
322 |
return
|
323 |
|
324 |
-
#
|
325 |
if lower_text.startswith("@video-infer"):
|
326 |
# Remove the video flag from the prompt.
|
327 |
prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
|
@@ -336,7 +297,6 @@ def generate(
|
|
336 |
# Append each frame as an image with a timestamp label.
|
337 |
for frame in frames:
|
338 |
image, timestamp = frame
|
339 |
-
# Save the frame image to a temporary unique filename.
|
340 |
image_path = f"video_frame_{uuid.uuid4().hex}.png"
|
341 |
image.save(image_path)
|
342 |
messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
|
@@ -465,14 +425,13 @@ demo = gr.ChatInterface(
|
|
465 |
['@lightningv5 Chocolate dripping from a donut'],
|
466 |
["Python Program for Array Rotation"],
|
467 |
["@tts1 Who is Nikola Tesla, and why did he die?"],
|
468 |
-
['@lightningv4 Cat holding a sign that says hello world'],
|
469 |
["@tts2 What causes rainbows to form?"],
|
470 |
],
|
471 |
cache_examples=False,
|
472 |
type="messages",
|
473 |
description="# **Gemma 3 `@gemma3-4b, @video-infer for video understanding`**",
|
474 |
fill_height=True,
|
475 |
-
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="@gemma3-4b for multimodal, @video-infer for video, @lightningv5
|
476 |
stop_btn="Stop Generation",
|
477 |
multimodal=True,
|
478 |
)
|
|
|
114 |
|
115 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
116 |
|
117 |
+
# STABLE DIFFUSION IMAGE GENERATION MODEL (Lightning 5 only)
|
118 |
|
119 |
if torch.cuda.is_available():
|
|
|
120 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
121 |
"SG161222/RealVisXL_V5.0_Lightning",
|
122 |
torch_dtype=dtype,
|
|
|
132 |
if USE_TORCH_COMPILE:
|
133 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
134 |
print("Model RealVisXL_V5.0_Lightning Compiled!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
else:
|
136 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
137 |
"SG161222/RealVisXL_V5.0_Lightning",
|
|
|
139 |
use_safetensors=True,
|
140 |
add_watermarker=False
|
141 |
).to(device)
|
142 |
+
print("Running on CPU; model loaded in float32.")
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
DEFAULT_MODEL = "Lightning 5"
|
|
|
145 |
models = {
|
146 |
+
"Lightning 5": pipe
|
|
|
147 |
}
|
148 |
|
149 |
def save_image(img: Image.Image) -> str:
|
|
|
196 |
|
197 |
lower_text = text.lower().strip()
|
198 |
|
199 |
+
# IMAGE GENERATION BRANCH (Stable Diffusion model using @lightningv5)
|
200 |
+
if lower_text.startswith("@lightningv5"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
# Remove the model flag from the prompt.
|
202 |
+
prompt_clean = re.sub(r"@lightningv5", "", text, flags=re.IGNORECASE).strip().strip('"')
|
|
|
|
|
203 |
|
204 |
# Default parameters for single image generation.
|
205 |
width = 1024
|
|
|
226 |
if device.type == "cuda":
|
227 |
torch.cuda.empty_cache()
|
228 |
|
|
|
229 |
yield progress_bar_html("Processing Image Generation")
|
230 |
+
images = models["Lightning 5"](**options).images
|
231 |
image_path = save_image(images[0])
|
232 |
yield gr.Image(image_path)
|
233 |
return
|
|
|
282 |
yield buffer
|
283 |
return
|
284 |
|
285 |
+
# GEMMA3-4B VIDEO Branch
|
286 |
if lower_text.startswith("@video-infer"):
|
287 |
# Remove the video flag from the prompt.
|
288 |
prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
|
|
|
297 |
# Append each frame as an image with a timestamp label.
|
298 |
for frame in frames:
|
299 |
image, timestamp = frame
|
|
|
300 |
image_path = f"video_frame_{uuid.uuid4().hex}.png"
|
301 |
image.save(image_path)
|
302 |
messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
|
|
|
425 |
['@lightningv5 Chocolate dripping from a donut'],
|
426 |
["Python Program for Array Rotation"],
|
427 |
["@tts1 Who is Nikola Tesla, and why did he die?"],
|
|
|
428 |
["@tts2 What causes rainbows to form?"],
|
429 |
],
|
430 |
cache_examples=False,
|
431 |
type="messages",
|
432 |
description="# **Gemma 3 `@gemma3-4b, @video-infer for video understanding`**",
|
433 |
fill_height=True,
|
434 |
+
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="@gemma3-4b for multimodal, @video-infer for video, @lightningv5 for image gen !"),
|
435 |
stop_btn="Stop Generation",
|
436 |
multimodal=True,
|
437 |
)
|