Update app.py
Browse files
app.py
CHANGED
@@ -206,7 +206,7 @@ def stage2_generate(model_stage2, prompt, batch_size=16):
|
|
206 |
for frames_idx in range(codec_ids_tensor.shape[1]):
|
207 |
cb0 = codec_ids_tensor[:, frames_idx:frames_idx+1]
|
208 |
prompt_ids_tensor = torch.cat([prompt_ids_tensor, cb0], dim=1)
|
209 |
-
with torch.
|
210 |
stage2_output = model_stage2.generate(
|
211 |
input_ids=prompt_ids_tensor,
|
212 |
min_new_tokens=7,
|
@@ -214,6 +214,7 @@ def stage2_generate(model_stage2, prompt, batch_size=16):
|
|
214 |
eos_token_id=mmtokenizer.eoa,
|
215 |
pad_token_id=mmtokenizer.eoa,
|
216 |
logits_processor=block_list,
|
|
|
217 |
)
|
218 |
# Ensure exactly 7 new tokens were added.
|
219 |
assert stage2_output.shape[1] - prompt_ids_tensor.shape[1] == 7, (
|
|
|
206 |
for frames_idx in range(codec_ids_tensor.shape[1]):
|
207 |
cb0 = codec_ids_tensor[:, frames_idx:frames_idx+1]
|
208 |
prompt_ids_tensor = torch.cat([prompt_ids_tensor, cb0], dim=1)
|
209 |
+
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.float16):
|
210 |
stage2_output = model_stage2.generate(
|
211 |
input_ids=prompt_ids_tensor,
|
212 |
min_new_tokens=7,
|
|
|
214 |
eos_token_id=mmtokenizer.eoa,
|
215 |
pad_token_id=mmtokenizer.eoa,
|
216 |
logits_processor=block_list,
|
217 |
+
use_cache=True
|
218 |
)
|
219 |
# Ensure exactly 7 new tokens were added.
|
220 |
assert stage2_output.shape[1] - prompt_ids_tensor.shape[1] == 7, (
|