KingNish commited on
Commit
e2fefec
·
verified ·
1 Parent(s): a02a3fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -76,8 +76,8 @@ model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
78
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
79
- )
80
- model.to(device)
81
  model.eval()
82
 
83
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
@@ -90,15 +90,19 @@ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model"
90
 
91
  codectool = CodecManipulator("xcodec", 0, 1)
92
  model_config = OmegaConf.load(basic_model_config)
 
93
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
94
  parameter_dict = torch.load(resume_path, map_location='cpu')
95
  codec_model.load_state_dict(parameter_dict['codec_model'])
96
- codec_model.to(device)
97
  codec_model.eval()
98
 
 
99
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
100
  vocal_decoder.to(device)
101
  inst_decoder.to(device)
 
 
102
  vocal_decoder.eval()
103
  inst_decoder.eval()
104
 
 
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
77
  torch_dtype=torch.float16,
78
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
79
+ ).to(device)
80
+ model = torch.compile(model)
81
  model.eval()
82
 
83
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
 
90
 
91
  codectool = CodecManipulator("xcodec", 0, 1)
92
  model_config = OmegaConf.load(basic_model_config)
93
+ # Load codec model
94
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
95
  parameter_dict = torch.load(resume_path, map_location='cpu')
96
  codec_model.load_state_dict(parameter_dict['codec_model'])
97
+ codec_model = torch.compile(codec_model)
98
  codec_model.eval()
99
 
100
+ # Preload and compile vocoders
101
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
102
  vocal_decoder.to(device)
103
  inst_decoder.to(device)
104
+ vocal_decoder = torch.compile(vocal_decoder)
105
+ inst_decoder = torch.compile(inst_decoder)
106
  vocal_decoder.eval()
107
  inst_decoder.eval()
108