Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1125,47 +1125,72 @@ def handle_model_choice_change(selected_model):
|
|
1125 |
# Default case: allow interaction
|
1126 |
return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
|
1127 |
|
1128 |
-
import gradio as gr
|
1129 |
-
import torch
|
1130 |
-
from diffusers import FluxPipeline
|
1131 |
-
import os
|
1132 |
-
|
1133 |
-
# Set PYTORCH_CUDA_ALLOC_CONF to handle memory fragmentation
|
1134 |
-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
1135 |
-
|
1136 |
-
# Check if CUDA (GPU) is available, otherwise fallback to CPU
|
1137 |
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1138 |
|
1139 |
-
# Function to initialize Flux bot model with GPU memory management
|
1140 |
-
def initialize_flux_bot():
|
1141 |
-
try:
|
1142 |
-
torch.cuda.empty_cache() # Clear GPU memory cache
|
1143 |
-
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16) # Use FP16
|
1144 |
-
pipe.to(device) # Move the model to the correct device (GPU/CPU)
|
1145 |
-
except torch.cuda.OutOfMemoryError:
|
1146 |
-
print("CUDA out of memory, switching to CPU.")
|
1147 |
-
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float32) # Use FP32 for CPU
|
1148 |
-
pipe.to("cpu")
|
1149 |
-
return pipe
|
1150 |
-
|
1151 |
-
# Function to generate image using Flux bot on the specified device
|
1152 |
-
def generate_image_flux(prompt):
|
1153 |
-
pipe = initialize_flux_bot()
|
1154 |
-
image = pipe(
|
1155 |
-
prompt,
|
1156 |
-
guidance_scale=0.0,
|
1157 |
-
num_inference_steps=2, # Reduced steps to save memory
|
1158 |
-
max_sequence_length=128, # Reduced sequence length to save memory
|
1159 |
-
generator=torch.Generator(device).manual_seed(0)
|
1160 |
-
).images[0]
|
1161 |
-
return image
|
1162 |
-
|
1163 |
-
# Hardcoded prompts for the images
|
1164 |
hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
|
1165 |
hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
|
1166 |
hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
|
1167 |
|
1168 |
-
# Function to
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1169 |
def update_images():
|
1170 |
image_1 = generate_image_flux(hardcoded_prompt_1)
|
1171 |
image_2 = generate_image_flux(hardcoded_prompt_2)
|
@@ -1177,7 +1202,6 @@ def update_images():
|
|
1177 |
|
1178 |
|
1179 |
|
1180 |
-
|
1181 |
def format_restaurant_hotel_info(name, link, location, phone, rating, reviews, snippet):
|
1182 |
return f"""
|
1183 |
{name}
|
@@ -1473,14 +1497,22 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
|
|
1473 |
events_output = gr.HTML(value=fetch_local_events())
|
1474 |
|
1475 |
with gr.Column():
|
1476 |
-
image_output_1 = gr.Image(value=generate_image_flux(hardcoded_prompt_1), width=400, height=400)
|
1477 |
-
image_output_2 = gr.Image(value=generate_image_flux(hardcoded_prompt_2), width=400, height=400)
|
1478 |
-
image_output_3 = gr.Image(value=generate_image_flux(hardcoded_prompt_3), width=400, height=400)
|
1479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1480 |
# Refresh button to update images
|
1481 |
refresh_button = gr.Button("Refresh Images")
|
1482 |
refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
|
1483 |
-
|
1484 |
|
1485 |
|
1486 |
|
|
|
1125 |
# Default case: allow interaction
|
1126 |
return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
|
1127 |
|
1128 |
+
# import gradio as gr
|
1129 |
+
# import torch
|
1130 |
+
# from diffusers import FluxPipeline
|
1131 |
+
# import os
|
1132 |
+
|
1133 |
+
# # Set PYTORCH_CUDA_ALLOC_CONF to handle memory fragmentation
|
1134 |
+
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
1135 |
+
|
1136 |
+
# # Check if CUDA (GPU) is available, otherwise fallback to CPU
|
1137 |
+
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
1138 |
+
|
1139 |
+
# # Function to initialize Flux bot model with GPU memory management
|
1140 |
+
# def initialize_flux_bot():
|
1141 |
+
# try:
|
1142 |
+
# torch.cuda.empty_cache() # Clear GPU memory cache
|
1143 |
+
# pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16) # Use FP16
|
1144 |
+
# pipe.to(device) # Move the model to the correct device (GPU/CPU)
|
1145 |
+
# except torch.cuda.OutOfMemoryError:
|
1146 |
+
# print("CUDA out of memory, switching to CPU.")
|
1147 |
+
# pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float32) # Use FP32 for CPU
|
1148 |
+
# pipe.to("cpu")
|
1149 |
+
# return pipe
|
1150 |
+
|
1151 |
+
# # Function to generate image using Flux bot on the specified device
|
1152 |
+
# def generate_image_flux(prompt):
|
1153 |
+
# pipe = initialize_flux_bot()
|
1154 |
+
# image = pipe(
|
1155 |
+
# prompt,
|
1156 |
+
# guidance_scale=0.0,
|
1157 |
+
# num_inference_steps=2, # Reduced steps to save memory
|
1158 |
+
# max_sequence_length=128, # Reduced sequence length to save memory
|
1159 |
+
# generator=torch.Generator(device).manual_seed(0)
|
1160 |
+
# ).images[0]
|
1161 |
+
# return image
|
1162 |
+
|
1163 |
+
# # Hardcoded prompts for the images
|
1164 |
+
# hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
|
1165 |
+
# hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
|
1166 |
+
# hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
|
1167 |
+
|
1168 |
+
# # Function to update images
|
1169 |
+
# def update_images():
|
1170 |
+
# image_1 = generate_image_flux(hardcoded_prompt_1)
|
1171 |
+
# image_2 = generate_image_flux(hardcoded_prompt_2)
|
1172 |
+
# image_3 = generate_image_flux(hardcoded_prompt_3)
|
1173 |
+
# return image_1, image_2, image_3
|
1174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1175 |
hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
|
1176 |
hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
|
1177 |
hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
|
1178 |
|
1179 |
+
# Function to call the Flux API and generate images
|
1180 |
+
def generate_image_flux(prompt):
|
1181 |
+
client = Client("black-forest-labs/FLUX.1-schnell")
|
1182 |
+
result = client.predict(
|
1183 |
+
prompt=prompt,
|
1184 |
+
seed=0,
|
1185 |
+
randomize_seed=True,
|
1186 |
+
width=400,
|
1187 |
+
height=400,
|
1188 |
+
num_inference_steps=4,
|
1189 |
+
api_name="/infer"
|
1190 |
+
)
|
1191 |
+
return result
|
1192 |
+
|
1193 |
+
# Function to update images with the three prompts
|
1194 |
def update_images():
|
1195 |
image_1 = generate_image_flux(hardcoded_prompt_1)
|
1196 |
image_2 = generate_image_flux(hardcoded_prompt_2)
|
|
|
1202 |
|
1203 |
|
1204 |
|
|
|
1205 |
def format_restaurant_hotel_info(name, link, location, phone, rating, reviews, snippet):
|
1206 |
return f"""
|
1207 |
{name}
|
|
|
1497 |
events_output = gr.HTML(value=fetch_local_events())
|
1498 |
|
1499 |
with gr.Column():
|
1500 |
+
# image_output_1 = gr.Image(value=generate_image_flux(hardcoded_prompt_1), width=400, height=400)
|
1501 |
+
# image_output_2 = gr.Image(value=generate_image_flux(hardcoded_prompt_2), width=400, height=400)
|
1502 |
+
# image_output_3 = gr.Image(value=generate_image_flux(hardcoded_prompt_3), width=400, height=400)
|
1503 |
|
1504 |
+
# # Refresh button to update images
|
1505 |
+
# refresh_button = gr.Button("Refresh Images")
|
1506 |
+
# refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
|
1507 |
+
|
1508 |
+
# Displaying the images generated using Flux API
|
1509 |
+
image_output_1 = gr.Image(label="Image 1", elem_id="flux_image_1", width=400, height=400)
|
1510 |
+
image_output_2 = gr.Image(label="Image 2", elem_id="flux_image_2", width=400, height=400)
|
1511 |
+
image_output_3 = gr.Image(label="Image 3", elem_id="flux_image_3", width=400, height=400)
|
1512 |
+
|
1513 |
# Refresh button to update images
|
1514 |
refresh_button = gr.Button("Refresh Images")
|
1515 |
refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
|
|
|
1516 |
|
1517 |
|
1518 |
|