Dakerqi commited on
Commit
6d80425
·
verified ·
1 Parent(s): d6a69c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -84,6 +84,7 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt
84
  attention_mask=prompt_masks,
85
  output_hidden_states=True,
86
  ).hidden_states[-2]
 
87
 
88
  return prompt_embeds, prompt_masks
89
 
@@ -242,13 +243,13 @@ def model_main(args, master_port, rank, request_queue, response_queue, mp_barrie
242
  torch.random.manual_seed(int(seed))
243
  z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
244
  z = z.repeat(2, 1, 1, 1)
245
-
246
  with torch.no_grad():
247
  if neg_cap != "":
248
  cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
249
  else:
250
  cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
251
-
252
  cap_mask = cap_mask.to(cap_feats.device)
253
 
254
  model_kwargs = dict(
 
84
  attention_mask=prompt_masks,
85
  output_hidden_states=True,
86
  ).hidden_states[-2]
87
+ text_encoder.cpu()
88
 
89
  return prompt_embeds, prompt_masks
90
 
 
243
  torch.random.manual_seed(int(seed))
244
  z = torch.randn([1, 16, latent_h, latent_w], device="cuda").to(dtype)
245
  z = z.repeat(2, 1, 1, 1)
246
+ model.cpu()
247
  with torch.no_grad():
248
  if neg_cap != "":
249
  cap_feats, cap_mask = encode_prompt([cap] + [neg_cap], text_encoder, tokenizer, 0.0)
250
  else:
251
  cap_feats, cap_mask = encode_prompt([cap] + [""], text_encoder, tokenizer, 0.0)
252
+ model.cuda()
253
  cap_mask = cap_mask.to(cap_feats.device)
254
 
255
  model_kwargs = dict(