KingNish commited on
Commit
44d4a2f
·
verified ·
1 Parent(s): 589972a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -159
app.py CHANGED
@@ -46,6 +46,9 @@ except FileNotFoundError:
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
 
 
 
49
  import argparse
50
  import numpy as np
51
  import json
@@ -64,8 +67,8 @@ import time
64
  import copy
65
  from collections import Counter
66
  from models.soundstream_hubert_new import SoundStream
67
-
68
- # don't change above code
69
 
70
  device = "cuda:0"
71
 
@@ -79,6 +82,9 @@ model.eval()
79
 
80
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
81
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
 
 
 
82
 
83
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
84
 
@@ -88,8 +94,19 @@ model_config = OmegaConf.load(basic_model_config)
88
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
89
  parameter_dict = torch.load(resume_path, map_location='cpu')
90
  codec_model.load_state_dict(parameter_dict['codec_model'])
 
91
  codec_model.eval()
92
 
 
 
 
 
 
 
 
 
 
 
93
  @spaces.GPU(duration=120)
94
  def generate_music(
95
  max_new_tokens=5,
@@ -108,164 +125,198 @@ def generate_music(
108
  cuda_idx = cuda_idx
109
  max_new_tokens = max_new_tokens * 100
110
 
111
- class BlockTokenRangeProcessor(LogitsProcessor):
112
- def __init__(self, start_id, end_id):
113
- self.blocked_token_ids = list(range(start_id, end_id))
114
-
115
- def __call__(self, input_ids, scores):
116
- scores[:, self.blocked_token_ids] = -float("inf")
117
- return scores
118
-
119
- def load_audio_mono(filepath, sampling_rate=16000):
120
- audio, sr = torchaudio.load(filepath)
121
- # Convert to mono
122
- audio = torch.mean(audio, dim=0, keepdim=True)
123
- # Resample if needed
124
- if sr != sampling_rate:
125
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
126
- audio = resampler(audio)
127
- return audio
128
-
129
- def split_lyrics(lyrics: str):
130
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
131
- segments = re.findall(pattern, lyrics, re.DOTALL)
132
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
133
- return structured_lyrics
134
-
135
- # Call the function and print the result
136
- stage1_output_set = []
137
-
138
- genres = genre_txt.strip()
139
- lyrics = split_lyrics(lyrics_txt + "\n")
140
- # intruction
141
- full_lyrics = "\n".join(lyrics)
142
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
143
- prompt_texts += lyrics
144
-
145
- random_id = uuid.uuid4()
146
- output_seq = None
147
- # Here is suggested decoding config
148
- top_p = 0.93
149
- temperature = 1.0
150
- repetition_penalty = 1.2
151
- # special tokens
152
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
153
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
154
-
155
- raw_output = None
156
-
157
- # Format text prompt
158
- run_n_segments = min(run_n_segments + 1, len(lyrics))
159
-
160
- print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
161
-
162
- for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
163
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
164
- guidance_scale = 1.5 if i <= 1 else 1.2
165
- if i == 0:
166
- continue
167
- if i == 1:
168
- if use_audio_prompt:
169
- audio_prompt = load_audio_mono(audio_prompt_path)
170
- audio_prompt.unsqueeze_(0)
171
- with torch.no_grad():
172
- raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
173
- raw_codes = raw_codes.transpose(0, 1)
174
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
175
- # Format audio prompt
176
- code_ids = codectool.npy2ids(raw_codes[0])
177
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
178
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
179
- mmtokenizer.eoa]
180
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
181
- "[end_of_reference]")
182
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  else:
184
- head_id = mmtokenizer.tokenize(prompt_texts[0])
185
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
186
- else:
187
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
188
-
189
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
190
- input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
191
- # Use window slicing in case output sequence exceeds the context of model
192
- max_context = 16384 - max_new_tokens - 1
193
- if input_ids.shape[-1] > max_context:
194
- print(
195
- f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
196
- input_ids = input_ids[:, -(max_context):]
197
- with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
198
- output_seq = model.generate(
199
- input_ids=input_ids,
200
- max_new_tokens=max_new_tokens,
201
- min_new_tokens=100,
202
- do_sample=True,
203
- top_p=top_p,
204
- temperature=temperature,
205
- repetition_penalty=repetition_penalty,
206
- eos_token_id=mmtokenizer.eoa,
207
- pad_token_id=mmtokenizer.eoa,
208
- logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
209
- guidance_scale=guidance_scale,
210
- use_cache=True
211
- )
212
- if output_seq[0][-1].item() != mmtokenizer.eoa:
213
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
214
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
215
- if i > 1:
216
- raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
217
- else:
218
- raw_output = output_seq
219
- print(len(raw_output))
220
-
221
- # save raw output and check sanity
222
- ids = raw_output[0].cpu().numpy()
223
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
224
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
225
- if len(soa_idx) != len(eoa_idx):
226
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
227
-
228
- vocals = []
229
- instrumentals = []
230
- range_begin = 1 if use_audio_prompt else 0
231
- for i in range(range_begin, len(soa_idx)):
232
- codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
233
- if codec_ids[0] == 32016:
234
- codec_ids = codec_ids[1:]
235
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
236
- vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
237
- vocals.append(vocals_ids)
238
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
239
- instrumentals.append(instrumentals_ids)
240
- vocals = np.concatenate(vocals, axis=1)
241
- instrumentals = np.concatenate(instrumentals, axis=1)
242
-
243
- vocal_audio = None
244
- instrumental_audio = None
245
- mixed_audio = None
246
-
247
- # convert audio tokens to audio
248
- def convert_to_audio(codec_result, rescale):
249
- with torch.no_grad():
250
- decoded_waveform = codec_model.decode(
251
- # Corrected line: Convert numpy array to PyTorch tensor with appropriate type
252
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)
253
- )
254
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
255
-
256
- limit = 0.99
257
- max_val = decoded_waveform.abs().max()
258
- scaled_waveform = decoded_waveform * min(limit / max_val, 1) if rescale else decoded_waveform.clamp(-limit, limit)
259
- # Corrected line: Convert to numpy array before casting to int16
260
- return (16000, (scaled_waveform * 32767).detach().cpu().numpy().astype(np.int16))
261
-
262
- vocal_audio = convert_to_audio(vocals, rescale)
263
- instrumental_audio = convert_to_audio(instrumentals, rescale)
264
-
265
- mix_stem = (vocal_audio[1] + instrumental_audio[1]) / 1 # mixing by summing and dividing
266
- mixed_audio = (vocal_audio[0], mix_stem) # same sample rate
267
-
268
- return (vocal_audio[0], (mix_stem * 32767).astype(np.int16)), None, None
269
 
