Update app.py
Browse files
app.py
CHANGED
@@ -215,23 +215,24 @@ def generate_music(
|
|
215 |
Performs model inference to generate music tokens.
|
216 |
This function is decorated with @spaces.GPU for GPU usage in Gradio Spaces.
|
217 |
"""
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
235 |
return output_seq
|
236 |
|
237 |
output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
|
|
|
215 |
Performs model inference to generate music tokens.
|
216 |
This function is decorated with @spaces.GPU for GPU usage in Gradio Spaces.
|
217 |
"""
|
218 |
+
with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
|
219 |
+
output_seq = model.generate(
|
220 |
+
input_ids=input_ids,
|
221 |
+
max_new_tokens=max_new_tokens,
|
222 |
+
min_new_tokens=100, # Keep min_new_tokens to avoid short generations
|
223 |
+
do_sample=True,
|
224 |
+
top_p=top_p,
|
225 |
+
temperature=temperature,
|
226 |
+
repetition_penalty=repetition_penalty,
|
227 |
+
eos_token_id=mmtokenizer.eoa,
|
228 |
+
pad_token_id=mmtokenizer.eoa,
|
229 |
+
logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
|
230 |
+
guidance_scale=guidance_scale,
|
231 |
+
use_cache=True
|
232 |
+
)
|
233 |
+
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
234 |
+
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
235 |
+
output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
|
236 |
return output_seq
|
237 |
|
238 |
output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
|