KingNish commited on
Commit
a02a3fd
·
verified ·
1 Parent(s): ab8cd62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -23
app.py CHANGED
@@ -70,7 +70,7 @@ from models.soundstream_hubert_new import SoundStream
70
  from vocoder import build_codec_model, process_audio
71
  from post_process_audio import replace_low_freq_with_energy_matched
72
 
73
- device = "cuda"
74
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
@@ -90,23 +90,18 @@ mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model"
90
 
91
  codectool = CodecManipulator("xcodec", 0, 1)
92
  model_config = OmegaConf.load(basic_model_config)
93
- # Load codec model
94
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
95
  parameter_dict = torch.load(resume_path, map_location='cpu')
96
  codec_model.load_state_dict(parameter_dict['codec_model'])
97
- codec_model = torch.compile(codec_model)
98
  codec_model.eval()
99
 
100
- # Preload and compile vocoders
101
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
102
  vocal_decoder.to(device)
103
  inst_decoder.to(device)
104
- vocal_decoder = torch.compile(vocal_decoder)
105
- inst_decoder = torch.compile(inst_decoder)
106
  vocal_decoder.eval()
107
  inst_decoder.eval()
108
 
109
- cuda_idx = 0
110
 
111
  def generate_music(
112
  max_new_tokens=5,
@@ -117,14 +112,13 @@ def generate_music(
117
  audio_prompt_path="",
118
  prompt_start_time=0.0,
119
  prompt_end_time=30.0,
 
120
  rescale=False,
121
  ):
122
  if use_audio_prompt and not audio_prompt_path:
123
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
124
- # Initial setup with memory-only processing
125
- # ------------------------------------------
126
  max_new_tokens = max_new_tokens * 100
127
- stage1_output_data = {}
128
 
129
  with tempfile.TemporaryDirectory() as output_dir:
130
  stage1_output_dir = os.path.join(output_dir, f"stage1")
@@ -179,17 +173,7 @@ def generate_music(
179
  # Format text prompt
180
  run_n_segments = min(run_n_segments + 1, len(lyrics))
181
 
182
- generation_config = {
183
- 'top_p': 0.93,
184
- 'temperature': 1.0,
185
- 'repetition_penalty': 1.2,
186
- 'top_k': 50, # Faster than top_p alone
187
- 'num_beams': 1, # Disable beam search
188
- 'max_new_tokens': max_new_tokens,
189
- 'min_new_tokens': 100,
190
- 'do_sample': True,
191
- 'use_cache': True,
192
- }
193
 
194
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
195
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
@@ -226,7 +210,7 @@ def generate_music(
226
  print(
227
  f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
228
  input_ids = input_ids[:, -(max_context):]
229
- with torch.inference_mode(), torch.autocast(device_type=device, dtype=torch.float16):
230
  output_seq = model.generate(
231
  input_ids=input_ids,
232
  max_new_tokens=max_new_tokens,
@@ -390,7 +374,8 @@ def generate_music(
390
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
391
  # Execute the command
392
  try:
393
- audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments, max_new_tokens=max_new_tokens)
 
394
  return audio_data
395
  except Exception as e:
396
  gr.Warning("An Error Occured: " + str(e))
 
70
  from vocoder import build_codec_model, process_audio
71
  from post_process_audio import replace_low_freq_with_energy_matched
72
 
73
+ device = "cuda:0"
74
 
75
  model = AutoModelForCausalLM.from_pretrained(
76
  "m-a-p/YuE-s1-7B-anneal-en-cot",
 
90
 
91
  codectool = CodecManipulator("xcodec", 0, 1)
92
  model_config = OmegaConf.load(basic_model_config)
 
93
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
94
  parameter_dict = torch.load(resume_path, map_location='cpu')
95
  codec_model.load_state_dict(parameter_dict['codec_model'])
96
+ codec_model.to(device)
97
  codec_model.eval()
98
 
 
99
  vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
100
  vocal_decoder.to(device)
101
  inst_decoder.to(device)
 
 
102
  vocal_decoder.eval()
103
  inst_decoder.eval()
104
 
 
105
 
106
  def generate_music(
107
  max_new_tokens=5,
 
112
  audio_prompt_path="",
113
  prompt_start_time=0.0,
114
  prompt_end_time=30.0,
115
+ cuda_idx=0,
116
  rescale=False,
117
  ):
118
  if use_audio_prompt and not audio_prompt_path:
119
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
120
+ cuda_idx = cuda_idx
 
121
  max_new_tokens = max_new_tokens * 100
 
122
 
123
  with tempfile.TemporaryDirectory() as output_dir:
124
  stage1_output_dir = os.path.join(output_dir, f"stage1")
 
173
  # Format text prompt
174
  run_n_segments = min(run_n_segments + 1, len(lyrics))
175
 
176
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
 
 
 
 
 
 
 
 
 
 
177
 
178
  for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
179
  section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
 
210
  print(
211
  f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
212
  input_ids = input_ids[:, -(max_context):]
213
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
214
  output_seq = model.generate(
215
  input_ids=input_ids,
216
  max_new_tokens=max_new_tokens,
 
374
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=10):
375
  # Execute the command
376
  try:
377
+ audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
378
+ cuda_idx=0, max_new_tokens=max_new_tokens)
379
  return audio_data
380
  except Exception as e:
381
  gr.Warning("An Error Occured: " + str(e))