|
import torch |
|
import torchvision.transforms as transforms |
|
from torch.utils.data.dataset import Dataset |
|
import torch.distributed as dist |
|
import torchaudio |
|
import torchvision |
|
import torchvision.io |
|
|
|
import os, io, csv, math, random |
|
import os.path as osp |
|
from pathlib import Path |
|
import numpy as np |
|
import pandas as pd |
|
from einops import rearrange |
|
import glob |
|
|
|
from decord import VideoReader, AudioReader |
|
import decord |
|
from copy import deepcopy |
|
import pickle |
|
|
|
from petrel_client.client import Client |
|
import sys |
|
sys.path.append('./') |
|
from foleycrafter.data import video_transforms |
|
|
|
from foleycrafter.utils.util import \ |
|
random_audio_video_clip, get_full_indices, video_tensor_to_np, get_video_frames |
|
from foleycrafter.utils.spec_to_mel import wav_tensor_to_fbank, read_wav_file_io, load_audio, normalize_wav, pad_wav |
|
from foleycrafter.utils.converter import get_mel_spectrogram_from_audio, pad_spec, normalize, normalize_spectrogram |
|
|
|
def zero_rank_print(s): |
|
if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s, flush=True) |
|
|
|
@torch.no_grad() |
|
def get_mel(audio_data, audio_cfg): |
|
|
|
mel = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=audio_cfg["sample_rate"], |
|
n_fft=audio_cfg["window_size"], |
|
win_length=audio_cfg["window_size"], |
|
hop_length=audio_cfg["hop_size"], |
|
center=True, |
|
pad_mode="reflect", |
|
power=2.0, |
|
norm=None, |
|
onesided=True, |
|
n_mels=64, |
|
f_min=audio_cfg["fmin"], |
|
f_max=audio_cfg["fmax"], |
|
).to(audio_data.device) |
|
mel = mel(audio_data) |
|
|
|
mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) |
|
return mel |
|
|
|
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): |
|
""" |
|
PARAMS |
|
------ |
|
C: compression factor |
|
""" |
|
return normalize_fun(torch.clamp(x, min=clip_val) * C) |
|
|
|
class CPU_Unpickler(pickle.Unpickler): |
|
def find_class(self, module, name): |
|
if module == 'torch.storage' and name == '_load_from_bytes': |
|
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') |
|
else: |
|
return super().find_class(module, name) |
|
|
|
class AudioSetStrong(Dataset): |
|
|
|
def __init__( |
|
self, |
|
): |
|
super().__init__() |
|
self.data_path = 'data/AudioSetStrong/train/feature' |
|
self.data_list = list(self._client.list(self.data_path)) |
|
self.length = len(self.data_list) |
|
|
|
self.video_path = 'data/AudioSetStrong/train/video' |
|
vision_transform_list = [ |
|
transforms.Resize((128, 128)), |
|
transforms.CenterCrop((112, 112)), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
] |
|
self.video_transform = transforms.Compose(vision_transform_list) |
|
|
|
def get_batch(self, idx): |
|
embeds = self.data_list[idx] |
|
mel = embeds['mel'] |
|
save_bsz = mel.shape[0] |
|
audio_info = embeds['audio_info'] |
|
text_embeds = embeds['text_embeds'] |
|
|
|
|
|
audio_info_array = np.array(audio_info['label_list']) |
|
prompts = [] |
|
for i in range(save_bsz): |
|
prompts.append(', '.join(audio_info_array[i, :audio_info['event_num'][i]].tolist())) |
|
|
|
|
|
videos = None |
|
for video_name in audio_info['audio_name']: |
|
video_bytes = self._client.Get(osp.join(self.video_path, video_name+'.mp4')) |
|
video_bytes = io.BytesIO(video_bytes) |
|
video_reader = VideoReader(video_bytes) |
|
video = video_reader.get_batch(get_full_indices(video_reader)).asnumpy() |
|
video = get_video_frames(video, 150) |
|
video = torch.from_numpy(video).permute(0, 3, 1, 2).contiguous().float() |
|
video = self.video_transform(video) |
|
video = video.unsqueeze(0) |
|
if videos is None: |
|
videos = video |
|
else: |
|
videos = torch.cat([videos, video], dim=0) |
|
|
|
assert videos is not None, 'no video read' |
|
|
|
return mel, audio_info, text_embeds, prompts, videos |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, idx): |
|
while True: |
|
try: |
|
mel, audio_info, text_embeds, prompts, videos = self.get_batch(idx) |
|
break |
|
except Exception as e: |
|
zero_rank_print(' >>> load error <<<') |
|
idx = random.randint(0, self.length-1) |
|
sample = dict(mel=mel, audio_info=audio_info, text_embeds=text_embeds, prompts=prompts, videos=videos) |
|
return sample |
|
|
|
class VGGSound(Dataset): |
|
|
|
def __init__( |
|
self, |
|
): |
|
super().__init__() |
|
self.data_path = 'data/VGGSound/train/video' |
|
self.visual_data_path = 'data/VGGSound/train/feature' |
|
self.embeds_list = glob.glob(f'{self.data_path}/*.pt') |
|
self.visual_list = glob.glob(f'{self.visual_data_path}/*.pt') |
|
self.length = len(self.embeds_list) |
|
|
|
def get_batch(self, idx): |
|
embeds = torch.load(self.embeds_list[idx], map_location='cpu') |
|
visual_embeds = torch.load(self.visual_list[idx], map_location='cpu') |
|
|
|
|
|
visual_embeds = visual_embeds['visual_embeds'] |
|
video_name = embeds['video_name'] |
|
text = embeds['text'] |
|
mel = embeds['mel'] |
|
|
|
audio = mel |
|
|
|
return visual_embeds, audio, text |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, idx): |
|
while True: |
|
try: |
|
visual_embeds, audio, text = self.get_batch(idx) |
|
break |
|
except Exception as e: |
|
zero_rank_print('load error') |
|
idx = random.randint(0, self.length-1) |
|
sample = dict(visual_embeds=visual_embeds, audio=audio, text=text) |
|
return sample |