Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -713,29 +713,53 @@ def generate_map(location_names):
|
|
713 |
# return image_1, image_2, image_3
|
714 |
|
715 |
|
|
|
716 |
import torch
|
717 |
from diffusers import FluxPipeline
|
|
|
718 |
|
719 |
-
|
720 |
-
|
721 |
-
pipe.enable_model_cpu_offload() # Offload to CPU to save VRAM
|
722 |
-
return pipe
|
723 |
|
724 |
-
#
|
725 |
-
|
726 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
727 |
|
728 |
-
|
729 |
-
|
730 |
-
|
|
|
731 |
prompt,
|
732 |
guidance_scale=0.0,
|
733 |
-
num_inference_steps=
|
734 |
-
max_sequence_length=
|
735 |
-
generator=torch.Generator(
|
736 |
).images[0]
|
737 |
return image
|
738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
739 |
|
740 |
|
741 |
|
@@ -1481,20 +1505,20 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
|
|
1481 |
|
1482 |
with gr.Column():
|
1483 |
|
1484 |
-
|
1485 |
-
|
1486 |
-
|
1487 |
-
|
1488 |
-
# When the button is clicked, the image generation function is triggered
|
1489 |
-
flux_generate_button.click(fn=generate_flux_image, inputs=flux_prompt, outputs=flux_image_output)
|
1490 |
|
|
|
|
|
|
|
1491 |
|
1492 |
|
1493 |
|
1494 |
|
1495 |
-
|
1496 |
-
|
1497 |
-
# refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
|
1498 |
|
1499 |
demo.queue()
|
1500 |
demo.launch(show_error=True)
|
|
|
713 |
# return image_1, image_2, image_3
|
714 |
|
715 |
|
716 |
+
import gradio as gr
|
717 |
import torch
|
718 |
from diffusers import FluxPipeline
|
719 |
+
import os
|
720 |
|
721 |
+
# Set PYTORCH_CUDA_ALLOC_CONF to handle memory fragmentation
|
722 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
|
|
|
723 |
|
724 |
+
# Check if CUDA (GPU) is available, otherwise fallback to CPU
|
725 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
726 |
|
727 |
+
# Function to initialize Flux bot model with GPU memory management
|
728 |
+
def initialize_flux_bot():
|
729 |
+
try:
|
730 |
+
torch.cuda.empty_cache() # Clear GPU memory cache
|
731 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16) # Use FP16
|
732 |
+
pipe.to(device) # Move the model to the correct device (GPU/CPU)
|
733 |
+
except torch.cuda.OutOfMemoryError:
|
734 |
+
print("CUDA out of memory, switching to CPU.")
|
735 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float32) # Use FP32 for CPU
|
736 |
+
pipe.to("cpu")
|
737 |
+
return pipe
|
738 |
|
739 |
+
# Function to generate image using Flux bot on the specified device
|
740 |
+
def generate_image_flux(prompt):
|
741 |
+
pipe = initialize_flux_bot()
|
742 |
+
image = pipe(
|
743 |
prompt,
|
744 |
guidance_scale=0.0,
|
745 |
+
num_inference_steps=2, # Reduced steps to save memory
|
746 |
+
max_sequence_length=128, # Reduced sequence length to save memory
|
747 |
+
generator=torch.Generator(device).manual_seed(0)
|
748 |
).images[0]
|
749 |
return image
|
750 |
|
751 |
+
# Hardcoded prompts for the images
|
752 |
+
hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
|
753 |
+
hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
|
754 |
+
hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
|
755 |
+
|
756 |
+
# Function to update images
|
757 |
+
def update_images():
|
758 |
+
image_1 = generate_image_flux(hardcoded_prompt_1)
|
759 |
+
image_2 = generate_image_flux(hardcoded_prompt_2)
|
760 |
+
image_3 = generate_image_flux(hardcoded_prompt_3)
|
761 |
+
return image_1, image_2, image_3
|
762 |
+
|
763 |
|
764 |
|
765 |
|
|
|
1505 |
|
1506 |
with gr.Column():
|
1507 |
|
1508 |
+
# Display images
|
1509 |
+
image_output_1 = gr.Image(value=generate_image_flux(hardcoded_prompt_1), width=400, height=400)
|
1510 |
+
image_output_2 = gr.Image(value=generate_image_flux(hardcoded_prompt_2), width=400, height=400)
|
1511 |
+
image_output_3 = gr.Image(value=generate_image_flux(hardcoded_prompt_3), width=400, height=400)
|
|
|
|
|
1512 |
|
1513 |
+
# Refresh button to update images
|
1514 |
+
refresh_button = gr.Button("Refresh Images")
|
1515 |
+
refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
|
1516 |
|
1517 |
|
1518 |
|
1519 |
|
1520 |
+
|
1521 |
+
|
|
|
1522 |
|
1523 |
demo.queue()
|
1524 |
demo.launch(show_error=True)
|