Pijush2023 commited on
Commit
2985c37
·
verified ·
1 Parent(s): e02621f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -21
app.py CHANGED
@@ -713,29 +713,53 @@ def generate_map(location_names):
713
  # return image_1, image_2, image_3
714
 
715
 
 
716
  import torch
717
  from diffusers import FluxPipeline
 
718
 
719
- def initialize_flux_pipeline():
720
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
721
- pipe.enable_model_cpu_offload() # Offload to CPU to save VRAM
722
- return pipe
723
 
724
- # Initialize the model
725
- flux_pipe = initialize_flux_pipeline()
726
 
 
 
 
 
 
 
 
 
 
 
 
727
 
728
- def generate_flux_image(prompt):
729
- # Use the initialized flux_pipe to generate an image based on the input prompt
730
- image = flux_pipe(
 
731
  prompt,
732
  guidance_scale=0.0,
733
- num_inference_steps=4,
734
- max_sequence_length=256,
735
- generator=torch.Generator("cpu").manual_seed(0)
736
  ).images[0]
737
  return image
738
 
 
 
 
 
 
 
 
 
 
 
 
 
739
 
740
 
741
 
@@ -1481,20 +1505,20 @@ with gr.Blocks(theme='Pijush2023/scikit-learn-pijush') as demo:
1481
 
1482
  with gr.Column():
1483
 
1484
- flux_prompt = gr.Textbox(show_copy_button=True, label="Flux Prompt", placeholder="Enter prompt for Flux image generation")
1485
- flux_image_output = gr.Image()
1486
- flux_generate_button = gr.Button("Generate Flux Image")
1487
-
1488
- # When the button is clicked, the image generation function is triggered
1489
- flux_generate_button.click(fn=generate_flux_image, inputs=flux_prompt, outputs=flux_image_output)
1490
 
 
 
 
1491
 
1492
 
1493
 
1494
 
1495
- # Refresh button to update images
1496
- # refresh_button = gr.Button("Refresh Images")
1497
- # refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
1498
 
1499
  demo.queue()
1500
  demo.launch(show_error=True)
 
713
  # return image_1, image_2, image_3
714
 
715
 
716
+ import gradio as gr
717
  import torch
718
  from diffusers import FluxPipeline
719
+ import os
720
 
721
+ # Set PYTORCH_CUDA_ALLOC_CONF to handle memory fragmentation
722
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
 
 
723
 
724
+ # Check if CUDA (GPU) is available, otherwise fallback to CPU
725
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
726
 
727
+ # Function to initialize Flux bot model with GPU memory management
728
+ def initialize_flux_bot():
729
+ try:
730
+ torch.cuda.empty_cache() # Clear GPU memory cache
731
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16) # Use FP16
732
+ pipe.to(device) # Move the model to the correct device (GPU/CPU)
733
+ except torch.cuda.OutOfMemoryError:
734
+ print("CUDA out of memory, switching to CPU.")
735
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float32) # Use FP32 for CPU
736
+ pipe.to("cpu")
737
+ return pipe
738
 
739
+ # Function to generate image using Flux bot on the specified device
740
+ def generate_image_flux(prompt):
741
+ pipe = initialize_flux_bot()
742
+ image = pipe(
743
  prompt,
744
  guidance_scale=0.0,
745
+ num_inference_steps=2, # Reduced steps to save memory
746
+ max_sequence_length=128, # Reduced sequence length to save memory
747
+ generator=torch.Generator(device).manual_seed(0)
748
  ).images[0]
749
  return image
750
 
751
+ # Hardcoded prompts for the images
752
+ hardcoded_prompt_1 = "A high quality cinematic image for Toyota Truck in Birmingham skyline shot in the style of Michael Mann"
753
+ hardcoded_prompt_2 = "A high quality cinematic image for Alabama Quarterback close up emotional shot in the style of Michael Mann"
754
+ hardcoded_prompt_3 = "A high quality cinematic image for Taylor Swift concert in Birmingham skyline style of Michael Mann"
755
+
756
+ # Function to update images
757
+ def update_images():
758
+ image_1 = generate_image_flux(hardcoded_prompt_1)
759
+ image_2 = generate_image_flux(hardcoded_prompt_2)
760
+ image_3 = generate_image_flux(hardcoded_prompt_3)
761
+ return image_1, image_2, image_3
762
+
763
 
764
 
765
 
 
1505
 
1506
  with gr.Column():
1507
 
1508
+ # Display images
1509
+ image_output_1 = gr.Image(value=generate_image_flux(hardcoded_prompt_1), width=400, height=400)
1510
+ image_output_2 = gr.Image(value=generate_image_flux(hardcoded_prompt_2), width=400, height=400)
1511
+ image_output_3 = gr.Image(value=generate_image_flux(hardcoded_prompt_3), width=400, height=400)
 
 
1512
 
1513
+ # Refresh button to update images
1514
+ refresh_button = gr.Button("Refresh Images")
1515
+ refresh_button.click(fn=update_images, inputs=None, outputs=[image_output_1, image_output_2, image_output_3])
1516
 
1517
 
1518
 
1519
 
1520
+
1521
+
 
1522
 
1523
  demo.queue()
1524
  demo.launch(show_error=True)