Pijush2023 commited on
Commit
726c820
·
verified ·
1 Parent(s): b89bf3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -756,27 +756,36 @@ import torch
756
  from diffusers import DiffusionPipeline
757
  import os
758
 
759
- # Set PYTORCH_CUDA_ALLOC_CONF to avoid memory fragmentation
760
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
761
 
762
- # Clear CUDA cache before loading the model
763
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
764
 
765
- # Use a smaller dtype (e.g., torch.float16)
766
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
767
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
768
 
769
- # Load the model with a smaller precision
770
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
771
 
772
- # Reduce inference steps to save memory
773
- def generate_image_flux(prompt, seed=42, width=width, height=height, num_inference_steps=2):
774
  generator = torch.Generator(device).manual_seed(seed)
775
  image = pipe(
776
  prompt=prompt,
777
  width=width,
778
  height=height,
779
- num_inference_steps=num_inference_steps, # Reduce the number of inference steps
780
  generator=generator,
781
  guidance_scale=0.0
782
  ).images[0]
 
756
  from diffusers import DiffusionPipeline
757
  import os
758
 
759
+ # Set PYTORCH_CUDA_ALLOC_CONF to handle memory fragmentation
760
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
761
 
762
+ # Check GPU memory and fallback to CPU if necessary
763
+ if torch.cuda.is_available():
764
+ free_memory = torch.cuda.memory_reserved(0) - torch.cuda.memory_allocated(0)
765
+ if free_memory < 100 * 1024 * 1024: # If less than 100 MB is free
766
+ print("Low GPU memory, switching to CPU.")
767
+ device = "cpu"
768
+ else:
769
+ device = "cuda"
770
+ else:
771
+ device = "cpu"
772
 
773
+ dtype = torch.float16 if device == "cuda" else torch.float32 # Use float16 for GPU and float32 for CPU
774
+
775
+ # Clear any existing GPU memory cache
776
+ torch.cuda.empty_cache()
777
 
778
+ # Load the pipeline
779
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
780
 
781
+ # Reduce the inference steps and image dimensions
782
+ def generate_image_flux(prompt, seed=42, width=512, height=512, num_inference_steps=2):
783
  generator = torch.Generator(device).manual_seed(seed)
784
  image = pipe(
785
  prompt=prompt,
786
  width=width,
787
  height=height,
788
+ num_inference_steps=num_inference_steps, # Reduce steps to save memory
789
  generator=generator,
790
  guidance_scale=0.0
791
  ).images[0]