import importlib import numpy as np import io import os import posixpath import random import re import subprocess import time import torch import torchaudio import webdataset as wds import pandas as pd from aeiou.core import is_silence from os import path from pathlib import Path from pedalboard.io import AudioFile from torchaudio import transforms as T from typing import Optional, Callable, List import bisect from .utils import FOA, Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, PadCrop_Video_Normalized_T, PadCrop_Video_Hiera_Normalized_T, PadCrop_Video_Image_Normalized_T, PadCrop_DualVideo_Normalized_T AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py def fast_scandir( dir:str, # top-level directory at which to begin scanning ext:list, # list of allowed file extensions, #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB ): "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" subfolders, files = [], [] ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed try: # hope to avoid 'permission denied' by this try for f in os.scandir(dir): try: # 'hope to avoid too many levels of symbolic links' error if f.is_dir(): subfolders.append(f.path) elif f.is_file(): file_ext = os.path.splitext(f.name)[1].lower() is_hidden = os.path.basename(f.path).startswith(".") if file_ext in ext and not is_hidden: files.append(f.path) except: pass except: pass for dir in list(subfolders): sf, f = fast_scandir(dir, ext) subfolders.extend(sf) files.extend(f) return subfolders, files def keyword_scandir( dir: str, # top-level directory at which to begin scanning ext: list, # list of allowed file extensions keywords: list, # list of keywords to search for in the file name ): "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" subfolders, files = [], [] # make keywords case insensitive keywords = [keyword.lower() for keyword in keywords] # add starting period to extensions if needed ext = ['.'+x if x[0] != '.' else x for x in ext] banned_words = ["paxheader", "__macosx"] try: # hope to avoid 'permission denied' by this try for f in os.scandir(dir): try: # 'hope to avoid too many levels of symbolic links' error if f.is_dir(): subfolders.append(f.path) elif f.is_file(): is_hidden = f.name.split("/")[-1][0] == '.' has_ext = os.path.splitext(f.name)[1].lower() in ext name_lower = f.name.lower() has_keyword = any( [keyword in name_lower for keyword in keywords]) has_banned = any( [banned_word in name_lower for banned_word in banned_words]) if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): files.append(f.path) except: pass except: pass for dir in list(subfolders): sf, f = keyword_scandir(dir, ext, keywords) subfolders.extend(sf) files.extend(f) return subfolders, files def get_audio_filenames( paths: list, # directories in which to search keywords=None, exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] ): "recursively get a list of audio filenames" filenames = [] if type(paths) is str: paths = [paths] for path in paths: # get a list of relevant filenames if keywords is not None: subfolders, files = keyword_scandir(path, exts, keywords) else: subfolders, files = fast_scandir(path, exts) filenames.extend(files) return filenames class LocalDatasetConfig: def __init__( self, id: str, path: str, split_path: str, audio_dir: str = None, extra_cot: str = None, custom_metadata_fn: Optional[Callable[[str], str]] = None ): self.id = id self.path = path self.split_path = split_path self.audio_dir = audio_dir self.custom_metadata_fn = custom_metadata_fn self.extra_cot = extra_cot class SampleDataset(torch.utils.data.Dataset): def __init__( self, configs, sample_size=65536, sample_rate=48000, keywords=None, random_crop=True, input_type="prompt", fps=4, force_channels="stereo" ): super().__init__() self.filenames = [] self.augs = torch.nn.Sequential( PhaseFlipper(), ) self.root_paths = [] if input_type == 'video': self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) elif input_type == 'video_hiera': self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) elif input_type == 'video_image': self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) elif input_type == 'dual_video': self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) else: self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) self.force_channels = force_channels print('######################') print(f'input channels is: {force_channels}') print('######################') self.encoding = torch.nn.Sequential( FOA() if self.force_channels == "foa" else torch.nn.Identity(), Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), Mono() if self.force_channels == "mono" else torch.nn.Identity(), ) self.input_type = input_type self.sr = sample_rate self.custom_metadata_fns = {} for config in configs: self.root_paths.append(config.path) def add_prefix(s): return str(os.path.join(config.path,f'{s.strip()}')) with open(config.split_path,'r') as f: item_names = f.readlines() filenames = list(map(add_prefix, item_names)) self.filenames.extend(filenames) # self.filenames.extend(get_audio_filenames(config.path, keywords)) if config.custom_metadata_fn is not None: self.custom_metadata_fns[config.path] = config.custom_metadata_fn print(f'Found {len(self.filenames)} files') def load_file(self, filename): ext = filename.split(".")[-1] if ext == "mp3": with AudioFile(filename) as f: audio = f.read(f.frames) audio = torch.from_numpy(audio) in_sr = f.samplerate else: audio, in_sr = torchaudio.load(filename, format=ext) if in_sr != self.sr: try: resample_tf = T.Resample(in_sr, self.sr) audio = resample_tf(audio) except: print(f'{filename} resample errors') assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' return audio def __len__(self): return len(self.filenames) def __getitem__(self, idx): audio_filename = self.filenames[idx] assert os.path.exists(audio_filename), f'{audio_filename}: file not exists' try: start_time = time.time() audio = self.load_file(audio_filename) info = {} info["path"] = audio_filename for root_path in self.root_paths: if root_path in audio_filename: info["relpath"] = path.relpath(audio_filename, root_path) for custom_md_path in self.custom_metadata_fns.keys(): if custom_md_path in audio_filename: custom_metadata_fn = self.custom_metadata_fns[custom_md_path] custom_metadata = custom_metadata_fn(info, audio) info.update(custom_metadata) if "__reject__" in info and info["__reject__"]: return self[random.randrange(len(self))] if self.input_type == 'video': audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video']) info['video'] = video elif self.input_type == 'dual_video': audio, video_360, video_fov, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video'], info['video_fov']) info['video_360'] = video_360 info['video_fov'] = video_fov else: audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio) assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' # Run augmentations on this sample (including random crop) if self.augs is not None: audio = self.augs(audio) audio = audio.clamp(-1, 1) # Encode the file to assist in prediction if self.encoding is not None: audio = self.encoding(audio) info["timestamps"] = (t_start, t_end) info["seconds_start"] = seconds_start info["seconds_total"] = seconds_total info["padding_mask"] = padding_mask end_time = time.time() info["load_time"] = end_time - start_time return (audio, info) except Exception as e: print(f'Couldn\'t load file {audio_filename}: {e}') return self[random.randrange(len(self))] class LatentDataset(torch.utils.data.Dataset): def __init__( self, configs, sample_size=65536, sample_rate=48000, keywords=None, random_crop=True, input_type="prompt", fps=4, force_channels="stereo" ): super().__init__() self.filenames = [] self.augs = torch.nn.Sequential( PhaseFlipper(), ) self.root_paths = [] self.force_channels = force_channels print('######################') print(f'input channels is: {force_channels}') print('######################') self.encoding = torch.nn.Sequential( FOA() if self.force_channels == "foa" else torch.nn.Identity(), Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), Mono() if self.force_channels == "mono" else torch.nn.Identity(), ) self.input_type = input_type self.sr = sample_rate for config in configs: self.root_paths.append(config.path) def add_prefix(s): return str(os.path.join(config.path,f'{s.strip()}')) with open(config.split_path,'r') as f: item_names = f.readlines() filenames = list(map(add_prefix, item_names)) self.filenames.extend(filenames) # self.filenames.extend(get_audio_filenames(config.path, keywords)) print(f'Found {len(self.filenames)} files') def load_file(self, filename, info): # try: npz_file = filename.replace('.pth','.npz') if os.path.exists(filename) and '.npz' not in filename: data = torch.load(filename, weights_only=False) elif os.path.exists(npz_file): # print(filename) npz_data = np.load(npz_file,allow_pickle=True) data = {key: npz_data[key] for key in npz_data.files} # print("data.keys()",data.keys()) for key in data.keys(): if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): data[key] = torch.from_numpy(data[key]) else: raise ValueError(f'error load file: {filename}') info.update(data) audio = data['latent'] # except: # print(f'error load file: {filename}') return audio, info['metaclip_features'] def __len__(self): return len(self.filenames) def __getitem__(self, idx): audio_filename = self.filenames[idx] assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists' # try: start_time = time.time() info = {} audio, video = self.load_file(audio_filename, info) info["path"] = audio_filename assert audio.shape == (64,194), f'{audio.shape} input error, id: {id}' assert video.shape == (72,1024), f'{video.shape} input error, id: {id}' info['id'] = Path(audio_filename).stem for root_path in self.root_paths: if root_path in audio_filename: info["relpath"] = path.relpath(audio_filename, root_path) return (audio, info) class AudioDataset(torch.utils.data.Dataset): def __init__( self, configs, sample_size=65536, sample_rate=48000, keywords=None, random_crop=True, input_type="prompt", fps=4, force_channels="stereo" ): super().__init__() self.filenames = [] self.augs = torch.nn.Sequential( PhaseFlipper(), ) self.root_paths = [] self.force_channels = force_channels print('######################') print(f'input channels is: {force_channels}') print('######################') self.encoding = torch.nn.Sequential( FOA() if self.force_channels == "foa" else torch.nn.Identity(), Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), Mono() if self.force_channels == "mono" else torch.nn.Identity(), ) self.fake_clip_features = torch.zeros(72, 1024) self.fake_sync_features = torch.zeros(216, 768) self.video_exist = torch.tensor(0, dtype=torch.bool) self.input_type = input_type self.sr = sample_rate for config in configs: self.root_paths.append(config.path) def add_prefix(s): return str(os.path.join(config.path,f'{s.strip()}')) with open(config.split_path,'r') as f: item_names = f.readlines() filenames = list(map(add_prefix, item_names)) self.filenames.extend(filenames) # self.filenames.extend(get_audio_filenames(config.path, keywords)) print(f'Found {len(self.filenames)} files') def load_file(self, filename, info): # try: npz_file = filename.replace('.pth','.npz') if os.path.exists(filename) and '.npz' not in filename: data = torch.load(filename, weights_only=False) elif os.path.exists(npz_file): # print(filename) npz_data = np.load(npz_file,allow_pickle=True) data = {key: npz_data[key] for key in npz_data.files} # print("data.keys()",data.keys()) for key in data.keys(): if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): data[key] = torch.from_numpy(data[key]) else: raise ValueError(f'error load file: {filename}') info.update(data) audio = data['latent'] info['metaclip_features'] = self.fake_clip_features info['sync_features'] = self.fake_sync_features info['video_exist'] = self.video_exist # except: # print(f'error load file: {filename}') return audio, info['metaclip_features'] def __len__(self): return len(self.filenames) def __getitem__(self, idx): audio_filename = self.filenames[idx] assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists' # try: start_time = time.time() info = {} audio, video = self.load_file(audio_filename, info) info["path"] = audio_filename assert audio.shape == (64,194), f'{audio.shape} input error, id: {id}' assert video.shape == (72,1024), f'{video.shape} input error, id: {id}' info['id'] = Path(audio_filename).stem for root_path in self.root_paths: if root_path in audio_filename: info["relpath"] = path.relpath(audio_filename, root_path) return (audio, info) class VideoDataset(torch.utils.data.Dataset): def __init__( self, configs, sample_size=65536, sample_rate=48000, keywords=None, random_crop=True, input_type="prompt", fps=4, force_channels="stereo", ): super().__init__() self.filenames = [] print(f'configs: {configs[0]}') if configs[0].extra_cot is not None: self.extra_cot = configs[0].extra_cot print(f'load extra cot from {self.extra_cot}') else: self.extra_cot = None self.augs = torch.nn.Sequential( PhaseFlipper(), ) self.root_paths = [] self.force_channels = force_channels print('######################') print(f'input channels is: {force_channels}') print('######################') self.encoding = torch.nn.Sequential( FOA() if self.force_channels == "foa" else torch.nn.Identity(), Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), Mono() if self.force_channels == "mono" else torch.nn.Identity(), ) self.input_type = input_type self.sr = sample_rate self.video_exist = torch.tensor(1, dtype=torch.bool) for config in configs: self.root_paths.append(config.path) def add_prefix(s): return str(os.path.join(config.path,f'{s.strip()}')) with open(config.split_path,'r') as f: item_names = f.readlines() filenames = list(map(add_prefix, item_names)) self.filenames.extend(filenames) # self.filenames.extend(get_audio_filenames(config.path, keywords)) print(f'Found {len(self.filenames)} files') def load_file(self, filename, info): # try: npz_file = filename.replace('.pth','.npz') if os.path.exists(filename) and '.npz' not in filename: data = torch.load(filename, weights_only=False) elif os.path.exists(npz_file): # print(filename) npz_data = np.load(npz_file,allow_pickle=True) data = {key: npz_data[key] for key in npz_data.files} # print("data.keys()",data.keys()) for key in data.keys(): if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): data[key] = torch.from_numpy(data[key]) if self.extra_cot is not None: extra_pth = filename.replace('.npz','.pth') extra_pth = os.path.join(self.extra_cot, os.path.basename(extra_pth)) if os.path.exists(extra_pth): extra_data = torch.load(extra_pth, weights_only=False) for key in extra_data.keys(): if isinstance(extra_data[key], torch.Tensor): # print(f'load extra cot {key}') data[key] = extra_data[key] else: raise ValueError(f'error load file: {filename}') info.update(data) if 'latent' in data.keys(): audio = data['latent'] else: audio = torch.zeros(64,194) info['video_exist'] = self.video_exist # except: # print(f'error load file: {filename}') return audio, info['metaclip_features'] def __len__(self): return len(self.filenames) def __getitem__(self, idx): audio_filename = self.filenames[idx] assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists' # try: start_time = time.time() info = {} audio, video = self.load_file(audio_filename, info) info["path"] = audio_filename assert audio is None or audio.shape == (64,194), f'{audio.shape} input error, id: {id}' assert video.shape == (72,1024), f'{video.shape} input error, id: {id}' info['id'] = Path(audio_filename).stem for root_path in self.root_paths: if root_path in audio_filename: info["relpath"] = path.relpath(audio_filename, root_path) return (audio, info) # modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset class MultiModalDataset(torch.utils.data.Dataset): datasets: list[torch.utils.data.Dataset] cumulative_sizes: list[int] @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, video_datasets: list[torch.utils.data.Dataset], audio_datasets: list[torch.utils.data.Dataset]): super().__init__() self.video_datasets = list(video_datasets) self.audio_datasets = list(audio_datasets) self.datasets = self.video_datasets + self.audio_datasets self.cumulative_sizes = self.cumsum(self.datasets) print(f'Found {self.cumulative_sizes[-1]} files') def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError("absolute value of index should not exceed dataset length") idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: return self.video_datasets[0].compute_latent_stats() # class MultiModalDataset(torch.utils.data.Dataset): # def __init__( # self, # configs, # sample_size=65536, # sample_rate=48000, # keywords=None, # random_crop=True, # input_type="prompt", # fps=4, # force_channels="stereo" # ): # super().__init__() # self.filenames = [] # self.captions = [] # self.caption_t5s = [] # self.ids = [] # self.augs = torch.nn.Sequential( # PhaseFlipper(), # ) # self.root_paths = [] # if input_type == 'video': # self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) # elif input_type == 'video_hiera': # self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) # elif input_type == 'video_image': # self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) # elif input_type == 'dual_video': # self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) # else: # self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) # self.force_channels = force_channels # print('######################') # print(f'input channels is: {force_channels}') # print('######################') # self.encoding = torch.nn.Sequential( # FOA() if self.force_channels == "foa" else torch.nn.Identity(), # Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), # Mono() if self.force_channels == "mono" else torch.nn.Identity(), # ) # self.input_type = input_type # self.sr = sample_rate # self.custom_metadata_fns = {} # for config in configs: # print(config.split_path) # self.root_paths.append(config.path) # def add_prefix(s): # return str(os.path.join(config.path,f'{s.strip()}')) # with open(config.split_path,'r') as f: # item_names = f.readlines() # csv_path = config.split_path.replace('.txt','.csv') # df = pd.read_csv(csv_path) # # 检查是否存在 'caption_t5' 列,如果不存在则创建并复制 'caption' 的值 # if 'caption_t5' not in df.columns: # df['caption_t5'] = df['caption'] # captions = df['caption'].tolist() # caption_t5s = df['caption_t5'].tolist() # filenames = list(map(add_prefix, item_names)) # assert len(captions) == len(caption_t5s) and len(captions) == len(filenames), f'{config.path} has wrong filename and caption' # if config.id == 'vggsound': # self.filenames.extend(filenames*5) # self.captions.extend(captions*5) # self.caption_t5s.extend(caption_t5s*5) # self.ids.extend(df['id'].tolist()*5) # else: # self.filenames.extend(filenames) # self.captions.extend(captions) # self.caption_t5s.extend(caption_t5s) # self.ids.extend(df['id'].tolist()) # # self.filenames.extend(get_audio_filenames(config.path, keywords)) # if config.custom_metadata_fn is not None: # self.custom_metadata_fns[config.path] = config.custom_metadata_fn # assert len(self.ids) == len(self.captions) and len(self.caption_t5s) == len(self.filenames), 'length need to be same' # print(f'Found {len(self.filenames)} files') # def load_file(self, filename): # ext = filename.split(".")[-1] # if ext == "mp3": # with AudioFile(filename) as f: # audio = f.read(f.frames) # audio = torch.from_numpy(audio) # in_sr = f.samplerate # else: # audio, in_sr = torchaudio.load(filename, format=ext) # if in_sr != self.sr: # try: # resample_tf = T.Resample(in_sr, self.sr) # audio = resample_tf(audio) # except: # print(f'{filename} resample errors') # assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' # return audio # def __len__(self): # return len(self.filenames) # def __getitem__(self, idx): # audio_filename = self.filenames[idx] # id = self.ids[idx] # assert str(id) == str(Path(audio_filename).stem), f'audio_file: {audio_filename} needs to be same as {id} ' # assert os.path.exists(audio_filename), f'{audio_filename}: file not exists' # try: # start_time = time.time() # audio = self.load_file(audio_filename) # caption = self.captions[idx] # caption_t5 = self.caption_t5s[idx] # if pd.isna(caption_t5) or caption_t5 == '': # caption_t5 = caption # info = {} # info["path"] = audio_filename # info['caption'] = caption # info['caption_t5'] = caption_t5 # for root_path in self.root_paths: # if root_path in audio_filename: # info["relpath"] = path.relpath(audio_filename, root_path) # for custom_md_path in self.custom_metadata_fns.keys(): # if custom_md_path in audio_filename: # custom_metadata_fn = self.custom_metadata_fns[custom_md_path] # custom_metadata = custom_metadata_fn(info, audio) # info.update(custom_metadata) # if "__reject__" in info and info["__reject__"]: # return self[random.randrange(len(self))] # # if self.input_type == 'video': # # audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['clip_features']) # # info['clip_features'] = video # # else: # if info['flag']: # audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=False) # else: # audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=True) # assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' # # Run augmentations on this sample (including random crop) # if self.augs is not None: # audio = self.augs(audio) # audio = audio.clamp(-1, 1) # # Encode the file to assist in prediction # if self.encoding is not None: # audio = self.encoding(audio) # info["timestamps"] = (t_start, t_end) # info["seconds_start"] = seconds_start # info["seconds_total"] = seconds_total # info["padding_mask"] = padding_mask # end_time = time.time() # info["load_time"] = end_time - start_time # return (audio, info) # except Exception as e: # print(f'Couldn\'t load file {audio_filename}: {e}') # return self[random.randrange(len(self))] def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None): """Return function over iterator that groups key, value pairs into samples. :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to lower case (Default value = True) """ current_sample = None for filesample in data: assert isinstance(filesample, dict) fname, value = filesample["fname"], filesample["data"] prefix, suffix = keys(fname) if wds.tariterators.trace: print( prefix, suffix, current_sample.keys() if isinstance(current_sample, dict) else None, ) if prefix is None: continue if lcase: suffix = suffix.lower() if current_sample is None or prefix != current_sample["__key__"]: if wds.tariterators.valid_sample(current_sample): yield current_sample current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) if suffix in current_sample: print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") if suffixes is None or suffix in suffixes: current_sample[suffix] = value if wds.tariterators.valid_sample(current_sample): yield current_sample wds.tariterators.group_by_keys = group_by_keys # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): """ Returns a list of full S3 paths to files in a given S3 bucket and directory path. """ # Ensure dataset_path ends with a trailing slash if dataset_path != '' and not dataset_path.endswith('/'): dataset_path += '/' # Use posixpath to construct the S3 URL path bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) # Construct the `aws s3 ls` command cmd = ['aws', 's3', 'ls', bucket_path] if profile is not None: cmd.extend(['--profile', profile]) if recursive: # Add the --recursive flag if requested cmd.append('--recursive') # Run the `aws s3 ls` command and capture the output run_ls = subprocess.run(cmd, capture_output=True, check=True) # Split the output into lines and strip whitespace from each line contents = run_ls.stdout.decode('utf-8').split('\n') contents = [x.strip() for x in contents if x] # Remove the timestamp from lines that begin with a timestamp contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] # Construct a full S3 path for each file in the contents list contents = [posixpath.join(s3_url_prefix or '', x) for x in contents if not x.endswith('/')] # Apply the filter, if specified if filter: contents = [x for x in contents if filter in x] # Remove redundant directory names in the S3 URL if recursive: # Get the main directory name from the S3 URL main_dir = "/".join(bucket_path.split('/')[3:]) # Remove the redundant directory names from each file path contents = [x.replace(f'{main_dir}', '').replace( '//', '/') for x in contents] # Print debugging information, if requested if debug: print("contents = \n", contents) # Return the list of S3 paths to files return contents def get_all_s3_urls( names=[], # list of all valid [LAION AudioDataset] dataset names # list of subsets you want from those datasets, e.g. ['train','valid'] subsets=[''], s3_url_prefix=None, # prefix for those dataset names recursive=True, # recursively list all tar files in all subdirs filter_str='tar', # only grab files with this substring # print debugging info -- note: info displayed likely to change at dev's whims debug=False, profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} ): "get urls of shards (tar files) for multiple datasets in one s3 bucket" urls = [] for name in names: # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list if s3_url_prefix is None: contents_str = name else: # Construct the S3 path using the s3_url_prefix and the current name value contents_str = posixpath.join(s3_url_prefix, name) if debug: print(f"get_all_s3_urls: {contents_str}:") for subset in subsets: subset_str = posixpath.join(contents_str, subset) if debug: print(f"subset_str = {subset_str}") # Get the list of tar files in the current subset directory profile = profiles.get(name, None) tar_list = get_s3_contents( subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) for tar in tar_list: # Escape spaces and parentheses in the tar filename for use in the shell command tar = tar.replace(" ", "\ ").replace( "(", "\(").replace(")", "\)") # Construct the S3 path to the current tar file s3_path = posixpath.join(name, subset, tar) + " -" # Construct the AWS CLI command to download the current tar file if s3_url_prefix is None: request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" else: request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" if profiles.get(name): request_str += f" --profile {profiles.get(name)}" if debug: print("request_str = ", request_str) # Add the constructed URL to the list of URLs urls.append(request_str) return urls def log_and_continue(exn): """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" print(f"Handling webdataset error ({repr(exn)}). Ignoring.") return True def is_valid_sample(sample): has_json = "json" in sample has_audio = "audio" in sample is_silent = is_silence(sample["audio"]) is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"] return has_json and has_audio and not is_silent and not is_rejected class S3DatasetConfig: def __init__( self, id: str, s3_path: str, custom_metadata_fn: Optional[Callable[[str], str]] = None, profile: Optional[str] = None, ): self.id = id self.path = s3_path self.custom_metadata_fn = custom_metadata_fn self.profile = profile self.urls = [] def load_data_urls(self): self.urls = get_all_s3_urls( names=[self.path], s3_url_prefix=None, recursive=True, profiles={self.path: self.profile} if self.profile else {}, ) return self.urls class LocalWebDatasetConfig: def __init__( self, id: str, path: str, custom_metadata_fn: Optional[Callable[[str], str]] = None, profile: Optional[str] = None, ): self.id = id self.path = path self.custom_metadata_fn = custom_metadata_fn self.urls = [] def load_data_urls(self): self.urls = fast_scandir(self.path, ["tar"])[1] return self.urls def audio_decoder(key, value): # Get file extension from key ext = key.split(".")[-1] if ext in AUDIO_KEYS: return torchaudio.load(io.BytesIO(value)) else: return None def collation_fn(samples): batched = list(zip(*samples)) result = [] for b in batched: if isinstance(b[0], (int, float)): b = np.array(b) elif isinstance(b[0], torch.Tensor): b = torch.stack(b) elif isinstance(b[0], np.ndarray): b = np.array(b) else: b = b result.append(b) return result class WebDatasetDataLoader(): def __init__( self, datasets: List[S3DatasetConfig], batch_size, sample_size, sample_rate=48000, num_workers=8, epoch_steps=1000, random_crop=True, force_channels="stereo", augment_phase=True, **data_loader_kwargs ): self.datasets = datasets self.sample_size = sample_size self.sample_rate = sample_rate self.random_crop = random_crop self.force_channels = force_channels self.augment_phase = augment_phase urls = [dataset.load_data_urls() for dataset in datasets] # Flatten the list of lists of URLs urls = [url for dataset_urls in urls for url in dataset_urls] # Shuffle the urls random.shuffle(urls) self.dataset = wds.DataPipeline( wds.ResampledShards(urls), wds.tarfile_to_samples(handler=log_and_continue), wds.decode(audio_decoder, handler=log_and_continue), wds.map(self.wds_preprocess, handler=log_and_continue), wds.select(is_valid_sample), wds.to_tuple("audio", "json", handler=log_and_continue), #wds.shuffle(bufsize=1000, initial=5000), wds.batched(batch_size, partial=False, collation_fn=collation_fn), ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs) def wds_preprocess(self, sample): found_key, rewrite_key = '', '' for k, v in sample.items(): # print the all entries in dict for akey in AUDIO_KEYS: if k.endswith(akey): # to rename long/weird key with its simpler counterpart found_key, rewrite_key = k, akey break if '' != found_key: break if '' == found_key: # got no audio! return None # try returning None to tell WebDataset to skip this one audio, in_sr = sample[found_key] if in_sr != self.sample_rate: resample_tf = T.Resample(in_sr, self.sample_rate) audio = resample_tf(audio) if self.sample_size is not None: # Pad/crop and get the relative timestamp pad_crop = PadCrop_Normalized_T( self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop( audio) sample["json"]["seconds_start"] = seconds_start sample["json"]["seconds_total"] = seconds_total sample["json"]["padding_mask"] = padding_mask else: t_start, t_end = 0, 1 # Check if audio is length zero, initialize to a single zero if so if audio.shape[-1] == 0: audio = torch.zeros(1, 1) # Make the audio stereo and augment by randomly inverting phase augs = torch.nn.Sequential( Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), Mono() if self.force_channels == "mono" else torch.nn.Identity(), PhaseFlipper() if self.augment_phase else torch.nn.Identity() ) audio = augs(audio) sample["json"]["timestamps"] = (t_start, t_end) if "text" in sample["json"]: sample["json"]["prompt"] = sample["json"]["text"] # Check for custom metadata functions for dataset in self.datasets: if dataset.custom_metadata_fn is None: continue if dataset.path in sample["__url__"]: custom_metadata = dataset.custom_metadata_fn(sample["json"], audio) sample["json"].update(custom_metadata) if found_key != rewrite_key: # rename long/weird key with its simpler counterpart del sample[found_key] sample["audio"] = audio # Add audio to the metadata as well for conditioning sample["json"]["audio"] = audio return sample def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4): dataset_type = dataset_config.get("dataset_type", None) assert dataset_type is not None, "Dataset type must be specified in dataset config" if audio_channels == 1: force_channels = "mono" elif audio_channels == 2: force_channels = "stereo" else: force_channels = "foa" if dataset_type == "audio_dir": audio_dir_configs = dataset_config.get("datasets", None) assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" configs = [] for audio_dir_config in audio_dir_configs: audio_dir_path = audio_dir_config.get("path", None) split_path = audio_dir_config.get("split_path", None) assert audio_dir_path is not None, "Path must be set for local audio directory configuration" custom_metadata_fn = None custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) if custom_metadata_module_path is not None: spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) metadata_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(metadata_module) custom_metadata_fn = metadata_module.get_custom_metadata configs.append( LocalDatasetConfig( id=audio_dir_config["id"], path=audio_dir_path, split_path=split_path, custom_metadata_fn=custom_metadata_fn ) ) train_set = SampleDataset( configs, sample_rate=sample_rate, sample_size=sample_size, random_crop=dataset_config.get("random_crop", True), input_type=dataset_config.get("input_type", "video"), fps=dataset_config.get("fps", 4), force_channels=force_channels ) return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility wds_configs = [] for wds_config in dataset_config["datasets"]: custom_metadata_fn = None custom_metadata_module_path = wds_config.get("custom_metadata_module", None) if custom_metadata_module_path is not None: spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) metadata_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(metadata_module) custom_metadata_fn = metadata_module.get_custom_metadata if "s3_path" in wds_config: wds_configs.append( S3DatasetConfig( id=wds_config["id"], s3_path=wds_config["s3_path"], custom_metadata_fn=custom_metadata_fn, profile=wds_config.get("profile", None), ) ) elif "path" in wds_config: wds_configs.append( LocalWebDatasetConfig( id=wds_config["id"], path=wds_config["path"], custom_metadata_fn=custom_metadata_fn ) ) return WebDatasetDataLoader( wds_configs, sample_rate=sample_rate, sample_size=sample_size, batch_size=batch_size, random_crop=dataset_config.get("random_crop", True), num_workers=num_workers, persistent_workers=True, force_channels=force_channels, epoch_steps=dataset_config.get("epoch_steps", 2000) ).data_loader elif dataset_type == "latent_dir": audio_dir_configs = dataset_config.get("datasets", None) assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" configs = [] for audio_dir_config in audio_dir_configs: audio_dir_path = audio_dir_config.get("path", None) split_path = audio_dir_config.get("split_path", None) assert audio_dir_path is not None, "Path must be set for local audio directory configuration" configs.append( LocalDatasetConfig( id=audio_dir_config["id"], path=audio_dir_path, split_path=split_path, ) ) train_set = LatentDataset( configs, sample_rate=sample_rate, sample_size=sample_size, random_crop=dataset_config.get("random_crop", True), input_type=dataset_config.get("input_type", "video"), fps=dataset_config.get("fps", 4), force_channels=force_channels ) return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) elif dataset_type == 'multimodal_dir': audio_dir_configs = dataset_config.get("datasets", None) assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" configs = [] for audio_dir_config in audio_dir_configs: audio_dir_path = audio_dir_config.get("path", None) split_path = audio_dir_config.get("split_path", None) assert audio_dir_path is not None, "Path must be set for local audio directory configuration" custom_metadata_fn = None custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) if custom_metadata_module_path is not None: spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) metadata_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(metadata_module) custom_metadata_fn = metadata_module.get_custom_metadata configs.append( LocalDatasetConfig( id=audio_dir_config["id"], path=audio_dir_path, split_path=split_path, custom_metadata_fn=custom_metadata_fn ) ) train_set = MultiModalDataset( configs, sample_rate=sample_rate, sample_size=sample_size, random_crop=dataset_config.get("random_crop", True), input_type=dataset_config.get("input_type", "video"), fps=dataset_config.get("fps", 4), force_channels=force_channels ) return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)