Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -756,27 +756,36 @@ import torch
|
|
756 |
from diffusers import DiffusionPipeline
|
757 |
import os
|
758 |
|
759 |
-
# Set PYTORCH_CUDA_ALLOC_CONF to
|
760 |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
761 |
|
762 |
-
#
|
763 |
-
torch.cuda.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
764 |
|
765 |
-
|
766 |
-
|
767 |
-
|
|
|
768 |
|
769 |
-
# Load the
|
770 |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
|
771 |
|
772 |
-
# Reduce inference steps
|
773 |
-
def generate_image_flux(prompt, seed=42, width=
|
774 |
generator = torch.Generator(device).manual_seed(seed)
|
775 |
image = pipe(
|
776 |
prompt=prompt,
|
777 |
width=width,
|
778 |
height=height,
|
779 |
-
num_inference_steps=num_inference_steps, # Reduce
|
780 |
generator=generator,
|
781 |
guidance_scale=0.0
|
782 |
).images[0]
|
|
|
756 |
from diffusers import DiffusionPipeline
|
757 |
import os
|
758 |
|
759 |
+
# Set PYTORCH_CUDA_ALLOC_CONF to handle memory fragmentation
|
760 |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
761 |
|
762 |
+
# Check GPU memory and fallback to CPU if necessary
|
763 |
+
if torch.cuda.is_available():
|
764 |
+
free_memory = torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
|
765 |
+
if free_memory < 100 * 1024 * 1024: # If less than 100 MB is free
|
766 |
+
print("Low GPU memory, switching to CPU.")
|
767 |
+
device = "cpu"
|
768 |
+
else:
|
769 |
+
device = "cuda"
|
770 |
+
else:
|
771 |
+
device = "cpu"
|
772 |
|
773 |
+
dtype = torch.float16 if device == "cuda" else torch.float32 # Use float16 for GPU and float32 for CPU
|
774 |
+
|
775 |
+
# Clear any existing GPU memory cache
|
776 |
+
torch.cuda.empty_cache()
|
777 |
|
778 |
+
# Load the pipeline
|
779 |
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
|
780 |
|
781 |
+
# Reduce the inference steps and image dimensions
|
782 |
+
def generate_image_flux(prompt, seed=42, width=512, height=512, num_inference_steps=2):
|
783 |
generator = torch.Generator(device).manual_seed(seed)
|
784 |
image = pipe(
|
785 |
prompt=prompt,
|
786 |
width=width,
|
787 |
height=height,
|
788 |
+
num_inference_steps=num_inference_steps, # Reduce steps to save memory
|
789 |
generator=generator,
|
790 |
guidance_scale=0.0
|
791 |
).images[0]
|