import importlib import numpy as np import io import os import posixpath import random import re import subprocess import time import torch import torchaudio import pandas as pd import webdataset as wds from transformers import AutoTokenizer from aeiou.core import is_silence from os import path from pedalboard.io import AudioFile from torchaudio import transforms as T from typing import Optional, Callable, List from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py def fast_scandir( dir:str, # top-level directory at which to begin scanning ext:list, # list of allowed file extensions, #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB ): "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" subfolders, files = [], [] ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed try: # hope to avoid 'permission denied' by this try for f in os.scandir(dir): try: # 'hope to avoid too many levels of symbolic links' error if f.is_dir(): subfolders.append(f.path) elif f.is_file(): file_ext = os.path.splitext(f.name)[1].lower() is_hidden = os.path.basename(f.path).startswith(".") if file_ext in ext and not is_hidden: files.append(f.path) except: pass except: pass for dir in list(subfolders): sf, f = fast_scandir(dir, ext) subfolders.extend(sf) files.extend(f) return subfolders, files def keyword_scandir( dir: str, # top-level directory at which to begin scanning ext: list, # list of allowed file extensions keywords: list, # list of keywords to search for in the file name ): "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" subfolders, files = [], [] # make keywords case insensitive keywords = [keyword.lower() for keyword in keywords] # add starting period to extensions if needed ext = ['.'+x if x[0] != '.' else x for x in ext] banned_words = ["paxheader", "__macosx"] try: # hope to avoid 'permission denied' by this try for f in os.scandir(dir): try: # 'hope to avoid too many levels of symbolic links' error if f.is_dir(): subfolders.append(f.path) elif f.is_file(): is_hidden = f.name.split("/")[-1][0] == '.' has_ext = os.path.splitext(f.name)[1].lower() in ext name_lower = f.name.lower() has_keyword = any( [keyword in name_lower for keyword in keywords]) has_banned = any( [banned_word in name_lower for banned_word in banned_words]) if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): files.append(f.path) except: pass except: pass for dir in list(subfolders): sf, f = keyword_scandir(dir, ext, keywords) subfolders.extend(sf) files.extend(f) return subfolders, files def get_audio_filenames( paths: list, # directories in which to search keywords=None, exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] ): "recursively get a list of audio filenames" filenames = [] if type(paths) is str: paths = [paths] for path in paths: # get a list of relevant filenames if keywords is not None: subfolders, files = keyword_scandir(path, exts, keywords) else: subfolders, files = fast_scandir(path, exts) filenames.extend(files) return filenames class LocalDatasetConfig: def __init__( self, id: str, path: str, custom_metadata_fn: Optional[Callable[[str], str]] = None ): self.id = id self.path = path self.custom_metadata_fn = custom_metadata_fn class SampleDataset(torch.utils.data.Dataset): def __init__( self, configs, sample_size=65536, sample_rate=48000, keywords=None, random_crop=True, force_channels="stereo" ): super().__init__() self.csv_dataset = pd.read_csv(configs[0].path)[:1000] self.filenames = list(self.csv_dataset['path']) self.captions = list(self.csv_dataset['caption']) self.augs = torch.nn.Sequential( PhaseFlipper(), ) self.root_paths = [] self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) self.force_channels = force_channels self.encoding = torch.nn.Sequential( Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), Mono() if self.force_channels == "mono" else torch.nn.Identity(), ) self.sr = sample_rate # Initialize the LLM tokenizer and expand # 0. Initialize LLM tokenizer self.llm_model_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B") self.num_new_tokens = 64 # 1. LLM has the special token "" for system message to generate image -> add_tokens "" -> 32000 self.llm_model_tokenizer.add_tokens([""], special_tokens=False) # 2. LLM contains 32 tokens to summarize image and text information for conversation system -> add_tokens "..." -> 32003~32034 new_token_list = [f"" for i in range(self.num_new_tokens)] self.llm_model_tokenizer.add_tokens(new_token_list, special_tokens=False) # 3. count new tokens and resize tokenizer self.num_new_tokens = self.num_new_tokens + 1 self.llm_model_tokenizer.ad_start_token_id = self.llm_model_tokenizer.convert_tokens_to_ids("") print(f'Found {len(self.filenames)} files') def load_file(self, filename): ext = filename.split(".")[-1] if ext == "mp3": with AudioFile(filename) as f: audio = f.read(f.frames) audio = torch.from_numpy(audio) in_sr = f.samplerate else: audio, in_sr = torchaudio.load(filename, format=ext) if in_sr != self.sr: resample_tf = T.Resample(in_sr, self.sr) audio = resample_tf(audio) return audio def __len__(self): return len(self.filenames) def __getitem__(self, idx): audio_filename = self.filenames[idx] try: start_time = time.time() # for preferred audio audio = self.load_file(audio_filename) audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio) # Run augmentations on this sample (including random crop) if self.augs is not None: audio = self.augs(audio) audio = audio.clamp(-1, 1) # Encode the file to assist in prediction if self.encoding is not None: audio = self.encoding(audio) # --------------------------------------------------------------- # info = {} info["path"] = audio_filename for root_path in self.root_paths: if root_path in audio_filename: info["relpath"] = path.relpath(audio_filename, root_path) info["timestamps"] = (t_start, t_end) info["seconds_start"] = seconds_start info["seconds_total"] = seconds_total info["padding_mask"] = padding_mask end_time = time.time() info["load_time"] = end_time - start_time info['prompt'] = self.captions[idx] # # define a system prompt for the LLM # llm_caption_system = "A chat between a curious user and an artificial intelligence assistant. The assistant can generate . " # # construct the prompt for the LLM # llm_caption_interim = "Please generate an audio for the following caption: " + info["prompt"] # llm_caption_last = " Here is the audio for the given caption: [ad]" # append_str = "" # for i in range(self.num_new_tokens - 1): # append_str += f" " # llm_caption = llm_caption_last.replace(" [ad]", append_str) # # add the system prompt to the LLM prompt # llm_caption = llm_caption_system + llm_caption_interim + llm_caption_last # # tokenize the prompt # IGNORE_TOKEN_ID=-100 # input_ids_max_len = 512 # llm_caption_input_ids = self.llm_tokenizer( # llm_caption, # return_tensors="pt", # padding="max_length", # max_length=input_ids_max_len, # truncation=True, # ).input_ids[0] # # generate LLM targets # llm_targets = llm_caption_input_ids.clone() # llm_targets[:1] = IGNORE_TOKEN_ID # total_padding_len = int(llm_targets.ne(self.llm_tokenizer.pad_token_id).sum()) # instruction_len = len( # self.llm_model_tokenizer( # llm_caption_system + llm_caption_interim, # max_length=input_ids_max_len, # truncation=True, # ).input_ids) - 2 # llm_targets[1:(1 + instruction_len)] = IGNORE_TOKEN_ID # llm_targets[total_padding_len:] = IGNORE_TOKEN_ID # # store everything in a dictionary # prompt_dict = {} # prompt_dict["prompt"] = self.captions[idx] # prompt_dict['llm_caption_last'] = llm_caption_last # prompt_dict['llm_caption_interim'] = llm_caption_interim # prompt_dict['llm_caption_system'] = llm_caption_system # prompt_dict['llm_caption_input_ids'] = llm_caption_input_ids # prompt_dict['llm_targets'] = llm_targets # prompt_dict['llm_input_ids_attention_mask'] = llm_caption_input_ids.ne(self.llm_tokenizer.pad_token_id) # prompt_dict['llm_qformer_attention_mask'] = llm_caption_input_ids.ge(self.llm_tokenizer.img_start_token_id) # # store the dictionary as a key # info['prompt'] = prompt_dict return (audio, info) except Exception as e: print(f'Couldn\'t load file {audio_filename}: {e}') return self[random.randrange(len(self))] def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None): """Return function over iterator that groups key, value pairs into samples. :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to lower case (Default value = True) """ current_sample = None for filesample in data: assert isinstance(filesample, dict) fname, value = filesample["fname"], filesample["data"] prefix, suffix = keys(fname) if wds.tariterators.trace: print( prefix, suffix, current_sample.keys() if isinstance(current_sample, dict) else None, ) if prefix is None: continue if lcase: suffix = suffix.lower() if current_sample is None or prefix != current_sample["__key__"]: if wds.tariterators.valid_sample(current_sample): yield current_sample current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) if suffix in current_sample: print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") if suffixes is None or suffix in suffixes: current_sample[suffix] = value if wds.tariterators.valid_sample(current_sample): yield current_sample wds.tariterators.group_by_keys = group_by_keys # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): """ Returns a list of full S3 paths to files in a given S3 bucket and directory path. """ # Ensure dataset_path ends with a trailing slash if dataset_path != '' and not dataset_path.endswith('/'): dataset_path += '/' # Use posixpath to construct the S3 URL path bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) # Construct the `aws s3 ls` command cmd = ['aws', 's3', 'ls', bucket_path] if profile is not None: cmd.extend(['--profile', profile]) if recursive: # Add the --recursive flag if requested cmd.append('--recursive') # Run the `aws s3 ls` command and capture the output run_ls = subprocess.run(cmd, capture_output=True, check=True) # Split the output into lines and strip whitespace from each line contents = run_ls.stdout.decode('utf-8').split('\n') contents = [x.strip() for x in contents if x] # Remove the timestamp from lines that begin with a timestamp contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] # Construct a full S3 path for each file in the contents list contents = [posixpath.join(s3_url_prefix or '', x) for x in contents if not x.endswith('/')] # Apply the filter, if specified if filter: contents = [x for x in contents if filter in x] # Remove redundant directory names in the S3 URL if recursive: # Get the main directory name from the S3 URL main_dir = "/".join(bucket_path.split('/')[3:]) # Remove the redundant directory names from each file path contents = [x.replace(f'{main_dir}', '').replace( '//', '/') for x in contents] # Print debugging information, if requested if debug: print("contents = \n", contents) # Return the list of S3 paths to files return contents def get_all_s3_urls( names=[], # list of all valid [LAION AudioDataset] dataset names # list of subsets you want from those datasets, e.g. ['train','valid'] subsets=[''], s3_url_prefix=None, # prefix for those dataset names recursive=True, # recursively list all tar files in all subdirs filter_str='tar', # only grab files with this substring # print debugging info -- note: info displayed likely to change at dev's whims debug=False, profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} ): "get urls of shards (tar files) for multiple datasets in one s3 bucket" urls = [] for name in names: # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list if s3_url_prefix is None: contents_str = name else: # Construct the S3 path using the s3_url_prefix and the current name value contents_str = posixpath.join(s3_url_prefix, name) if debug: print(f"get_all_s3_urls: {contents_str}:") for subset in subsets: subset_str = posixpath.join(contents_str, subset) if debug: print(f"subset_str = {subset_str}") # Get the list of tar files in the current subset directory profile = profiles.get(name, None) tar_list = get_s3_contents( subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) for tar in tar_list: # Escape spaces and parentheses in the tar filename for use in the shell command tar = tar.replace(" ", "\ ").replace( "(", "\(").replace(")", "\)") # Construct the S3 path to the current tar file s3_path = posixpath.join(name, subset, tar) + " -" # Construct the AWS CLI command to download the current tar file if s3_url_prefix is None: request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" else: request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" if profiles.get(name): request_str += f" --profile {profiles.get(name)}" if debug: print("request_str = ", request_str) # Add the constructed URL to the list of URLs urls.append(request_str) return urls def log_and_continue(exn): """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" print(f"Handling webdataset error ({repr(exn)}). Ignoring.") return True def is_valid_sample(sample): has_json = "json" in sample has_audio = "audio" in sample is_silent = is_silence(sample["audio"]) is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"] return has_json and has_audio and not is_silent and not is_rejected class S3DatasetConfig: def __init__( self, id: str, s3_path: str, custom_metadata_fn: Optional[Callable[[str], str]] = None, profile: Optional[str] = None, ): self.id = id self.path = s3_path self.custom_metadata_fn = custom_metadata_fn self.profile = profile self.urls = [] def load_data_urls(self): self.urls = get_all_s3_urls( names=[self.path], s3_url_prefix=None, recursive=True, profiles={self.path: self.profile} if self.profile else {}, ) return self.urls class LocalWebDatasetConfig: def __init__( self, id: str, path: str, custom_metadata_fn: Optional[Callable[[str], str]] = None, profile: Optional[str] = None, ): self.id = id self.path = path self.custom_metadata_fn = custom_metadata_fn self.urls = [] def load_data_urls(self): self.urls = fast_scandir(self.path, ["tar"])[1] return self.urls def audio_decoder(key, value): # Get file extension from key ext = key.split(".")[-1] if ext in AUDIO_KEYS: return torchaudio.load(io.BytesIO(value)) else: return None def collation_fn(samples): batched = list(zip(*samples)) result = [] for b in batched: if isinstance(b[0], (int, float)): b = np.array(b) elif isinstance(b[0], torch.Tensor): b = torch.stack(b) elif isinstance(b[0], np.ndarray): b = np.array(b) else: b = b result.append(b) return result class WebDatasetDataLoader(): def __init__( self, datasets: List[S3DatasetConfig], batch_size, sample_size, sample_rate=48000, num_workers=8, epoch_steps=1000, random_crop=True, force_channels="stereo", augment_phase=True, **data_loader_kwargs ): self.datasets = datasets self.sample_size = sample_size self.sample_rate = sample_rate self.random_crop = random_crop self.force_channels = force_channels self.augment_phase = augment_phase urls = [dataset.load_data_urls() for dataset in datasets] # Flatten the list of lists of URLs urls = [url for dataset_urls in urls for url in dataset_urls] # Shuffle the urls random.shuffle(urls) self.dataset = wds.DataPipeline( wds.ResampledShards(urls), wds.tarfile_to_samples(handler=log_and_continue), wds.decode(audio_decoder, handler=log_and_continue), wds.map(self.wds_preprocess, handler=log_and_continue), wds.select(is_valid_sample), wds.to_tuple("audio", "json", handler=log_and_continue), #wds.shuffle(bufsize=1000, initial=5000), wds.batched(batch_size, partial=False, collation_fn=collation_fn), ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs) def wds_preprocess(self, sample): found_key, rewrite_key = '', '' for k, v in sample.items(): # print the all entries in dict for akey in AUDIO_KEYS: if k.endswith(akey): # to rename long/weird key with its simpler counterpart found_key, rewrite_key = k, akey break if '' != found_key: break if '' == found_key: # got no audio! return None # try returning None to tell WebDataset to skip this one audio, in_sr = sample[found_key] if in_sr != self.sample_rate: resample_tf = T.Resample(in_sr, self.sample_rate) audio = resample_tf(audio) if self.sample_size is not None: # Pad/crop and get the relative timestamp pad_crop = PadCrop_Normalized_T( self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop( audio) sample["json"]["seconds_start"] = seconds_start sample["json"]["seconds_total"] = seconds_total sample["json"]["padding_mask"] = padding_mask else: t_start, t_end = 0, 1 # Check if audio is length zero, initialize to a single zero if so if audio.shape[-1] == 0: audio = torch.zeros(1, 1) # Make the audio stereo and augment by randomly inverting phase augs = torch.nn.Sequential( Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), Mono() if self.force_channels == "mono" else torch.nn.Identity(), PhaseFlipper() if self.augment_phase else torch.nn.Identity() ) audio = augs(audio) sample["json"]["timestamps"] = (t_start, t_end) if "text" in sample["json"]: sample["json"]["prompt"] = sample["json"]["text"] # Check for custom metadata functions for dataset in self.datasets: if dataset.custom_metadata_fn is None: continue if dataset.path in sample["__url__"]: custom_metadata = dataset.custom_metadata_fn(sample["json"], audio) sample["json"].update(custom_metadata) if found_key != rewrite_key: # rename long/weird key with its simpler counterpart del sample[found_key] sample["audio"] = audio # Add audio to the metadata as well for conditioning sample["json"]["audio"] = audio return sample def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4): dataset_type = dataset_config.get("dataset_type", None) assert dataset_type is not None, "Dataset type must be specified in dataset config" if audio_channels == 1: force_channels = "mono" else: force_channels = "stereo" if dataset_type == "audio_dir": audio_dir_configs = dataset_config.get("datasets", None) assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" configs = [] for audio_dir_config in audio_dir_configs: audio_dir_path = audio_dir_config.get("path", None) assert audio_dir_path is not None, "Path must be set for local audio directory configuration" custom_metadata_fn = None custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) if custom_metadata_module_path is not None: spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) metadata_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(metadata_module) custom_metadata_fn = metadata_module.get_custom_metadata configs.append( LocalDatasetConfig( id=audio_dir_config["id"], path=audio_dir_path, custom_metadata_fn=custom_metadata_fn ) ) train_set = SampleDataset( configs, sample_rate=sample_rate, sample_size=sample_size, random_crop=dataset_config.get("random_crop", True), force_channels=force_channels ) return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility wds_configs = [] for wds_config in dataset_config["datasets"]: custom_metadata_fn = None custom_metadata_module_path = wds_config.get("custom_metadata_module", None) if custom_metadata_module_path is not None: spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) metadata_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(metadata_module) custom_metadata_fn = metadata_module.get_custom_metadata if "s3_path" in wds_config: wds_configs.append( S3DatasetConfig( id=wds_config["id"], s3_path=wds_config["s3_path"], custom_metadata_fn=custom_metadata_fn, profile=wds_config.get("profile", None), ) ) elif "path" in wds_config: wds_configs.append( LocalWebDatasetConfig( id=wds_config["id"], path=wds_config["path"], custom_metadata_fn=custom_metadata_fn ) ) return WebDatasetDataLoader( wds_configs, sample_rate=sample_rate, sample_size=sample_size, batch_size=batch_size, random_crop=dataset_config.get("random_crop", True), num_workers=num_workers, persistent_workers=True, force_channels=force_channels, epoch_steps=dataset_config.get("epoch_steps", 2000) ).data_loader