KingNish commited on
Commit
d6882b3
·
verified ·
1 Parent(s): 954ab16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -0
app.py CHANGED
@@ -77,9 +77,18 @@ model = AutoModelForCausalLM.from_pretrained(
77
  torch_dtype=torch.float16,
78
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
79
  ).to(device)
 
 
 
 
 
 
80
  # model = torch.compile(model)
 
81
  model.eval()
82
 
 
 
83
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
84
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
85
  config_path = './xcodec_mini_infer/decoders/config.yaml'
@@ -231,6 +240,7 @@ def generate_music(
231
  use_cache=True,
232
  top_k=50,
233
  num_beams=1,
 
234
  )
235
  if output_seq[0][-1].item() != mmtokenizer.eoa:
236
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
 
77
  torch_dtype=torch.float16,
78
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
79
  ).to(device)
80
+ assistant_model = AutoModelForCausalLM.from_pretrained(
81
+ "m-a-p/YuE-s2-1B-general",
82
+ torch_dtype=torch.float16,
83
+ attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
84
+ ).to(device)
85
+ # assistant_model = torch.compile(assistant_model)
86
  # model = torch.compile(model)
87
+ assistant_model.eval()
88
  model.eval()
89
 
90
+
91
+
92
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
93
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
94
  config_path = './xcodec_mini_infer/decoders/config.yaml'
 
240
  use_cache=True,
241
  top_k=50,
242
  num_beams=1,
243
+ assistant_model=assistant_model
244
  )
245
  if output_seq[0][-1].item() != mmtokenizer.eoa:
246
  tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)