KingNish commited on
Commit
1bb807a
·
verified ·
1 Parent(s): fdbe6f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -28
app.py CHANGED
@@ -261,34 +261,61 @@ def generate_music(
261
  vocals = np.concatenate(vocals, axis=1)
262
  instrumentals = np.concatenate(instrumentals, axis=1)
263
 
264
- #convert audio tokens to audio
265
- with torch.no_grad():
266
- decoded_vocals = codec_model.decode(
267
- torch.as_tensor(vocals.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
268
- device))
269
- decoded_instrumentals = codec_model.decode(
270
- torch.as_tensor(instrumentals.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
271
- device))
272
-
273
- decoded_vocals = decoded_vocals.cpu().squeeze(0)
274
- decoded_instrumentals = decoded_instrumentals.cpu().squeeze(0)
275
- mixed_audio = (decoded_vocals + decoded_instrumentals)/2
276
-
277
- # Scale to be between -1 and 1 and convert to int16
278
- limit = 0.99
279
- max_val = mixed_audio.abs().max()
280
- mixed_audio = mixed_audio * min(limit / max_val, 1) if rescale else mixed_audio.clamp(-limit, limit)
281
- mixed_audio = (mixed_audio * 32767).to(torch.int16).numpy()
282
-
283
- max_val = decoded_vocals.abs().max()
284
- decoded_vocals = decoded_vocals * min(limit / max_val, 1) if rescale else decoded_vocals.clamp(-limit, limit)
285
- decoded_vocals = (decoded_vocals * 32767).to(torch.int16).numpy()
286
-
287
- max_val = decoded_instrumentals.abs().max()
288
- decoded_instrumentals = decoded_instrumentals * min(limit / max_val, 1) if rescale else decoded_instrumentals.clamp(-limit, limit)
289
- decoded_instrumentals = (decoded_instrumentals * 32767).to(torch.int16).numpy()
290
-
291
- return (16000, mixed_audio), (16000, decoded_vocals), (16000, decoded_instrumentals)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
 
294
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):
 
261
  vocals = np.concatenate(vocals, axis=1)
262
  instrumentals = np.concatenate(instrumentals, axis=1)
263
 
264
+ vocal_save_path = os.path.join(stage1_output_dir, f"vocal_{random_id}".replace('.', '@') + '.npy')
265
+ inst_save_path = os.path.join(stage1_output_dir, f"instrumental_{random_id}".replace('.', '@') + '.npy')
266
+ np.save(vocal_save_path, vocals)
267
+ np.save(inst_save_path, instrumentals)
268
+ stage1_output_set.append(vocal_save_path)
269
+ stage1_output_set.append(inst_save_path)
270
+
271
+ print("Converting to Audio...")
272
+
273
+ # convert audio tokens to audio
274
+ def save_audio(wav: torch.Tensor, path, sample_rate: int, rescale: bool = False):
275
+ folder_path = os.path.dirname(path)
276
+ if not os.path.exists(folder_path):
277
+ os.makedirs(folder_path)
278
+ limit = 0.99
279
+ max_val = wav.abs().max()
280
+ wav = wav * min(limit / max_val, 1) if rescale else wav.clamp(-limit, limit)
281
+ torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16)
282
+
283
+ # reconstruct tracks
284
+ recons_output_dir = os.path.join(output_dir, "recons")
285
+ recons_mix_dir = os.path.join(recons_output_dir, 'mix')
286
+ os.makedirs(recons_mix_dir, exist_ok=True)
287
+ tracks = []
288
+ for npy in stage1_output_set:
289
+ codec_result = np.load(npy)
290
+ decodec_rlt = []
291
+ with torch.no_grad():
292
+ decoded_waveform = codec_model.decode(
293
+ torch.as_tensor(codec_result.astype(np.int16), dtype=torch.long).unsqueeze(0).permute(1, 0, 2).to(
294
+ device))
295
+ decoded_waveform = decoded_waveform.cpu().squeeze(0)
296
+ decodec_rlt.append(torch.as_tensor(decoded_waveform))
297
+ decodec_rlt = torch.cat(decodec_rlt, dim=-1)
298
+ save_path = os.path.join(recons_output_dir, os.path.splitext(os.path.basename(npy))[0] + ".mp3")
299
+ tracks.append(save_path)
300
+ save_audio(decodec_rlt, save_path, 16000)
301
+ # mix tracks
302
+ for inst_path in tracks:
303
+ try:
304
+ if (inst_path.endswith('.wav') or inst_path.endswith('.mp3')) \
305
+ and 'instrumental' in inst_path:
306
+ # find pair
307
+ vocal_path = inst_path.replace('instrumental', 'vocal')
308
+ if not os.path.exists(vocal_path):
309
+ continue
310
+ # mix
311
+ recons_mix = os.path.join(recons_mix_dir, os.path.basename(inst_path).replace('instrumental', 'mixed'))
312
+ vocal_stem, sr = sf.read(inst_path)
313
+ instrumental_stem, _ = sf.read(vocal_path)
314
+ mix_stem = (vocal_stem + instrumental_stem) / 1
315
+ return (16000, mix_stem), (16000, vocal_stem), (16000, instrumental_stem)
316
+ except Exception as e:
317
+ print(e)
318
+ return None, None, None
319
 
320
 
321
  def infer(genre_txt_content, lyrics_txt_content, num_segments=2, max_new_tokens=15):