Spaces:
Runtime error
Runtime error
| import pytorch_lightning as pl | |
| from sklearn.base import ClassifierMixin, BaseEstimator | |
| import pandas as pd | |
| from torch import nn | |
| import torch | |
| from typing import Iterator | |
| import numpy as np | |
| import json | |
| from torch.utils.data import random_split | |
| from tqdm import tqdm | |
| import librosa | |
| from joblib import dump, load | |
| from os import path | |
| import os | |
| from preprocessing.dataset import get_music4dance_examples | |
| DANCE_INFO_FILE = "data/dance_info.csv" | |
| dance_info_df = pd.read_csv( | |
| DANCE_INFO_FILE, | |
| converters={"tempoRange": lambda s: json.loads(s.replace("'", '"'))}, | |
| ) | |
| class DanceTreeClassifier(BaseEstimator, ClassifierMixin): | |
| """ | |
| Trains a series of binary classifiers to classify each dance when a song falls into its bpm range. | |
| Features: | |
| - Spectrogram | |
| - BPM | |
| """ | |
| def __init__(self, device="cpu", lr=1e-4, verbose=True) -> None: | |
| self.device = device | |
| self.verbose = verbose | |
| self.lr = lr | |
| self.classifiers = {} | |
| self.optimizers = {} | |
| self.criterion = nn.BCELoss() | |
| def get_valid_dances_from_bpm(self, bpm: float) -> list[str]: | |
| mask = dance_info_df["tempoRange"].apply( | |
| lambda interval: interval["min"] <= bpm <= interval["max"] | |
| ) | |
| return list(dance_info_df["id"][mask]) | |
| def fit(self, x, y): | |
| """ | |
| x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time) | |
| y: (batch_size, n_classes) | |
| """ | |
| epoch_loss = 0 | |
| pred_count = 0 | |
| data_loader = zip(x, y) | |
| if self.verbose: | |
| data_loader = tqdm(data_loader, total=len(y)) | |
| for (spec, bpm), label in data_loader: | |
| # find all models that are in the bpm range | |
| matching_dances = self.get_valid_dances_from_bpm(bpm) | |
| spec = torch.from_numpy(spec).to(self.device) | |
| for dance in matching_dances: | |
| if dance not in self.classifiers or dance not in self.optimizers: | |
| classifier = DanceCNN().to(self.device) | |
| self.classifiers[dance] = classifier | |
| self.optimizers[dance] = torch.optim.Adam( | |
| classifier.parameters(), lr=self.lr | |
| ) | |
| models = [ | |
| (dance, model, self.optimizers[dance]) | |
| for dance, model in self.classifiers.items() | |
| if dance in matching_dances | |
| ] | |
| for model_i, (dance, model, opt) in enumerate(models, start=1): | |
| opt.zero_grad() | |
| output = model(spec) | |
| target = torch.tensor([float(dance == label)], device=self.device) | |
| loss = self.criterion(output, target) | |
| epoch_loss += loss.item() | |
| pred_count += 1 | |
| loss.backward() | |
| if self.verbose: | |
| data_loader.set_description( | |
| f"model: {model_i}/{len(models)}, loss: {loss.item()}" | |
| ) | |
| opt.step() | |
| def predict(self, x) -> list[str]: | |
| results = [] | |
| for spec, bpm in zip(*x): | |
| matching_dances = self.get_valid_dances_from_bpm(bpm) | |
| dance_i = torch.tensor( | |
| [self.classifiers[dance](spec) for dance in matching_dances] | |
| ).argmax() | |
| results.append(matching_dances[dance_i]) | |
| return results | |
| def save(self, folder: str): | |
| # Create a folder | |
| classifier_path = path.join(folder, "classifier") | |
| os.makedirs(classifier_path, exist_ok=True) | |
| # Swap out model reference | |
| classifiers = self.classifiers | |
| optimizers = self.optimizers | |
| criterion = self.criterion | |
| self.classifiers = None | |
| self.optimizers = None | |
| self.criterion = None | |
| # Save the Pth models | |
| for dance, classifier in classifiers.items(): | |
| torch.save( | |
| classifier.state_dict(), path.join(classifier_path, dance + ".pth") | |
| ) | |
| # Save the Sklearn model | |
| dump(path.join(folder, "sklearn.joblib")) | |
| # Reload values | |
| self.classifiers = classifiers | |
| self.optimizers = optimizers | |
| self.criterion = criterion | |
| def from_config(folder: str, device="cpu") -> "DanceTreeClassifier": | |
| # load in weights | |
| model_paths = ( | |
| p for p in os.listdir(path.join(folder, "classifier")) if p.endswith("pth") | |
| ) | |
| classifiers = {} | |
| for model_path in model_paths: | |
| dance = model_path.split(".")[0] | |
| model = DanceCNN().to(device) | |
| model.load_state_dict( | |
| torch.load(path.join(folder, "classifier", model_path)) | |
| ) | |
| classifiers[dance] = model | |
| wrapper = load(path.join(folder, "sklearn.joblib")) | |
| wrapper.classifiers = classifiers | |
| return wrapper | |
| class DanceCNN(nn.Module): | |
| def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| kernel_size = (3, 9) | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(1, 16, kernel_size=kernel_size), | |
| nn.ReLU(), | |
| nn.MaxPool2d((2, 10)), | |
| nn.Conv2d(16, 32, kernel_size=kernel_size), | |
| nn.ReLU(), | |
| nn.MaxPool2d((2, 10)), | |
| nn.Conv2d(32, 32, kernel_size=kernel_size), | |
| nn.ReLU(), | |
| nn.MaxPool2d((2, 10)), | |
| nn.Conv2d(32, 16, kernel_size=kernel_size), | |
| nn.ReLU(), | |
| nn.MaxPool2d((2, 10)), | |
| ) | |
| embedding_dimension = 16 * 6 * 8 | |
| self.classifier = nn.Sequential( | |
| nn.Linear(embedding_dimension, 200), | |
| nn.ReLU(), | |
| nn.Linear(200, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| x = self.cnn(x) | |
| x = x.flatten() if len(x.shape) == 3 else x.flatten(1) | |
| return self.classifier(x) | |
| def features_from_path( | |
| paths: list[str], audio_window_duration=6, audio_duration=30, resample_freq=16000 | |
| ) -> Iterator[tuple[np.array, float]]: | |
| """ | |
| Loads audio and bpm from an audio path. | |
| """ | |
| for path in paths: | |
| waveform, sr = librosa.load(path, mono=True, sr=resample_freq) | |
| num_frames = audio_window_duration * sr | |
| tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr) | |
| spec = librosa.feature.melspectrogram(y=waveform, sr=sr) | |
| spec_normalized = (spec - spec.mean()) / spec.std() | |
| spec_padded = librosa.util.fix_length( | |
| spec_normalized, size=sr * audio_duration, axis=1 | |
| ) | |
| batched_spec = np.expand_dims(spec_padded, axis=0) | |
| for i in range(audio_duration // audio_window_duration): | |
| spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames] | |
| yield (spec_window, tempo) | |
| def train_decision_tree(config: dict): | |
| TARGET_CLASSES = config["global"]["dance_ids"] | |
| DEVICE = config["global"]["device"] | |
| SEED = config["global"]["seed"] | |
| SEED = config["global"]["seed"] | |
| EPOCHS = config["trainer"]["min_epochs"] | |
| song_data_path = config["data_module"]["song_data_path"] | |
| song_audio_path = config["data_module"]["song_audio_path"] | |
| pl.seed_everything(SEED, workers=True) | |
| df = pd.read_csv(song_data_path) | |
| x, y = get_music4dance_examples( | |
| df, song_audio_path, class_list=TARGET_CLASSES, multi_label=True | |
| ) | |
| # Convert y back to string classes | |
| y = np.array(TARGET_CLASSES)[y.argmax(-1)] | |
| train_i, test_i = random_split( | |
| np.arange(len(x)), [0.1, 0.9] | |
| ) # Temporary to test efficacy | |
| train_paths, train_y = x[train_i], y[train_i] | |
| model = DanceTreeClassifier(device=DEVICE) | |
| for epoch in tqdm(range(1, EPOCHS + 1)): | |
| # Shuffle the data | |
| i = np.arange(len(train_paths)) | |
| np.random.shuffle(i) | |
| train_paths = train_paths[i] | |
| train_y = train_y[i] | |
| train_x = features_from_path(train_paths) | |
| model.fit(train_x, train_y) | |
| # evaluate the model | |
| preds = model.predict(x[test_i]) | |
| accuracy = (preds == y[test_i]).mean() | |
| print(f"{accuracy=}") | |
| model.save("models/weights/decision_tree") | |