Spaces:
Running
Running
Commit
·
d9a6cae
1
Parent(s):
1df1048
update
Browse files- gradio_demo/app.py +5 -19
gradio_demo/app.py
CHANGED
@@ -117,8 +117,7 @@ def instantir_restore(
|
|
117 |
print(f"use lora alpha {lora_alpha}")
|
118 |
lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
|
119 |
print(f"use lora alpha {lora_alpha}")
|
120 |
-
|
121 |
-
pipe.to(device=device, dtype=torch_dtype)
|
122 |
pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
|
123 |
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
124 |
|
@@ -126,15 +125,10 @@ def instantir_restore(
|
|
126 |
print("Loading checkpoint...")
|
127 |
aggregator_state_dict = torch.load(
|
128 |
f"{instantir_path}/aggregator.pt",
|
129 |
-
# map_location = device
|
130 |
-
# map_location = device if not cpu_offload else "cpu"
|
131 |
map_location="cpu"
|
132 |
)
|
133 |
pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
|
134 |
-
|
135 |
-
pipe.aggregator.to(device="cpu")
|
136 |
-
else:
|
137 |
-
pipe.aggregator.to(device=device, dtype=torch_dtype)
|
138 |
|
139 |
print("******loaded")
|
140 |
|
@@ -145,12 +139,8 @@ def instantir_restore(
|
|
145 |
if "previewer" not in pipe.unet.active_adapters():
|
146 |
pipe.unet.set_adapter('previewer')
|
147 |
|
148 |
-
|
149 |
-
#
|
150 |
-
if cpu_offload:
|
151 |
-
pipe.enable_model_cpu_offload()
|
152 |
-
# pipe.enable_sequential_cpu_offload()
|
153 |
-
print('done')
|
154 |
|
155 |
|
156 |
if isinstance(guidance_end, int):
|
@@ -168,10 +158,6 @@ def instantir_restore(
|
|
168 |
else:
|
169 |
lq = [resize_img(lq.convert("RGB"), size=None)]
|
170 |
|
171 |
-
#if cpu_offload:
|
172 |
-
# generator = torch.Generator().manual_seed(seed)
|
173 |
-
#else:
|
174 |
-
# generator = torch.Generator(device=device).manual_seed(seed)
|
175 |
generator = torch.Generator(device=device).manual_seed(seed)
|
176 |
timesteps = [
|
177 |
i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps)
|
@@ -266,7 +252,7 @@ with gr.Blocks() as demo:
|
|
266 |
preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1)
|
267 |
prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="")
|
268 |
mode = gr.Checkbox(label="Creative Restoration", value=False)
|
269 |
-
cpu_offload = gr.Checkbox(label="CPU offload", info="If you have a lot of GPU VRAM, uncheck this option for faster generation", value=False)
|
270 |
|
271 |
|
272 |
with gr.Row():
|
|
|
117 |
print(f"use lora alpha {lora_alpha}")
|
118 |
lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
|
119 |
print(f"use lora alpha {lora_alpha}")
|
120 |
+
pipe.to(device=device, dtype=torch_dtype)
|
|
|
121 |
pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
|
122 |
lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
|
123 |
|
|
|
125 |
print("Loading checkpoint...")
|
126 |
aggregator_state_dict = torch.load(
|
127 |
f"{instantir_path}/aggregator.pt",
|
|
|
|
|
128 |
map_location="cpu"
|
129 |
)
|
130 |
pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
|
131 |
+
pipe.aggregator.to(device=device, dtype=torch_dtype)
|
|
|
|
|
|
|
132 |
|
133 |
print("******loaded")
|
134 |
|
|
|
139 |
if "previewer" not in pipe.unet.active_adapters():
|
140 |
pipe.unet.set_adapter('previewer')
|
141 |
|
142 |
+
# if cpu_offload:
|
143 |
+
# pipe.enable_model_cpu_offload()
|
|
|
|
|
|
|
|
|
144 |
|
145 |
|
146 |
if isinstance(guidance_end, int):
|
|
|
158 |
else:
|
159 |
lq = [resize_img(lq.convert("RGB"), size=None)]
|
160 |
|
|
|
|
|
|
|
|
|
161 |
generator = torch.Generator(device=device).manual_seed(seed)
|
162 |
timesteps = [
|
163 |
i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps)
|
|
|
252 |
preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1)
|
253 |
prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="")
|
254 |
mode = gr.Checkbox(label="Creative Restoration", value=False)
|
255 |
+
cpu_offload = gr.Checkbox(label="CPU offload", info="If you have a lot of GPU VRAM, uncheck this option for faster generation", value=False, visible=False)
|
256 |
|
257 |
|
258 |
with gr.Row():
|