KingNish commited on
Commit
be4c769
·
verified ·
1 Parent(s): 5b4f482

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -215,23 +215,24 @@ def generate_music(
215
  Performs model inference to generate music tokens.
216
  This function is decorated with @spaces.GPU for GPU usage in Gradio Spaces.
217
  """
218
- output_seq = model.generate(
219
- input_ids=input_ids,
220
- max_new_tokens=max_new_tokens,
221
- min_new_tokens=100, # Keep min_new_tokens to avoid short generations
222
- do_sample=True,
223
- top_p=top_p,
224
- temperature=temperature,
225
- repetition_penalty=repetition_penalty,
226
- eos_token_id=mmtokenizer.eoa,
227
- pad_token_id=mmtokenizer.eoa,
228
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
229
- guidance_scale=guidance_scale,
230
- use_cache=True
231
- )
232
- if output_seq[0][-1].item() != mmtokenizer.eoa:
233
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
234
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
 
235
  return output_seq
236
 
237
  output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)
 
215
  Performs model inference to generate music tokens.
216
  This function is decorated with @spaces.GPU for GPU usage in Gradio Spaces.
217
  """
218
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
219
+ output_seq = model.generate(
220
+ input_ids=input_ids,
221
+ max_new_tokens=max_new_tokens,
222
+ min_new_tokens=100, # Keep min_new_tokens to avoid short generations
223
+ do_sample=True,
224
+ top_p=top_p,
225
+ temperature=temperature,
226
+ repetition_penalty=repetition_penalty,
227
+ eos_token_id=mmtokenizer.eoa,
228
+ pad_token_id=mmtokenizer.eoa,
229
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
230
+ guidance_scale=guidance_scale,
231
+ use_cache=True
232
+ )
233
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
234
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
235
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
236
  return output_seq
237
 
238
  output_seq = model_inference(input_ids, max_new_tokens, top_p, temperature, repetition_penalty, guidance_scale)