270
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
271
  # Execute the command
@@ -279,6 +330,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=
279
  finally:
280
  print("Temporary files deleted.")
281
 
 
282
  # Gradio
283
  with gr.Blocks() as demo:
284
  with gr.Column():
 
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
49
+
50
+ # don't change above code
51
+
52
  import argparse
53
  import numpy as np
54
  import json
 
67
  import copy
68
  from collections import Counter
69
  from models.soundstream_hubert_new import SoundStream
70
+ #from vocoder import build_codec_model, process_audio # removed vocoder
71
+ #from post_process_audio import replace_low_freq_with_energy_matched # removed post process
72
 
73
  device = "cuda:0"
74
 
 
82
 
83
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
84
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
85
+ #config_path = './xcodec_mini_infer/decoders/config.yaml' # removed vocoder
86
+ #vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth' # removed vocoder
87
+ #inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth' # removed vocoder
88
 
89
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
90
 
 
94
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
95
  parameter_dict = torch.load(resume_path, map_location='cpu')
96
  codec_model.load_state_dict(parameter_dict['codec_model'])
97
+ # codec_model = torch.compile(codec_model)
98
  codec_model.eval()
99
 
100
+ # Preload and compile vocoders # removed vocoder
101
+ #vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
102
+ #vocal_decoder.to(device)
103
+ #inst_decoder.to(device)
104
+ #vocal_decoder = torch.compile(vocal_decoder)
105
+ #inst_decoder = torch.compile(inst_decoder)
106
+ #vocal_decoder.eval()
107
+ #inst_decoder.eval()
108
+
109
+
110
  @spaces.GPU(duration=120)
