KingNish commited on
Commit
ba1f3a6
·
verified ·
1 Parent(s): 6f6b5d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -206,7 +206,7 @@ def stage2_generate(model_stage2, prompt, batch_size=16):
206
  for frames_idx in range(codec_ids_tensor.shape[1]):
207
  cb0 = codec_ids_tensor[:, frames_idx:frames_idx+1]
208
  prompt_ids_tensor = torch.cat([prompt_ids_tensor, cb0], dim=1)
209
- with torch.no_grad():
210
  stage2_output = model_stage2.generate(
211
  input_ids=prompt_ids_tensor,
212
  min_new_tokens=7,
@@ -214,6 +214,7 @@ def stage2_generate(model_stage2, prompt, batch_size=16):
214
  eos_token_id=mmtokenizer.eoa,
215
  pad_token_id=mmtokenizer.eoa,
216
  logits_processor=block_list,
 
217
  )
218
  # Ensure exactly 7 new tokens were added.
219
  assert stage2_output.shape[1] - prompt_ids_tensor.shape[1] == 7, (
 
206
  for frames_idx in range(codec_ids_tensor.shape[1]):
207
  cb0 = codec_ids_tensor[:, frames_idx:frames_idx+1]
208
  prompt_ids_tensor = torch.cat([prompt_ids_tensor, cb0], dim=1)
209
+ with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
210
  stage2_output = model_stage2.generate(
211
  input_ids=prompt_ids_tensor,
212
  min_new_tokens=7,
 
214
  eos_token_id=mmtokenizer.eoa,
215
  pad_token_id=mmtokenizer.eoa,
216
  logits_processor=block_list,
217
+ use_cache=True
218
  )
219
  # Ensure exactly 7 new tokens were added.
220
  assert stage2_output.shape[1] - prompt_ids_tensor.shape[1] == 7, (