ALeLacheur's picture
uploading audio diffusion attacks
5a9b731
"""
test_audioldm.py
Desc: Using example clips from the Free Music Archive (FMA), test out AudioLDM2
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy
import torch
import torchaudio
import os
import ast
import soundfile as sf
# Old Code for Importing AudioLDM
# from audioldm.pipeline import build_model
# HF Code for AudioLDM2
# from diffusers import AudioLDM2Pipeline
from audioldm.audio import wav_to_fbank, TacotronSTFT
try:
from audioldm2 import build_model
except:
from audioldm2 import build_model
# Change the below to desired GPU if using GPU
# os.environ['CUDA_VISIBLE_DEVICES'] = '4'
if __name__ == '__main__':
data_loc = '/data/robbizorg/music_datasets/fma/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Instantiate AudioLDM
# repo_id = "cvssp/audioldm2-music"
# pipe = AudioLDM2Pipeline.from_pretrained(repo_id) #torch_dtype = torch.float16)
# pipe = pipe.to('cuda')
model = build_model()
# audio_codec = build_model().to(device)
# audio_codec.latent_t_size = 256
example_audio_loc = os.path.join(data_loc, 'data/fma_large/000/000420.mp3')
audio, sr = torchaudio.load(example_audio_loc)
mono_audio = torch.mean(audio, axis = 0) # Convert to Mono
print("Audio Size and Sampling Rate")
print(audio.shape, sr)
print(f"Audio is {np.round(audio.shape[-1]/sr/60, 2)} minutes long")
# Let's check the generation capabilities
# prompt = "Techno music with an energetic riff."
# negative_prompt = "Low quality."
# generator = torch.Generator("cuda").manual_seed(0)
# gen_audio = pipe(
# prompt,
# negative_prompt=negative_prompt,
# num_inference_steps=200,
# audio_length_in_s=10.0,
# num_waveforms_per_prompt=3,
# generator=generator,
# ).audios
# Omg these generated audios are so bad
# scipy.io.wavfile.write("./assets/audios/hyperpop.wav", rate=16000, data=gen_audio[0])
# Test Vocoder Abilities
mono_audio = mono_audio.to('cuda')
# Resample Audio
resamp_audio = torchaudio.functional.resample(mono_audio, sr, 48000)
resamp_16k = torchaudio.functional.resample(mono_audio, sr, 16000)
# feats = pipe.feature_extractor(resamp_audio.cpu().numpy(), sampling_rate = 48000)['input_features']
# feats = torch.Tensor(feats[0]).to(device)
# wav = pipe.vocoder(feats)
# pipe.vae needs mel-spec
"""
Config for Mel Spec
{'audio':
{'sampling_rate': 16000, 'max_wav_value': 32768, 'duration': 10.24},
'stft': {'filter_length': 1024, 'hop_length': 160, 'win_length': 1024},
'mel': {'n_mel_channels': 64, 'mel_fmin': 0, 'mel_fmax': 8000}}
"""
default_mel_config = {
"preprocessing": {
"audio": {
"sampling_rate": 16000,
"max_wav_value": 32768,
"duration": 10.24,
},
"stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
"mel": {"n_mel_channels": 64, "mel_fmin": 0, "mel_fmax": 8000},
}}
fn_STFT = TacotronSTFT(
default_mel_config["preprocessing"]["stft"]["filter_length"],
default_mel_config["preprocessing"]["stft"]["hop_length"],
default_mel_config["preprocessing"]["stft"]["win_length"],
default_mel_config["preprocessing"]["mel"]["n_mel_channels"],
default_mel_config["preprocessing"]["audio"]["sampling_rate"],
default_mel_config["preprocessing"]["mel"]["mel_fmin"],
default_mel_config["preprocessing"]["mel"]["mel_fmax"],
)
duration = resamp_16k.shape[0]/16000
target_length = int(duration * 100) # int(duration * 102.4)
mel_1, _, _ = wav_to_fbank(resamp_16k.cpu(), target_length=target_length, fn_STFT=fn_STFT)
# wav = pipe.vocoder(mel_1.to(device))
wav = model.mel_spectrogram_to_waveform(mel_1.unsqueeze(0).to(device), save = False)
wav = wav[0, :, :]
# Normalize Generated Waveform
todo_waveform = (
wav / np.max(np.abs(wav))
) * 0.8 # Normalize the energy of the generation output
# torchaudio.save('assets/audios/example_reconst.wav', wav, sample_rate = 16000)
# Mel and Back results in a bit of a volume loss?
sf.write('./assets/audios/example_reconst.wav', todo_waveform[0, :], samplerate = 16000)
sf.write('./assets/audios/example_reconst_orig.wav', resamp_16k.cpu().numpy(), samplerate = 16000)
# Let's test the VAE
mel_1 = mel_1.unsqueeze(0).unsqueeze(0).to(device)
encode = model.first_stage_model.encode(mel_1).mean
decode = model.first_stage_model.decode(encode)
reconst_wav = model.mel_spectrogram_to_waveform(decode.squeeze(0), save = False)
todo_waveform = (
wav / np.max(np.abs(wav))
) * 0.8 # Normalize the energy of the generation output
sf.write('./assets/audios/example_reconst_frommel.wav', todo_waveform[0, :], samplerate = 16000)