KingNish commited on
Commit
b1201e2
·
verified ·
1 Parent(s): 9897c6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -233
app.py CHANGED
@@ -46,8 +46,6 @@ except FileNotFoundError:
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
49
- # don't change above code
50
-
51
  import argparse
52
  import numpy as np
53
  import json
@@ -66,8 +64,8 @@ import time
66
  import copy
67
  from collections import Counter
68
  from models.soundstream_hubert_new import SoundStream
69
- #from vocoder import build_codec_model, process_audio # removed vocoder
70
- #from post_process_audio import replace_low_freq_with_energy_matched # removed post process
71
 
72
  device = "cuda:0"
73
 
@@ -81,9 +79,6 @@ model.eval()
81
 
82
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
83
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
84
- #config_path = './xcodec_mini_infer/decoders/config.yaml' # removed vocoder
85
- #vocal_decoder_path = './xcodec_mini_infer/decoders/decoder_131000.pth' # removed vocoder
86
- #inst_decoder_path = './xcodec_mini_infer/decoders/decoder_151000.pth' # removed vocoder
87
 
88
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
89
 
@@ -93,18 +88,8 @@ model_config = OmegaConf.load(basic_model_config)
93
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
94
  parameter_dict = torch.load(resume_path, map_location='cpu')
95
  codec_model.load_state_dict(parameter_dict['codec_model'])
96
- # codec_model = torch.compile(codec_model)
97
  codec_model.eval()
98
 
99
- # Preload and compile vocoders # removed vocoder
100
- #vocal_decoder, inst_decoder = build_codec_model(config_path, vocal_decoder_path, inst_decoder_path)
101
- #vocal_decoder.to(device)
102
- #inst_decoder.to(device)
103
- #vocal_decoder = torch.compile(vocal_decoder)
104
- #inst_decoder = torch.compile(inst_decoder)
105
- #vocal_decoder.eval()
106
- #inst_decoder.eval()
107
-
108
  @spaces.GPU(duration=120)
