cocktailpeanut commited on
Commit
d9a6cae
·
1 Parent(s): 1df1048
Files changed (1) hide show
  1. gradio_demo/app.py +5 -19
gradio_demo/app.py CHANGED
@@ -117,8 +117,7 @@ def instantir_restore(
117
  print(f"use lora alpha {lora_alpha}")
118
  lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
119
  print(f"use lora alpha {lora_alpha}")
120
- if not cpu_offload:
121
- pipe.to(device=device, dtype=torch_dtype)
122
  pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
123
  lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
124
 
@@ -126,15 +125,10 @@ def instantir_restore(
126
  print("Loading checkpoint...")
127
  aggregator_state_dict = torch.load(
128
  f"{instantir_path}/aggregator.pt",
129
- # map_location = device
130
- # map_location = device if not cpu_offload else "cpu"
131
  map_location="cpu"
132
  )
133
  pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
134
- if cpu_offload:
135
- pipe.aggregator.to(device="cpu")
136
- else:
137
- pipe.aggregator.to(device=device, dtype=torch_dtype)
138
 
139
  print("******loaded")
140
 
@@ -145,12 +139,8 @@ def instantir_restore(
145
  if "previewer" not in pipe.unet.active_adapters():
146
  pipe.unet.set_adapter('previewer')
147
 
148
- print('optimizing')
149
- # pipe.enable_vae_tiling()
150
- if cpu_offload:
151
- pipe.enable_model_cpu_offload()
152
- # pipe.enable_sequential_cpu_offload()
153
- print('done')
154
 
155
 
156
  if isinstance(guidance_end, int):
@@ -168,10 +158,6 @@ def instantir_restore(
168
  else:
169
  lq = [resize_img(lq.convert("RGB"), size=None)]
170
 
171
- #if cpu_offload:
172
- # generator = torch.Generator().manual_seed(seed)
173
- #else:
174
- # generator = torch.Generator(device=device).manual_seed(seed)
175
  generator = torch.Generator(device=device).manual_seed(seed)
176
  timesteps = [
177
  i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps)
@@ -266,7 +252,7 @@ with gr.Blocks() as demo:
266
  preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1)
267
  prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="")
268
  mode = gr.Checkbox(label="Creative Restoration", value=False)
269
- cpu_offload = gr.Checkbox(label="CPU offload", info="If you have a lot of GPU VRAM, uncheck this option for faster generation", value=False)
270
 
271
 
272
  with gr.Row():
 
117
  print(f"use lora alpha {lora_alpha}")
118
  lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
119
  print(f"use lora alpha {lora_alpha}")
120
+ pipe.to(device=device, dtype=torch_dtype)
 
121
  pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
122
  lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
123
 
 
125
  print("Loading checkpoint...")
126
  aggregator_state_dict = torch.load(
127
  f"{instantir_path}/aggregator.pt",
 
 
128
  map_location="cpu"
129
  )
130
  pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
131
+ pipe.aggregator.to(device=device, dtype=torch_dtype)
 
 
 
132
 
133
  print("******loaded")
134
 
 
139
  if "previewer" not in pipe.unet.active_adapters():
140
  pipe.unet.set_adapter('previewer')
141
 
142
+ # if cpu_offload:
143
+ # pipe.enable_model_cpu_offload()
 
 
 
 
144
 
145
 
146
  if isinstance(guidance_end, int):
 
158
  else:
159
  lq = [resize_img(lq.convert("RGB"), size=None)]
160
 
 
 
 
 
161
  generator = torch.Generator(device=device).manual_seed(seed)
162
  timesteps = [
163
  i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps)
 
252
  preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1)
253
  prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="")
254
  mode = gr.Checkbox(label="Creative Restoration", value=False)
255
+ cpu_offload = gr.Checkbox(label="CPU offload", info="If you have a lot of GPU VRAM, uncheck this option for faster generation", value=False, visible=False)
256
 
257
 
258
  with gr.Row():