MohamedRashad commited on
Commit
9914b63
·
1 Parent(s): 54f9225

Optimize Infinity model loading by clearing CUDA cache and adjusting device assignment; remove redundant calls

Browse files
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -197,7 +197,8 @@ def load_infinity(
197
  ):
198
  print(f'[Loading Infinity]')
199
  text_maxlen = 512
200
- with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
 
201
  infinity_test: Infinity = Infinity(
202
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
203
  shared_aln=True, raw_scale_schedule=scale_schedule,
@@ -215,7 +216,7 @@ def load_infinity(
215
  inference_mode=True,
216
  train_h_div_w_list=[1.0],
217
  **model_kwargs,
218
- ).to(device=device)
219
  print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
220
 
221
  if bf16:
@@ -225,9 +226,6 @@ def load_infinity(
225
  infinity_test.eval()
226
  infinity_test.requires_grad_(False)
227
 
228
- infinity_test.cuda()
229
- torch.cuda.empty_cache()
230
-
231
  print(f'[Load Infinity weights]')
232
  state_dict = torch.load(model_path, map_location=device)
233
  print(infinity_test.load_state_dict(state_dict))
@@ -529,7 +527,6 @@ with gr.Blocks() as demo:
529
  # Output Section
530
  gr.Markdown("### Generated Image")
531
  output_image = gr.Image(label="Generated Image", type="pil")
532
- gr.Markdown("**Tip:** Right-click the image to save it.")
533
 
534
  # Error Handling
535
  error_message = gr.Textbox(label="Error Message", visible=False)
 
197
  ):
198
  print(f'[Loading Infinity]')
199
  text_maxlen = 512
200
+ torch.cuda.empty_cache()
201
+ with torch.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
202
  infinity_test: Infinity = Infinity(
203
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
204
  shared_aln=True, raw_scale_schedule=scale_schedule,
 
216
  inference_mode=True,
217
  train_h_div_w_list=[1.0],
218
  **model_kwargs,
219
+ ).to(device)
220
  print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
221
 
222
  if bf16:
 
226
  infinity_test.eval()
227
  infinity_test.requires_grad_(False)
228
 
 
 
 
229
  print(f'[Load Infinity weights]')
230
  state_dict = torch.load(model_path, map_location=device)
231
  print(infinity_test.load_state_dict(state_dict))
 
527
  # Output Section
528
  gr.Markdown("### Generated Image")
529
  output_image = gr.Image(label="Generated Image", type="pil")
 
530
 
531
  # Error Handling
532
  error_message = gr.Textbox(label="Error Message", visible=False)