109
  def generate_music(
110
  max_new_tokens=5,
@@ -117,234 +102,174 @@ def generate_music(
117
  prompt_end_time=30.0,
118
  cuda_idx=0,
119
  rescale=False,
120
- batch_size=1
121
  ):
122
  if use_audio_prompt and not audio_prompt_path:
123
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
124
  cuda_idx = cuda_idx
125
  max_new_tokens = max_new_tokens * 100
126
 
127
- with tempfile.TemporaryDirectory() as output_dir:
128
- stage1_output_dir = os.path.join(output_dir, f"stage1")
129
- os.makedirs(stage1_output_dir, exist_ok=True)
130
-
131
- class BlockTokenRangeProcessor(LogitsProcessor):
132
- def __init__(self, start_id, end_id):
133
- self.blocked_token_ids = list(range(start_id, end_id))
134
-
135
- def __call__(self, input_ids, scores):
136
- scores[:, self.blocked_token_ids] = -float("inf")
137
- return scores
138
-
139
- def load_audio_mono(filepath, sampling_rate=16000):
140
- audio, sr = torchaudio.load(filepath)
141
- # Convert to mono
142
- audio = torch.mean(audio, dim=0, keepdim=True)
143
- # Resample if needed
144
- if sr != sampling_rate:
145
- resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
146
- audio = resampler(audio)
147
- return audio
148
-
149
- def split_lyrics(lyrics: str):
150
- pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
151
- segments = re.findall(pattern, lyrics, re.DOTALL)
152
- structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
153
- return structured_lyrics
154
-
155
- # Call the function and print the result
156
- stage1_output_set = []
157
- vocals_list = []
158
- instrumentals_list = []
159
- genres = genre_txt.strip()
160
- lyrics = split_lyrics(lyrics_txt + "\n")
161
- # intruction
162
- full_lyrics = "\n".join(lyrics)
163
- prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
164
- prompt_texts += lyrics
165
-
166
- # special tokens
167
- start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
168
- end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
169
-
170
- # Format text prompt
171
- run_n_segments = min(run_n_segments + 1, len(lyrics))
172
-
173
- batches = [prompt_texts[i:i + batch_size] for i in range(0, run_n_segments, batch_size)]
174
-
175
- print(batches)
176
-
177
- for batch_idx, batch in enumerate(tqdm(batches)):
178
- random_ids = [uuid.uuid4() for _ in range(len(batch))]
179
- raw_outputs = [None] * len(batch)
180
-
181
- # Here is suggested decoding config
182
- top_p = 0.93
183
- temperature = 1.0
184
- repetition_penalty = 1.2
185
-
186
- for i, p in enumerate(batch):
187
- section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
188
- # Adjust guidance scale for the first two sections to be lower
189
- guidance_scale = 1.5 if (batch_idx*batch_size + i) <= 1 else 1.2
190
-
191
- if (batch_idx*batch_size + i) == 0:
192
- continue # Skip the first instruction
193
-
194
- if (batch_idx * batch_size + i) == 1:
195
- if use_audio_prompt:
196
- audio_prompt = load_audio_mono(audio_prompt_path)
197
- audio_prompt.unsqueeze_(0)
198
- with torch.no_grad():
199
- raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
200
- raw_codes = raw_codes.transpose(0, 1)
201
- raw_codes = raw_codes.cpu().numpy().astype(np.int16)
202
- # Format audio prompt
203
- code_ids = codectool.npy2ids(raw_codes[0])
204
- audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
205
- audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
206
- mmtokenizer.eoa]
207
- sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
208
- "[end_of_reference]")
209
- head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
210
- else:
211
- head_id = mmtokenizer.tokenize(prompt_texts[0])
212
- prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [
213
- mmtokenizer.soa] + codectool.sep_ids
214
- else:
215
- prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [
216
- mmtokenizer.soa] + codectool.sep_ids
217
-
218
- prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
219
- input_ids = torch.cat([raw_outputs[i], prompt_ids], dim=1) if (batch_idx * batch_size + i) > 1 else prompt_ids
220
-
221
- # Use window slicing in case output sequence exceeds the context of model
222
- max_context = 16384 - max_new_tokens - 1
223
- if input_ids.shape[-1] > max_context:
224
- print(
225
- f'Section {(batch_idx * batch_size + i)}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
226
- input_ids = input_ids[:, -(max_context):]
227
-
228
- with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
229
- output_seq = model.generate(
230
- input_ids=input_ids,
231
- max_new_tokens=max_new_tokens,
232
- min_new_tokens=100,
233
- do_sample=True,
234
- top_p=top_p,
235
- temperature=temperature,
236
- repetition_penalty=repetition_penalty,
237
- eos_token_id=mmtokenizer.eoa,
238
- pad_token_id=mmtokenizer.eoa,
239
- logits_processor=LogitsProcessorList(
240
- [BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
241
- guidance_scale=guidance_scale,
242
- use_cache=True
243
- )
244
- if output_seq[0][-1].item() != mmtokenizer.eoa:
245
- tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
246
- output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
247
-
248
- if (batch_idx * batch_size + i) > 1:
249
- raw_outputs[i] = torch.cat([raw_outputs[i], prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
250
- else:
251
- raw_outputs[i] = output_seq
252
-
253
- for i, raw_output in enumerate(raw_outputs):
254
- # save raw output and check sanity
255
- ids = raw_output[0].cpu().numpy()
256
- soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
257
- eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
258
- if len(soa_idx) != len(eoa_idx):
259
- raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
260
-
261
- range_begin = 1 if use_audio_prompt and batch_idx == 0 else 0
262
-
263
- vocals_batch = []
264
- instrumentals_batch = []
265
- for j in range(range_begin, len(soa_idx)):
266
- codec_ids = ids[soa_idx[j] + 1:eoa_idx[j]]
267
- if codec_ids[0] == 32016:
268
- codec_ids = codec_ids[1:]
269
- codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
270
- vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
271
- vocals_batch.append(vocals_ids)
272
- instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
273
- instrumentals_batch.append(instrumentals_ids)
274
-
275
- vocals_batch = np.concatenate(vocals_batch, axis=1)
276
- instrumentals_batch = np.concatenate(instrumentals_batch, axis=1)
277
-
278
- vocals_list.append(vocals_batch)
279
- instrumentals_list.append(instrumentals_batch)
280
-
281
- vocals = np.concatenate(vocals_list, axis=1)
282
- instrumentals = np.concatenate(instrumentals_list, axis=1)
283
-
284
- vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_ids[0]}".replace('.', '@') + '.npy')
285
- inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_ids[0]}".replace('.', '@') + '.npy')
286
- np.save(vocal_save_path, vocals)
287
- np.save(inst_save_path, instrumentals)
288
- stage1_output_set.append(vocal_save_path)
289
- stage1_output_set.append(inst_save_path)
290
-
291
- print("Converting to Audio...")
292
-
293
- # convert audio tokens to audio
294
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
295
- folder_path = os.path.dirname(path)
296
- if not os.path.exists(folder_path):
297
- os.makedirs(folder_path)
298
- limit = 0.99
299
- max_val = wav.abs().max()
300
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
301
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
302
-
303
- # reconstruct tracks
304
- recons_output_dir = os.path.join(output_dir, "recons")
305
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
306
- os.makedirs(recons_mix_dir, exist_ok=True)
307
- tracks = []
308
-
309
- vocal_stem = None
310
- instrumental_stem = None
311
- sr = None
312
-
313
- for npy in stage1_output_set:
314
- codec_result = np.load(npy)
315
- decodec_rlt = []
316
- with torch.no_grad():
317
- decoded_waveform = codec_model.decode(
318
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
319
- device))
320
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
321
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
322
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
323
-
324
- #save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
325
- #tracks.append(save_path)
326
- #save_audio(decodec_rlt, save_path, 16000)
327
- if 'vocal' in npy:
328
- vocal_stem = decodec_rlt.numpy()
329
- elif 'instrumental' in npy:
330
- instrumental_stem = decodec_rlt.numpy()
331
- sr = 16000
332
 
333
- # mix tracks
334
- if vocal_stem is not None and instrumental_stem is not None:
335
- mix_stem = (vocal_stem + instrumental_stem) / 1
336
- return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (
337
- sr, (instrumental_stem * 32767).astype(np.int16))
338
 
339
- else:
340
- print("Missing Vocal or Instrumental Stem")
341
- return None, None, None
 
 
 
 
342
 
343
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
344
  # Execute the command
345
  try:
346
  mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
347
- cuda_idx=0, max_new_tokens=max_new_tokens, batch_size=4)
348
  return mixed_audio_data, vocal_audio_data, instrumental_audio_data
349
  except Exception as e:
350
  gr.Warning("An Error Occured: " + str(e))
@@ -430,7 +355,6 @@ Living out my dreams with this mic and a deal
430
  inputs=[genre_txt, lyrics_txt],
431
  outputs=[music_out, vocal_out, instrumental_out],
432
  cache_examples=True,
433
- cache_mode="eager",
434
  fn=infer
435
  )
