Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1128,38 +1128,49 @@ def handle_model_choice_change(selected_model):
|
|
1128 |
import gradio as gr
|
1129 |
import torch
|
1130 |
from diffusers import FluxPipeline
|
1131 |
-
|
1132 |
|
1133 |
-
#
|
1134 |
-
|
1135 |
-
|
|
|
|
|
1136 |
|
1137 |
-
# Function to
|
1138 |
-
def
|
1139 |
-
|
1140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1141 |
prompt,
|
1142 |
guidance_scale=0.0,
|
1143 |
-
num_inference_steps=
|
1144 |
-
max_sequence_length=
|
1145 |
-
generator=
|
1146 |
).images[0]
|
1147 |
-
|
1148 |
-
# Save image temporarily and return for display
|
1149 |
-
temp_image_path = f"temp_flux_image_{hash(prompt)}.png"
|
1150 |
-
image.save(temp_image_path)
|
1151 |
-
|
1152 |
-
return temp_image_path
|
1153 |
|
1154 |
-
# Hardcoded prompts for
|
1155 |
hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
|
1156 |
hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
|
1157 |
hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
|
1158 |
|
1159 |
-
#
|
1160 |
-
|
1161 |
-
|
1162 |
-
|
|
|
|
|
1163 |
|
1164 |
|
1165 |
|
@@ -1462,10 +1473,13 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
|
|
1462 |
events_output = gr.HTML(value=fetch_local_events())
|
1463 |
|
1464 |
with gr.Column():
|
1465 |
-
|
1466 |
-
|
1467 |
-
|
1468 |
-
|
|
|
|
|
|
|
1469 |
|
1470 |
|
1471 |
|
|
|
1128 |
import gradio as gr
|
1129 |
import torch
|
1130 |
from diffusers import FluxPipeline
|
1131 |
+
import os
|
1132 |
|
1133 |
+
# Set PYTORCH_CUDA_ALLOC_CONF to handle memory fragmentation
|
1134 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
1135 |
+
|
1136 |
+
# Check if CUDA (GPU) is available, otherwise fallback to CPU
|
1137 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
1138 |
|
1139 |
+
# Function to initialize Flux bot model with GPU memory management
|
1140 |
+
def initialize_flux_bot():
|
1141 |
+
try:
|
1142 |
+
torch.cuda.empty_cache() # Clear GPU memory cache
|
1143 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16) # Use FP16
|
1144 |
+
pipe.to(device) # Move the model to the correct device (GPU/CPU)
|
1145 |
+
except torch.cuda.OutOfMemoryError:
|
1146 |
+
print("CUDA out of memory, switching to CPU.")
|
1147 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float32) # Use FP32 for CPU
|
1148 |
+
pipe.to("cpu")
|
1149 |
+
return pipe
|
1150 |
+
|
1151 |
+
# Function to generate image using Flux bot on the specified device
|
1152 |
+
def generate_image_flux(prompt):
|
1153 |
+
pipe = initialize_flux_bot()
|
1154 |
+
image = pipe(
|
1155 |
prompt,
|
1156 |
guidance_scale=0.0,
|
1157 |
+
num_inference_steps=2, # Reduced steps to save memory
|
1158 |
+
max_sequence_length=128, # Reduced sequence length to save memory
|
1159 |
+
generator=torch.Generator(device).manual_seed(0)
|
1160 |
).images[0]
|
1161 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
1162 |
|
1163 |
+
# Hardcoded prompts for the images
|
1164 |
hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
|
1165 |
hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
|
1166 |
hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
|
1167 |
|
1168 |
+
# Function to update images
|
1169 |
+
def update_images():
|
1170 |
+
image_1 = generate_image_flux(hardcoded_prompt_1)
|
1171 |
+
image_2 = generate_image_flux(hardcoded_prompt_2)
|
1172 |
+
image_3 = generate_image_flux(hardcoded_prompt_3)
|
1173 |
+
return image_1, image_2, image_3
|
1174 |
|
1175 |
|
1176 |
|
|
|
1473 |
events_output = gr.HTML(value=fetch_local_events())
|
1474 |
|
1475 |
with gr.Column():
|
1476 |
+
image_output_1 = gr.Image(value=generate_image_flux(hardcoded_prompt_1), width=400, height=400)
|
1477 |
+
image_output_2 = gr.Image(value=generate_image_flux(hardcoded_prompt_2), width=400, height=400)
|
1478 |
+
image_output_3 = gr.Image(value=generate_image_flux(hardcoded_prompt_3), width=400, height=400)
|
1479 |
+
|
1480 |
+
# Refresh button to update images
|
1481 |
+
refresh_button = gr.Button("Refresh Images")
|
1482 |
+
refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
|
1483 |
|
1484 |
|
1485 |
|