KingNish commited on
Commit
c06dce9
·
verified ·
1 Parent(s): 827d1b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -50
app.py CHANGED
@@ -258,64 +258,37 @@ def generate_music(
258
  vocals.append(vocals_ids)
259
  instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
260
  instrumentals.append(instrumentals_ids)
 
261
  vocals = np.concatenate(vocals, axis=1)
262
  instrumentals = np.concatenate(instrumentals, axis=1)
263
 
264
- vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
265
- inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
266
- np.save(vocal_save_path, vocals)
267
- np.save(inst_save_path, instrumentals)
268
- stage1_output_set.append(vocal_save_path)
269
- stage1_output_set.append(inst_save_path)
270
-
271
  print("Converting to Audio...")
272
 
273
- # convert audio tokens to audio
274
- def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
275
- folder_path = os.path.dirname(path)
276
- if not os.path.exists(folder_path):
277
- os.makedirs(folder_path)
278
- limit = 0.99
279
- max_val = wav.abs().max()
280
- wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
281
- torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
 
 
 
282
 
283
  # reconstruct tracks
284
- recons_output_dir = os.path.join(output_dir, "recons")
285
- recons_mix_dir = os.path.join(recons_output_dir, 'mix')
286
- os.makedirs(recons_mix_dir, exist_ok=True)
287
- tracks = []
288
- for npy in stage1_output_set:
289
- codec_result = np.load(npy)
290
- decodec_rlt = []
291
- with torch.no_grad():
292
- decoded_waveform = codec_model.decode(
293
- torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
294
- device))
295
- decoded_waveform = decoded_waveform.cpu().squeeze(0)
296
- decodec_rlt.append(torch.as_tensor(decoded_waveform))
297
- decodec_rlt = torch.cat(decodec_rlt, dim=-1)
298
- save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
299
- tracks.append(save_path)
300
- save_audio(decodec_rlt, save_path, 16000)
301
  # mix tracks
302
- for inst_path in tracks:
303
- try:
304
- if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
305
- and 'instrumental' in inst_path:
306
- # find pair
307
- vocal_path = inst_path.replace('instrumental', 'vocal')
308
- if not os.path.exists(vocal_path):
309
- continue
310
- # mix
311
- recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
312
- vocal_stem, sr = sf.read(inst_path)
313
- instrumental_stem, _ = sf.read(vocal_path)
314
- mix_stem = (vocal_stem + instrumental_stem) / 1
315
- return (sr, (mix_stem * 32767).astype(np.int16)), (sr, (vocal_stem * 32767).astype(np.int16)), (sr, (instrumental_stem * 32767).astype(np.int16))
316
- except Exception as e:
317
- print(e)
318
- return None, None, None
319
 
320
 
321
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
 
258
  vocals.append(vocals_ids)
259
  instrumentals_ids = codectool.ids2npy(rearrange(codec_ids, "(n b) -> b n", b=2)[1])
260
  instrumentals.append(instrumentals_ids)
261
+
262
  vocals = np.concatenate(vocals, axis=1)
263
  instrumentals = np.concatenate(instrumentals, axis=1)
264
 
 
 
 
 
 
 
 
265
  print("Converting to Audio...")
266
 
267
+ # batching audio
268
+ def decode_audio_batch(codec_result, batch_size=4):
269
+ decoded_waveforms = []
270
+ with torch.no_grad():
271
+ for i in range(0, codec_result.shape[-1], batch_size):
272
+ batch = codec_result[:,i:i+batch_size]
273
+ batch_tensor = torch.as_tensor(batch.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(device)
274
+ decoded_waveform = codec_model.decode(batch_tensor)
275
+ decoded_waveforms.append(decoded_waveform)
276
+ decoded_waveforms = torch.cat(decoded_waveforms, dim=-1).squeeze(0).cpu()
277
+ return decoded_waveforms
278
+
279
 
280
  # reconstruct tracks
281
+ vocal_waveform = decode_audio_batch(vocals)
282
+ instrumental_waveform = decode_audio_batch(instrumentals)
283
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  # mix tracks
285
+ try:
286
+ mix_waveform = (vocal_waveform + instrumental_waveform) / 1
287
+ return (16000, (mix_waveform * 32767).numpy().astype(np.int16)), (16000, (vocal_waveform * 32767).numpy().astype(np.int16)), (16000, (instrumental_waveform * 32767).numpy().astype(np.int16))
288
+ except Exception as e:
289
+ print(e)
290
+ return None, None, None
291
+
 
 
 
 
 
 
 
 
 
 
292
 
293
 
294
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):