436
 
 
46
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer'))
47
  sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'xcodec_mini_infer', 'descriptaudiocodec'))
48
 
 
 
49
  import argparse
50
  import numpy as np
51
  import json
 
64
  import copy
65
  from collections import Counter
66
  from models.soundstream_hubert_new import SoundStream
67
+
68
+ # don't change above code
69
 
70
  device = "cuda:0"
71
 
 
79
 
80
  basic_model_config = './xcodec_mini_infer/final_ckpt/config.yaml'
81
  resume_path = './xcodec_mini_infer/final_ckpt/ckpt_00360000.pth'
 
 
 
82
 
83
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
84
 
 
88
  codec_model = eval(model_config.generator.name)(**model_config.generator.config).to(device)
89
  parameter_dict = torch.load(resume_path, map_location='cpu')
90
  codec_model.load_state_dict(parameter_dict['codec_model'])
 
91
  codec_model.eval()
92
 
 
 
 
 
 
 
 
 
 
93
  @spaces.GPU(duration=120)
94
  def generate_music(
95
  max_new_tokens=5,
 
102
  prompt_end_time=30.0,
103
  cuda_idx=0,
104
  rescale=False,
 
105
  ):
106
  if use_audio_prompt and not audio_prompt_path:
107
  raise FileNotFoundError("Please offer audio prompt filepath using '--audio_prompt_path', when you enable 'use_audio_prompt'!")
108
  cuda_idx = cuda_idx
109
  max_new_tokens = max_new_tokens * 100
110
 
