Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
9914b63
1
Parent(s):
54f9225
Optimize Infinity model loading by clearing CUDA cache and adjusting device assignment; remove redundant calls
Browse files
app.py
CHANGED
@@ -197,7 +197,8 @@ def load_infinity(
|
|
197 |
):
|
198 |
print(f'[Loading Infinity]')
|
199 |
text_maxlen = 512
|
200 |
-
|
|
|
201 |
infinity_test: Infinity = Infinity(
|
202 |
vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
|
203 |
shared_aln=True, raw_scale_schedule=scale_schedule,
|
@@ -215,7 +216,7 @@ def load_infinity(
|
|
215 |
inference_mode=True,
|
216 |
train_h_div_w_list=[1.0],
|
217 |
**model_kwargs,
|
218 |
-
).to(device
|
219 |
print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
|
220 |
|
221 |
if bf16:
|
@@ -225,9 +226,6 @@ def load_infinity(
|
|
225 |
infinity_test.eval()
|
226 |
infinity_test.requires_grad_(False)
|
227 |
|
228 |
-
infinity_test.cuda()
|
229 |
-
torch.cuda.empty_cache()
|
230 |
-
|
231 |
print(f'[Load Infinity weights]')
|
232 |
state_dict = torch.load(model_path, map_location=device)
|
233 |
print(infinity_test.load_state_dict(state_dict))
|
@@ -529,7 +527,6 @@ with gr.Blocks() as demo:
|
|
529 |
# Output Section
|
530 |
gr.Markdown("### Generated Image")
|
531 |
output_image = gr.Image(label="Generated Image", type="pil")
|
532 |
-
gr.Markdown("**Tip:** Right-click the image to save it.")
|
533 |
|
534 |
# Error Handling
|
535 |
error_message = gr.Textbox(label="Error Message", visible=False)
|
|
|
197 |
):
|
198 |
print(f'[Loading Infinity]')
|
199 |
text_maxlen = 512
|
200 |
+
torch.cuda.empty_cache()
|
201 |
+
with torch.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
|
202 |
infinity_test: Infinity = Infinity(
|
203 |
vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
|
204 |
shared_aln=True, raw_scale_schedule=scale_schedule,
|
|
|
216 |
inference_mode=True,
|
217 |
train_h_div_w_list=[1.0],
|
218 |
**model_kwargs,
|
219 |
+
).to(device)
|
220 |
print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
|
221 |
|
222 |
if bf16:
|
|
|
226 |
infinity_test.eval()
|
227 |
infinity_test.requires_grad_(False)
|
228 |
|
|
|
|
|
|
|
229 |
print(f'[Load Infinity weights]')
|
230 |
state_dict = torch.load(model_path, map_location=device)
|
231 |
print(infinity_test.load_state_dict(state_dict))
|
|
|
527 |
# Output Section
|
528 |
gr.Markdown("### Generated Image")
|
529 |
output_image = gr.Image(label="Generated Image", type="pil")
|
|
|
530 |
|
531 |
# Error Handling
|
532 |
error_message = gr.Textbox(label="Error Message", visible=False)
|