Update app.py
Browse files
app.py
CHANGED
@@ -76,8 +76,8 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
76 |
"m-a-p/YuE-s1-7B-anneal-en-cot",
|
77 |
torch_dtype=torch.float16,
|
78 |
attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
|
79 |
-
)
|
80 |
-
model.
|
81 |
model.eval()
|
82 |
|
83 |
basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
|
@@ -90,15 +90,19 @@ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model"
|
|
90 |
|
91 |
codectool = CodecManipulator("xcodec", 0, 1)
|
92 |
model_config = OmegaConf.load(basic_model_config)
|
|
|
93 |
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
|
94 |
parameter_dict = torch.load(resume_path, map_location='cpu')
|
95 |
codec_model.load_state_dict(parameter_dict['codec_model'])
|
96 |
-
codec_model.
|
97 |
codec_model.eval()
|
98 |
|
|
|
99 |
vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
|
100 |
vocal_decoder.to(device)
|
101 |
inst_decoder.to(device)
|
|
|
|
|
102 |
vocal_decoder.eval()
|
103 |
inst_decoder.eval()
|
104 |
|
|
|
76 |
"m-a-p/YuE-s1-7B-anneal-en-cot",
|
77 |
torch_dtype=torch.float16,
|
78 |
attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
|
79 |
+
).to(device)
|
80 |
+
model = torch.compile(model)
|
81 |
model.eval()
|
82 |
|
83 |
basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
|
|
|
90 |
|
91 |
codectool = CodecManipulator("xcodec", 0, 1)
|
92 |
model_config = OmegaConf.load(basic_model_config)
|
93 |
+
# Load codec model
|
94 |
codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
|
95 |
parameter_dict = torch.load(resume_path, map_location='cpu')
|
96 |
codec_model.load_state_dict(parameter_dict['codec_model'])
|
97 |
+
codec_model = torch.compile(codec_model)
|
98 |
codec_model.eval()
|
99 |
|
100 |
+
# Preload and compile vocoders
|
101 |
vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
|
102 |
vocal_decoder.to(device)
|
103 |
inst_decoder.to(device)
|
104 |
+
vocal_decoder = torch.compile(vocal_decoder)
|
105 |
+
inst_decoder = torch.compile(inst_decoder)
|
106 |
vocal_decoder.eval()
|
107 |
inst_decoder.eval()
|
108 |
|