111
+ class BlockTokenRangeProcessor(LogitsProcessor):
112
+ def __init__(self, start_id, end_id):
113
+ self.blocked_token_ids = list(range(start_id, end_id))
114
+
115
+ def __call__(self, input_ids, scores):
116
+ scores[:, self.blocked_token_ids] = -float("inf")
117
+ return scores
118
+
119
+ def load_audio_mono(filepath, sampling_rate=16000):
120
+ audio, sr = torchaudio.load(filepath)
121
+ # Convert to mono
122
+ audio = torch.mean(audio, dim=0, keepdim=True)
123
+ # Resample if needed
124
+ if sr != sampling_rate:
125
+ resampler = Resample(orig_freq=sr, new_freq=sampling_rate)
126
+ audio = resampler(audio)
127
+ return audio
128
+
129
+ def split_lyrics(lyrics: str):
130
+ pattern = r"\[(\w+)\](.*?)\n(?=\[|\Z)"
131
+ segments = re.findall(pattern, lyrics, re.DOTALL)
132
+ structured_lyrics = [f"[{seg[0]}]\n{seg[1].strip()}\n\n" for seg in segments]
133
+ return structured_lyrics
134
+
135
+ # Call the function and print the result
136
+ stage1_output_set = []
137
+
138
+ genres = genre_txt.strip()
139
+ lyrics = split_lyrics(lyrics_txt + "\n")
140
+ # intruction
141
+ full_lyrics = "\n".join(lyrics)
142
+ prompt_texts = [f"Generate music from the given lyrics segment by segment.\n[Genre] {genres}\n{full_lyrics}"]
143
+ prompt_texts += lyrics
144
+
145
+ random_id = uuid.uuid4()
146
+ output_seq = None
147
+ # Here is suggested decoding config
148
+ top_p = 0.93
149
+ temperature = 1.0
150
+ repetition_penalty = 1.2
151
+ # special tokens
152
+ start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
153
+ end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
154
+
155
+ raw_output = None
156
+
157
+ # Format text prompt
158
+ run_n_segments = min(run_n_segments + 1, len(lyrics))
159
+
160
+ print(list(enumerate(tqdm(prompt_texts[:run_n_segments]))))
161
+
162
+ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
163
+ section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
164
+ guidance_scale = 1.5 if i <= 1 else 1.2
165
+ if i == 0:
166
+ continue
167
+ if i == 1:
168
+ if use_audio_prompt:
169
+ audio_prompt = load_audio_mono(audio_prompt_path)
170
+ audio_prompt.unsqueeze_(0)
171
+ with torch.no_grad():
172
+ raw_codes = codec_model.encode(audio_prompt.to(device), target_bw=0.5)
173
+ raw_codes = raw_codes.transpose(0, 1)
174
+ raw_codes = raw_codes.cpu().numpy().astype(np.int16)
175
+ # Format audio prompt
176
+ code_ids = codectool.npy2ids(raw_codes[0])
177
+ audio_prompt_codec = code_ids[int(prompt_start_time * 50): int(prompt_end_time * 50)] # 50 is tps of xcodec
178
+ audio_prompt_codec_ids = [mmtokenizer.soa] + codectool.sep_ids + audio_prompt_codec + [
179
+ mmtokenizer.eoa]
180
+ sentence_ids = mmtokenizer.tokenize("[start_of_reference]") + audio_prompt_codec_ids + mmtokenizer.tokenize(
181
+ "[end_of_reference]")
182
+ head_id = mmtokenizer.tokenize(prompt_texts[0]) + sentence_ids
183
+ else:
184
+ head_id = mmtokenizer.tokenize(prompt_texts[0])
185
+ prompt_ids = head_id + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
186
+ else:
187
+ prompt_ids = end_of_segment + start_of_segment + mmtokenizer.tokenize(section_text) + [mmtokenizer.soa] + codectool.sep_ids
188
+
189
+ prompt_ids = torch.as_tensor(prompt_ids).unsqueeze(0).to(device)
190
+ input_ids = torch.cat([raw_output, prompt_ids], dim=1) if i > 1 else prompt_ids
191
+ # Use window slicing in case output sequence exceeds the context of model
192
+ max_context = 16384 - max_new_tokens - 1
193
+ if input_ids.shape[-1] > max_context:
194
+ print(
195
+ f'Section {i}: output length {input_ids.shape[-1]} exceeding context length {max_context}, now using the last {max_context} tokens.')
196
+ input_ids = input_ids[:, -(max_context):]
197
+ with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.float16):
198
+ output_seq = model.generate(
199
+ input_ids=input_ids,
200
+ max_new_tokens=max_new_tokens,
201
+ min_new_tokens=100,
202
+ do_sample=True,
203
+ top_p=top_p,
204
+ temperature=temperature,
205
+ repetition_penalty=repetition_penalty,
206
+ eos_token_id=mmtokenizer.eoa,
207
+ pad_token_id=mmtokenizer.eoa,
208
+ logits_processor=LogitsProcessorList([BlockTokenRangeProcessor(0, 32002), BlockTokenRangeProcessor(32016, 32016)]),
209
+ guidance_scale=guidance_scale,
210
+ use_cache=True
211
+ )
212
+ if output_seq[0][-1].item() != mmtokenizer.eoa:
213
+ tensor_eoa = torch.as_tensor([[mmtokenizer.eoa]]).to(model.device)
214
+ output_seq = torch.cat((output_seq, tensor_eoa), dim=1)
215
+ if i > 1:
216
+ raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
217
+ else:
218
+ raw_output = output_seq
219
+ print(len(raw_output))
220
+
221
+ # save raw output and check sanity
222
+ ids = raw_output[0].cpu().numpy()
223
+ soa_idx = np.where(ids == mmtokenizer.soa)[0].tolist()
224
+ eoa_idx = np.where(ids == mmtokenizer.eoa)[0].tolist()
225
+ if len(soa_idx) != len(eoa_idx):
226
+ raise ValueError(f'invalid pairs of soa and eoa, Num of soa: {len(soa_idx)}, Num of eoa: {len(eoa_idx)}')
227
+
228
+ vocals = []
229
+ instrumentals = []
230
+ range_begin = 1 if use_audio_prompt else 0
231
+ for i in range(range_begin, len(soa_idx)):
232
+ codec_ids = ids[soa_idx[i] + 1:eoa_idx[i]]
233
+ if codec_ids[0] == 32016:
234
+ codec_ids = codec_ids[1:]
235
+ codec_ids = codec_ids[:2 * (codec_ids.shape[0] // 2)]
236
+ vocals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[0])
237
+ vocals.append(vocals_ids)
238
+ instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
239
+ instrumentals.append(instrumentals_ids)
240
+ vocals = np.concatenate(vocals, axis=1)
241
+ instrumentals = np.concatenate(instrumentals, axis=1)
242
+
243
+ vocal_audio = None
244
+ instrumental_audio = None
245
+ mixed_audio = None
246
+
247
+ # convert audio tokens to audio
248
+ def convert_to_audio(codec_result, rescale):
249
+ with torch.no_grad():
250
+ decoded_waveform = codec_model.decode(
251
+ torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
252
+ device))
253
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ limit = 0.99
256
+ max_val = decoded_waveform.abs().max()
257
+ scaled_waveform = decoded_waveform * min(limit / max_val, 1) if rescale else decoded_waveform.clamp(-limit, limit)
258
+ return (16000, (scaled_waveform * 32767).astype(np.int16))
 
259
 
260
+ vocal_audio = convert_to_audio(vocals, rescale)
261
+ instrumental_audio = convert_to_audio(instrumentals, rescale)
262
+
263
+ mix_stem = (vocal_audio[1] + instrumental_audio[1]) / 1 # mixing by summing and dividing
264
+ mixed_audio = (vocal_audio[0], mix_stem) # same sample rate
265
+
266
+ return mixed_audio, vocal_audio, instrumental_audio
267
 
268
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
269
  # Execute the command
270
  try:
271
  mixed_audio_data, vocal_audio_data, instrumental_audio_data = generate_music(genre_txt=genre_txt_content, lyrics_txt=lyrics_txt_content, run_n_segments=num_segments,
272
+ cuda_idx=0, max_new_tokens=max_new_tokens)
273
  return mixed_audio_data, vocal_audio_data, instrumental_audio_data
274
  except Exception as e:
275
  gr.Warning("An Error Occured: " + str(e))
 
355
  inputs=[genre_txt, lyrics_txt],
356
  outputs=[music_out, vocal_out, instrumental_out],
357
  cache_examples=True,
 
358
  fn=infer
359
  )
360