ford442 commited on
Commit
097644d
·
verified ·
1 Parent(s): aa0c4ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -92,6 +92,7 @@ def scheduler_swap_callback(pipeline, step_index, timestep, callback_kwargs):
92
  torch.backends.cuda.preferred_blas_library="cublaslt"
93
  if step_index == int(pipeline.num_timesteps * 0.5):
94
  # torch.set_float32_matmul_precision("medium")
 
95
  pipe.unet.to(torch.float64)
96
  # pipe.guidance_scale=1.0
97
  # pipe.scheduler.set_timesteps(num_inference_steps*.70)
@@ -102,6 +103,7 @@ def scheduler_swap_callback(pipeline, step_index, timestep, callback_kwargs):
102
  torch.backends.cudnn.allow_tf32 = False
103
  torch.backends.cuda.matmul.allow_tf32 = False
104
  torch.set_float32_matmul_precision("highest")
 
105
  pipe.unet.to(torch.bfloat16)
106
  # pipe.vae = vae_a
107
  # pipe.unet = unet_a
 
92
  torch.backends.cuda.preferred_blas_library="cublaslt"
93
  if step_index == int(pipeline.num_timesteps * 0.5):
94
  # torch.set_float32_matmul_precision("medium")
95
+ callback_kwargs["latents"].to(torch.float64)
96
  pipe.unet.to(torch.float64)
97
  # pipe.guidance_scale=1.0
98
  # pipe.scheduler.set_timesteps(num_inference_steps*.70)
 
103
  torch.backends.cudnn.allow_tf32 = False
104
  torch.backends.cuda.matmul.allow_tf32 = False
105
  torch.set_float32_matmul_precision("highest")
106
+ callback_kwargs["latents"].to(torch.bfloat16)
107
  pipe.unet.to(torch.bfloat16)
108
  # pipe.vae = vae_a
109
  # pipe.unet = unet_a