cocktailpeanut commited on
Commit
4692982
·
1 Parent(s): 4f0058c
Files changed (1) hide show
  1. gradio_demo/app.py +5 -3
gradio_demo/app.py CHANGED
@@ -117,7 +117,8 @@ def instantir_restore(
117
  print(f"use lora alpha {lora_alpha}")
118
  lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
119
  print(f"use lora alpha {lora_alpha}")
120
- pipe.to(device=device, dtype=torch_dtype)
 
121
  pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
122
  lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
123
 
@@ -125,10 +126,11 @@ def instantir_restore(
125
  print("Loading checkpoint...")
126
  aggregator_state_dict = torch.load(
127
  f"{instantir_path}/aggregator.pt",
128
- # map_location="cpu"
129
  )
130
  pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
131
- pipe.aggregator.to(device=device, dtype=torch_dtype)
 
132
 
133
  print("******loaded")
134
 
 
117
  print(f"use lora alpha {lora_alpha}")
118
  lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
119
  print(f"use lora alpha {lora_alpha}")
120
+ if not cpu_offload:
121
+ pipe.to(device=device, dtype=torch_dtype)
122
  pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
123
  lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
124
 
 
126
  print("Loading checkpoint...")
127
  aggregator_state_dict = torch.load(
128
  f"{instantir_path}/aggregator.pt",
129
+ map_location="cpu"
130
  )
131
  pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
132
+ if not cpu_offload:
133
+ pipe.aggregator.to(device=device, dtype=torch_dtype)
134
 
135
  print("******loaded")
136