Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import numpy as np | |
| import re | |
| import json | |
| from pathlib import Path | |
| import glob | |
| import os | |
| import shutil | |
| import torchaudio | |
| import torch | |
| from tqdm import tqdm | |
| def url_to_filename(url: str) -> str: | |
| return f"{url.split('/')[-1]}.wav" | |
| def has_valid_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series: | |
| audio_urls = audio_urls.replace(".", np.nan) | |
| audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir()) | |
| valid_audio_mask = audio_urls.apply( | |
| lambda url: url is not np.nan and url_to_filename(url) in audio_files | |
| ) | |
| return valid_audio_mask | |
| def validate_audio(audio_urls: pd.Series, audio_dir: str) -> pd.Series: | |
| """ | |
| Tests audio urls to ensure that their file exists and the contents is valid. | |
| """ | |
| audio_files = set(os.path.basename(f) for f in Path(audio_dir).iterdir()) | |
| def is_valid(url): | |
| valid_url = type(url) == str and "http" in url | |
| if not valid_url: | |
| return False | |
| filename = url_to_filename(url) | |
| if filename not in audio_files: | |
| return False | |
| try: | |
| w, _ = torchaudio.load(os.path.join(audio_dir, filename)) | |
| except: | |
| return False | |
| contents_invalid = ( | |
| torch.any(torch.isnan(w)) | |
| or torch.any(torch.isinf(w)) | |
| or len(torch.unique(w)) <= 2 | |
| ) | |
| return not contents_invalid | |
| idxs = [] | |
| validations = [] | |
| for index, url in tqdm( | |
| audio_urls.items(), total=len(audio_urls), desc="Audio URLs Validated" | |
| ): | |
| idxs.append(index) | |
| validations.append(is_valid(url)) | |
| return pd.Series(validations, index=idxs) | |
| def fix_dance_rating_counts(dance_ratings: pd.Series) -> pd.Series: | |
| tag_pattern = re.compile("([A-Za-z]+)(\+|-)(\d+)") | |
| dance_ratings = dance_ratings.apply(lambda v: json.loads(v.replace("'", '"'))) | |
| def fix_labels(labels: dict) -> dict | float: | |
| new_labels = {} | |
| for k, v in labels.items(): | |
| match = tag_pattern.search(k) | |
| if match is None: | |
| new_labels[k] = new_labels.get(k, 0) + v | |
| else: | |
| k = match[1] | |
| sign = 1 if match[2] == "+" else -1 | |
| scale = int(match[3]) | |
| new_labels[k] = new_labels.get(k, 0) + v * scale * sign | |
| valid = any(v > 0 for v in new_labels.values()) | |
| return new_labels if valid else np.nan | |
| return dance_ratings.apply(fix_labels) | |
| def get_unique_labels(dance_labels: pd.Series) -> list: | |
| labels = set() | |
| for dances in dance_labels: | |
| labels |= set(dances) | |
| return sorted(labels) | |
| def vectorize_label_probs( | |
| labels: dict[str, int], unique_labels: np.ndarray | |
| ) -> np.ndarray: | |
| """ | |
| Turns label dict into probability distribution vector based on each label count. | |
| """ | |
| label_vec = np.zeros((len(unique_labels),), dtype="float32") | |
| for k, v in labels.items(): | |
| item_vec = (unique_labels == k) * v | |
| label_vec += item_vec | |
| label_vec[label_vec < 0] = 0 | |
| label_vec /= label_vec.sum() | |
| assert not any(np.isnan(label_vec)), f"Provided labels are invalid: {labels}" | |
| return label_vec | |
| def vectorize_multi_label( | |
| labels: dict[str, int], unique_labels: np.ndarray | |
| ) -> np.ndarray: | |
| """ | |
| Turns label dict into binary label vectors for multi-label classification. | |
| """ | |
| probs = vectorize_label_probs(labels, unique_labels) | |
| probs[probs > 0.0] = 1.0 | |
| return probs | |
| def sort_yt_files( | |
| aliases_path="data/dance_aliases.json", | |
| all_dances_folder="data/best-ballroom-music", | |
| original_location="data/yt-ballroom-music/", | |
| ): | |
| def normalize_string(s): | |
| # Lowercase string and remove special characters | |
| return re.sub(r"\W+", "", s.lower()) | |
| with open(aliases_path, "r") as f: | |
| dances = json.load(f) | |
| # Normalize the dance inputs and aliases | |
| normalized_dances = { | |
| normalize_string(dance_id): [normalize_string(alias) for alias in aliases] | |
| for dance_id, aliases in dances.items() | |
| } | |
| # For every wav file in the target folder | |
| bad_files = [] | |
| progress_bar = tqdm(os.listdir(all_dances_folder), unit="files moved") | |
| for file_name in progress_bar: | |
| if file_name.endswith(".wav"): | |
| # check if the normalized wav file name contains the normalized dance alias | |
| normalized_file_name = normalize_string(file_name) | |
| matching_dance_ids = [ | |
| dance_id | |
| for dance_id, aliases in normalized_dances.items() | |
| if any(alias in normalized_file_name for alias in aliases) | |
| ] | |
| if len(matching_dance_ids) == 0: | |
| # See if the dance is in the path | |
| original_filename = file_name.replace(".wav", "") | |
| matches = glob.glob( | |
| os.path.join(original_location, "**", original_filename), | |
| recursive=True, | |
| ) | |
| if len(matches) == 1: | |
| normalized_file_name = normalize_string(matches[0]) | |
| matching_dance_ids = [ | |
| dance_id | |
| for dance_id, aliases in normalized_dances.items() | |
| if any(alias in normalized_file_name for alias in aliases) | |
| ] | |
| if "swz" in matching_dance_ids and "vwz" in matching_dance_ids: | |
| matching_dance_ids.remove("swz") | |
| if len(matching_dance_ids) > 1 and "lhp" in matching_dance_ids: | |
| matching_dance_ids.remove("lhp") | |
| if len(matching_dance_ids) != 1: | |
| bad_files.append(file_name) | |
| progress_bar.set_description(f"bad files: {len(bad_files)}") | |
| continue | |
| dst = os.path.join("data", "ballroom-songs", matching_dance_ids[0].upper()) | |
| os.makedirs(dst, exist_ok=True) | |
| filepath = os.path.join(all_dances_folder, file_name) | |
| shutil.copy(filepath, os.path.join(dst, file_name)) | |
| with open("data/bad_files.json", "w") as f: | |
| json.dump(bad_files, f) | |
| if __name__ == "__main__": | |
| sort_yt_files() | |