KingNish commited on
Commit
c3cdb06
·
verified ·
1 Parent(s): 75625eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -90,15 +90,19 @@ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model"
90
 
91
  codectool = CodecManipulator("xcodec", 0, 1)
92
  model_config = OmegaConf.load(basic_model_config)
 
93
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
94
  parameter_dict = torch.load(resume_path, map_location='cpu')
95
  codec_model.load_state_dict(parameter_dict['codec_model'])
96
- codec_model.to(device)
97
  codec_model.eval()
98
 
 
99
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
100
  vocal_decoder.to(device)
101
  inst_decoder.to(device)
 
 
102
  vocal_decoder.eval()
103
  inst_decoder.eval()
104
 
@@ -112,13 +116,14 @@ def generate_music(
112
  audio_prompt_path="",
113
  prompt_start_time=0.0,
114
  prompt_end_time=30.0,
115
- cuda_idx=0,
116
  rescale=False,
117
  ):
118
  if use_audio_prompt and not audio_prompt_path:
119
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
120
- cuda_idx = cuda_idx
 
121
  max_new_tokens = max_new_tokens * 100
 
122
 
123
  with tempfile.TemporaryDirectory() as output_dir:
124
  stage1_output_dir = os.path.join(output_dir, f"stage1")
@@ -173,7 +178,17 @@ def generate_music(
173
  # Format text prompt
174
  run_n_segments = min(run_n_segments + 1, len(lyrics))
175
 
176
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
 
 
 
 
 
 
 
 
 
 
177
 
178
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
179
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
@@ -210,7 +225,7 @@ def generate_music(
210
  print(
211
  f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
212
  input_ids = input_ids[:, -(max_context):]
213
- with torch.no_grad():
214
  output_seq = model.generate(
215
  input_ids=input_ids,
216
  max_new_tokens=max_new_tokens,
@@ -374,8 +389,7 @@ def generate_music(
374
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
375
  # Execute the command
376
  try:
377
- audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
378
- cuda_idx=0, max_new_tokens=max_new_tokens)
379
  return audio_data
380
  except Exception as e:
381
  gr.Warning("An Error Occured: " + str(e))
 
90
 
91
  codectool = CodecManipulator("xcodec", 0, 1)
92
  model_config = OmegaConf.load(basic_model_config)
93
+ # Load codec model
94
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
95
  parameter_dict = torch.load(resume_path, map_location='cpu')
96
  codec_model.load_state_dict(parameter_dict['codec_model'])
97
+ codec_model = torch.compile(codec_model)
98
  codec_model.eval()
99
 
100
+ # Preload and compile vocoders
101
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
102
  vocal_decoder.to(device)
103
  inst_decoder.to(device)
104
+ vocal_decoder = torch.compile(vocal_decoder)
105
+ inst_decoder = torch.compile(inst_decoder)
106
  vocal_decoder.eval()
107
  inst_decoder.eval()
108
 
 
116
  audio_prompt_path="",
117
  prompt_start_time=0.0,
118
  prompt_end_time=30.0,
 
119
  rescale=False,
120
  ):
121
  if use_audio_prompt and not audio_prompt_path:
122
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
123
+ # Initial setup with memory-only processing
124
+ # ------------------------------------------
125
  max_new_tokens = max_new_tokens * 100
126
+ stage1_output_data = {}
127
 
128
  with tempfile.TemporaryDirectory() as output_dir:
129
  stage1_output_dir = os.path.join(output_dir, f"stage1")
 
178
  # Format text prompt
179
  run_n_segments = min(run_n_segments + 1, len(lyrics))
180
 
181
+ generation_config = {
182
+ 'top_p': 0.93,
183
+ 'temperature': 1.0,
184
+ 'repetition_penalty': 1.2,
185
+ 'top_k': 50, # Faster than top_p alone
186
+ 'num_beams': 1, # Disable beam search
187
+ 'max_new_tokens': max_new_tokens,
188
+ 'min_new_tokens': 100,
189
+ 'do_sample': True,
190
+ 'use_cache': True,
191
+ }
192
 
193
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
194
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
 
225
  print(
226
  f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
227
  input_ids = input_ids[:, -(max_context):]
228
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
229
  output_seq = model.generate(
230
  input_ids=input_ids,
231
  max_new_tokens=max_new_tokens,
 
389
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
390
  # Execute the command
391
  try:
392
+ audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, max_new_tokens=max_new_tokens)
 
393
  return audio_data
394
  except Exception as e:
395
  gr.Warning("An Error Occured: " + str(e))