Update app.py
Browse files
app.py
CHANGED
@@ -77,9 +77,18 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
77 |
torch_dtype=torch.float16,
|
78 |
attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
|
79 |
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
# model = torch.compile(model)
|
|
|
81 |
model.eval()
|
82 |
|
|
|
|
|
83 |
basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
|
84 |
resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
|
85 |
config_path = './xcodec_mini_infer/decoders/config.yaml'
|
@@ -231,6 +240,7 @@ def generate_music(
|
|
231 |
use_cache=True,
|
232 |
top_k=50,
|
233 |
num_beams=1,
|
|
|
234 |
)
|
235 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
236 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|
|
|
77 |
torch_dtype=torch.float16,
|
78 |
attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
|
79 |
).to(device)
|
80 |
+
assistant_model = AutoModelForCausalLM.from_pretrained(
|
81 |
+
"m-a-p/YuE-s2-1B-general",
|
82 |
+
torch_dtype=torch.float16,
|
83 |
+
attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
|
84 |
+
).to(device)
|
85 |
+
# assistant_model = torch.compile(assistant_model)
|
86 |
# model = torch.compile(model)
|
87 |
+
assistant_model.eval()
|
88 |
model.eval()
|
89 |
|
90 |
+
|
91 |
+
|
92 |
basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
|
93 |
resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
|
94 |
config_path = './xcodec_mini_infer/decoders/config.yaml'
|
|
|
240 |
use_cache=True,
|
241 |
top_k=50,
|
242 |
num_beams=1,
|
243 |
+
assistant_model=assistant_model
|
244 |
)
|
245 |
if output_seq[0][-1].item() != mmtokenizer.eoa:
|
246 |
tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
|