ford442 commited on
Commit
3f9d242
·
verified ·
1 Parent(s): 64e7076

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -93,7 +93,14 @@ def scheduler_swap_callback(pipeline, step_index, timestep, callback_kwargs):
93
  if step_index == int(pipeline.num_timesteps * 0.5):
94
  # torch.set_float32_matmul_precision("medium")
95
  callback_kwargs["latents"] = callback_kwargs["latents"].to(torch.float64)
96
- pipe.unet.to(torch.float64)
 
 
 
 
 
 
 
97
  # pipe.guidance_scale=1.0
98
  # pipe.scheduler.set_timesteps(num_inference_steps*.70)
99
  # print(f"-- setting step {pipeline.num_timesteps * 0.1} --")
@@ -104,7 +111,14 @@ def scheduler_swap_callback(pipeline, step_index, timestep, callback_kwargs):
104
  torch.backends.cuda.matmul.allow_tf32 = False
105
  torch.set_float32_matmul_precision("highest")
106
  callback_kwargs["latents"] = callback_kwargs["latents"].to(torch.bfloat16)
107
- pipe.unet.to(torch.bfloat16)
 
 
 
 
 
 
 
108
  # pipe.vae = vae_a
109
  # pipe.unet = unet_a
110
  # torch.backends.cudnn.deterministic = False
 
93
  if step_index == int(pipeline.num_timesteps * 0.5):
94
  # torch.set_float32_matmul_precision("medium")
95
  callback_kwargs["latents"] = callback_kwargs["latents"].to(torch.float64)
96
+ def change_dtype(module):
97
+ for child in module.children():
98
+ if len(list(child.children())) > 0:
99
+ change_dtype(child)
100
+ for param in child.parameters():
101
+ param.data = param.data.to(torch.float64)
102
+
103
+ change_dtype(pipeline.unet)
104
  # pipe.guidance_scale=1.0
105
  # pipe.scheduler.set_timesteps(num_inference_steps*.70)
106
  # print(f"-- setting step {pipeline.num_timesteps * 0.1} --")
 
111
  torch.backends.cuda.matmul.allow_tf32 = False
112
  torch.set_float32_matmul_precision("highest")
113
  callback_kwargs["latents"] = callback_kwargs["latents"].to(torch.bfloat16)
114
+ def change_dtype(module):
115
+ for child in module.children():
116
+ if len(list(child.children())) > 0:
117
+ change_dtype(child)
118
+ for param in child.parameters():
119
+ param.data = param.data.to(torch.bfloat16)
120
+
121
+ change_dtype(pipeline.unet)
122
  # pipe.vae = vae_a
123
  # pipe.unet = unet_a
124
  # torch.backends.cudnn.deterministic = False