111
  def generate_music(
112
  max_new_tokens=5,
 
125
  cuda_idx = cuda_idx
126
  max_new_tokens = max_new_tokens * 100
127
 
128
+ with tempfile.TemporaryDirectory() as output_dir:
129
+ stage1_output_dir = os.path.join(output_dir, f"stage1")
130
+ os.makedirs(stage1_output_dir, exist_ok=True)
131
+
132
+ class BlockTokenRangeProcessor(LogitsProcessor):
133
+ def __init__(self, start_id, end_id):
134
+ self.blocked_token_ids = list(range(start_id, end_id))
135
+
136
+ def __call__(self, input_ids, scores):
137
+ scores[:, self.blocked_token_ids] = -float("inf")
138
+ return scores
139
+
140
+ def load_audio_mono(filepath, sampling_rate=16000):
141
+ audio, sr = torchaudio.load(filepath)
142
+ # Convert to mono
143
+ audio = torch.mean(audio, dim=0, keepdim=True)
144
+ # Resample if needed
145
+ if sr != sampling_rate:
146
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
147
+ audio = resampler(audio)
148
+ return audio
149
+
150
+ def split_lyrics(lyrics: str):
151
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
152
+ segments = re.findall(pattern, lyrics, re.DOTALL)
153
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
154
+ return structured_lyrics
155
+
156
+ # Call the function and print the result
157
+ stage1_output_set = []
158
+
159
+ genres = genre_txt.strip()
160
+ lyrics = split_lyrics(lyrics_txt + "\n")
161
+ # intruction
162
+ full_lyrics = "\n".join(lyrics)
163
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
164
+ prompt_texts += lyrics
165
+
166
+ random_id = uuid.uuid4()
167
+ output_seq = None
168
+ # Here is suggested decoding config
169
+ top_p = 0.93
170
+ temperature = 1.0
171
+ repetition_penalty = 1.2
172
+ # special tokens
173
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
174
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
175
+
176
+ raw_output = None
177
+
178
+ # Format text prompt
179
+ run_n_segments = min(run_n_segments + 1, len(lyrics))
180
+
181
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
182
+
183
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
184
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
185
+ guidance_scale = 1.5 if i <= 1 else 1.2
186
+ if i == 0:
187
+ continue
188
+ if i == 1:
189
+ if use_audio_prompt:
190
+ audio_prompt = load_audio_mono(audio_prompt_path)
191
+ audio_prompt.unsqueeze_(0)
192
+ with torch.no_grad():
193
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
194
+ raw_codes = raw_codes.transpose(0, 1)
195
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
196
+ # Format audio prompt
197
+ code_ids = codectool.npy2ids(raw_codes[0])
198
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
199
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
200
+ mmtokenizer.eoa]
201
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
202
+ "[end_of_reference]")
203
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
204
+ else:
205
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
206
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
207
+ else:
208
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
209
+
210
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
211
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
212
+ # Use window slicing in case output sequence exceeds the context of model
213
+ max_context = 16384 - max_new_tokens - 1
214
+ if input_ids.shape[-1] > max_context:
215
+ print(
216
+ f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
217
+ input_ids = input_ids[:, -(max_context):]
218
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
219
+ output_seq = model.generate(
220
+ input_ids=input_ids,
221
+ max_new_tokens=max_new_tokens,
222
+ min_new_tokens=100,
223
+ do_sample=True,
224
+ top_p=top_p,
225
+ temperature=temperature,
226
+ repetition_penalty=repetition_penalty,
227
+ eos_token_id=mmtokenizer.eoa,
228
+ pad_token_id=mmtokenizer.eoa,
229
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
230
+ guidance_scale=guidance_scale,
231
+ use_cache=True
232
+ )
233
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
234
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
235
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
236
+ if i > 1:
237
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
238
  else:
239
+ raw_output = output_seq
240
+ print(len(raw_output))
241
+
242
+ # save raw output and check sanity
243
+ ids = raw_output[0].cpu().numpy()
244
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
245
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
246
+ if len(soa_idx) != len(eoa_idx):
247
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
248
+
249
+ vocals = []
250
+ instrumentals = []
251
+ range_begin = 1 if use_audio_prompt else 0
252
+ for i in range(range_begin, len(soa_idx)):
253
+ codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
254
+ if codec_ids[0] == 32016:
255
+ codec_ids = codec_ids[1:]
256
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
257
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
258
+ vocals.append(vocals_ids)
259
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
260
+ instrumentals.append(instrumentals_ids)
261
+ vocals = np.concatenate(vocals, axis=1)
262
+ instrumentals = np.concatenate(instrumentals, axis=1)
263
+
264
+ vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
265
+ inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
266
+ np.save(vocal_save_path, vocals)
267
+ np.save(inst_save_path, instrumentals)
268
+ stage1_output_set.append(vocal_save_path)
269
+ stage1_output_set.append(inst_save_path)
270
+
271
+ print("Converting to Audio...")
272
+
273
+ # convert audio tokens to audio
274
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
275
+ folder_path = os.path.dirname(path)
276
+ if not os.path.exists(folder_path):
277
+ os.makedirs(folder_path)
278
+ limit = 0.99
279
+ max_val = wav.abs().max()
280
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
281
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
282
+
283
+ # reconstruct tracks
284
+ recons_output_dir = os.path.join(output_dir, "recons")
285
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
286
+ os.makedirs(recons_mix_dir, exist_ok=True)
287
+ tracks = []
288
+ for npy in stage1_output_set:
289
+ codec_result = np.load(npy)
290
+ decodec_rlt = []
291
+ with torch.no_grad():
292
+ decoded_waveform = codec_model.decode(
293
+ torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
294
+ device))
295
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
296
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
297
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
298
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
299
+ tracks.append(save_path)
300
+ save_audio(decodec_rlt, save_path, 16000)
301
+ # mix tracks
302
+ for inst_path in tracks:
303
+ try:
304
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
305
+ and 'instrumental' in inst_path:
306
+ # find pair
307
+ vocal_path = inst_path.replace('instrumental', 'vocal')
308
+ if not os.path.exists(vocal_path):
309
+ continue
310
+ # mix
311
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
312
+ vocal_stem, sr = sf.read(inst_path)
313
+ instrumental_stem, _ = sf.read(vocal_path)
314
+ mix_stem = (vocal_stem + instrumental_stem) / 1
315
+ return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
316
+ except Exception as e:
317
+ print(e)
318
+ return None, None, None
319
+
 
 
 
 
320
 
321
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
322
  # Execute the command
 
330
  finally:
331
  print("Temporary files deleted.")
332
 
333
+
334
  # Gradio
335
  with gr.Blocks() as demo:
336
  with gr.Column():