Nick088 commited on
Commit
5f92874
·
verified ·
1 Parent(s): 1653571

Switching back to cpu offload instead of sequential one, empty cache

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -25,25 +25,25 @@ MAX_SEED = np.iinfo(np.int32).max
25
  sd3_medium_pipe = StableDiffusion3Pipeline.from_pretrained(
26
  "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
27
  )
28
- sd3_medium_pipe.enable_sequential_cpu_offload()
29
 
30
  # sd 2.1
31
  sd2_1_pipe = StableDiffusionPipeline.from_pretrained(
32
  "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
33
  )
34
- sd2_1_pipe.enable_sequential_cpu_offload()
35
 
36
  # sdxl
37
  sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
38
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
39
  )
40
- sdxl_pipe.enable_sequential_cpu_offload()
41
 
42
  # sdxl flash
43
  sdxl_flash_pipe = StableDiffusionXLPipeline.from_pretrained(
44
  "sd-community/sdxl-flash", torch_dtype=torch.float16
45
  )
46
- sdxl_flash_pipe.enable_sequential_cpu_offload()
47
  # Ensure sampler uses "trailing" timesteps for sdxl flash.
48
  sdxl_flash_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
49
  sdxl_flash_pipe.scheduler.config, timestep_spacing="trailing"
@@ -53,18 +53,20 @@ sdxl_flash_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
53
  stable_cascade_prior_pipe = StableCascadePriorPipeline.from_pretrained(
54
  "stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16
55
  )
56
- stable_cascade_prior_pipe.enable_sequential_cpu_offload()
57
  stable_cascade_decoder_pipe = StableCascadeDecoderPipeline.from_pretrained(
58
  "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16
59
  )
60
- stable_cascade_decoder_pipe.enable_sequential_cpu_offload()
61
 
62
  # sd 1.5
63
  sd1_5_pipe = StableDiffusionPipeline.from_pretrained(
64
  "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
65
  )
66
- sd1_5_pipe.enable_sequential_cpu_offload()
67
 
 
 
68
 
69
  # Helper function to generate images for a single model
70
  @spaces.GPU(duration=80)
@@ -134,6 +136,9 @@ def generate_single_image(
134
  num_images_per_prompt=num_images_per_prompt,
135
  ).images
136
 
 
 
 
137
  return output
138
 
139
 
 
25
  sd3_medium_pipe = StableDiffusion3Pipeline.from_pretrained(
26
  "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
27
  )
28
+ sd3_medium_pipe.enable_cpu_offload()
29
 
30
  # sd 2.1
31
  sd2_1_pipe = StableDiffusionPipeline.from_pretrained(
32
  "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
33
  )
34
+ sd2_1_pipe.enable_cpu_offload()
35
 
36
  # sdxl
37
  sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
38
  "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
39
  )
40
+ sdxl_pipe.enable_cpu_offload()
41
 
42
  # sdxl flash
43
  sdxl_flash_pipe = StableDiffusionXLPipeline.from_pretrained(
44
  "sd-community/sdxl-flash", torch_dtype=torch.float16
45
  )
46
+ sdxl_flash_pipe.enable_cpu_offload()
47
  # Ensure sampler uses "trailing" timesteps for sdxl flash.
48
  sdxl_flash_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(
49
  sdxl_flash_pipe.scheduler.config, timestep_spacing="trailing"
 
53
  stable_cascade_prior_pipe = StableCascadePriorPipeline.from_pretrained(
54
  "stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=torch.bfloat16
55
  )
56
+ stable_cascade_prior_pipe.enable_cpu_offload()
57
  stable_cascade_decoder_pipe = StableCascadeDecoderPipeline.from_pretrained(
58
  "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16
59
  )
60
+ stable_cascade_decoder_pipe.enable_cpu_offload()
61
 
62
  # sd 1.5
63
  sd1_5_pipe = StableDiffusionPipeline.from_pretrained(
64
  "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
65
  )
66
+ sd1_5_pipe.enable_cpu_offload()
67
 
68
+ # empty cache to free up gpu memory before inference
69
+ torch.cuda.empty_cache()
70
 
71
  # Helper function to generate images for a single model
72
  @spaces.GPU(duration=80)
 
136
  num_images_per_prompt=num_images_per_prompt,
137
  ).images
138
 
139
+ # empty cache to free up gpu memory
140
+ torch.cuda.empty_cache()
141
+
142
  return output
143
 
144