Spaces:
Runtime error
Runtime error
| from sklearn.base import ClassifierMixin, BaseEstimator | |
| import pandas as pd | |
| from torch import nn | |
| import torch | |
| from typing import Iterator | |
| import numpy as np | |
| import json | |
| from tqdm import tqdm | |
| import librosa | |
| DANCE_INFO_FILE = "data/dance_info.csv" | |
| dance_info_df = pd.read_csv( | |
| DANCE_INFO_FILE, | |
| converters={"tempoRange": lambda s: json.loads(s.replace("'", '"'))}, | |
| ) | |
| class DanceTreeClassifier(BaseEstimator, ClassifierMixin): | |
| """ | |
| Trains a series of binary classifiers to classify each dance when a song falls into its bpm range. | |
| Features: | |
| - Spectrogram | |
| - BPM | |
| """ | |
| def __init__(self, device="cpu", lr=1e-4, epochs=5, verbose=True) -> None: | |
| self.device = device | |
| self.epochs = epochs | |
| self.verbose = verbose | |
| self.lr = lr | |
| self.classifiers = {} | |
| self.optimizers = {} | |
| self.criterion = nn.BCELoss() | |
| def get_valid_dances_from_bpm(self, bpm: float) -> list[str]: | |
| mask = dance_info_df["tempoRange"].apply( | |
| lambda interval: interval["min"] <= bpm <= interval["max"] | |
| ) | |
| return list(dance_info_df["id"][mask]) | |
| def fit(self, x, y): | |
| """ | |
| x: (specs, bpms). The first element is the spectrogram, second element is the bpm. spec shape should be (channel, freq_bins, sr * time) | |
| y: (batch_size, n_classes) | |
| """ | |
| progress_bar = tqdm(range(self.epochs)) | |
| for _ in progress_bar: | |
| # TODO: Introduce batches | |
| epoch_loss = 0 | |
| pred_count = 0 | |
| step = 0 | |
| for (spec, bpm), label in zip(x, y): | |
| step += 1 | |
| # find all models that are in the bpm range | |
| matching_dances = self.get_valid_dances_from_bpm(bpm) | |
| spec = torch.from_numpy(spec).to(self.device) | |
| for dance in matching_dances: | |
| if dance not in self.classifiers or dance not in self.optimizers: | |
| classifier = DanceCNN().to(self.device) | |
| self.classifiers[dance] = classifier | |
| self.optimizers[dance] = torch.optim.Adam( | |
| classifier.parameters(), lr=self.lr | |
| ) | |
| models = [ | |
| (dance, model, self.optimizers[dance]) | |
| for dance, model in self.classifiers.items() | |
| if dance in matching_dances | |
| ] | |
| for model_i, (dance, model, opt) in enumerate(models): | |
| opt.zero_grad() | |
| output = model(spec) | |
| target = torch.tensor([float(dance == label)], device=self.device) | |
| loss = self.criterion(output, target) | |
| epoch_loss += loss.item() | |
| pred_count += 1 | |
| loss.backward() | |
| opt.step() | |
| progress_bar.set_description( | |
| f"Loss: {epoch_loss / pred_count}, Step: {step}, Model: {model_i+1}/{len(models)}" | |
| ) | |
| def predict(self, x) -> list[str]: | |
| results = [] | |
| for spec, bpm in zip(*x): | |
| matching_dances = self.get_valid_dances_from_bpm(bpm) | |
| dance_i = torch.tensor( | |
| [self.classifiers[dance](spec) for dance in matching_dances] | |
| ).argmax() | |
| results.append(matching_dances[dance_i]) | |
| return results | |
| class DanceCNN(nn.Module): | |
| def __init__(self, sr=16000, freq_bins=20, duration=6, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| kernel_size = (3, 9) | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(1, 16, kernel_size=kernel_size), | |
| nn.ReLU(), | |
| nn.MaxPool2d((2, 10)), | |
| nn.Conv2d(16, 32, kernel_size=kernel_size), | |
| nn.ReLU(), | |
| nn.MaxPool2d((2, 10)), | |
| nn.Conv2d(32, 32, kernel_size=kernel_size), | |
| nn.ReLU(), | |
| nn.MaxPool2d((2, 10)), | |
| nn.Conv2d(32, 16, kernel_size=kernel_size), | |
| nn.ReLU(), | |
| nn.MaxPool2d((2, 10)), | |
| ) | |
| embedding_dimension = 16 * 6 * 8 | |
| self.classifier = nn.Sequential( | |
| nn.Linear(embedding_dimension, 200), | |
| nn.ReLU(), | |
| nn.Linear(200, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| x = self.cnn(x) | |
| x = x.flatten() if len(x.shape) == 3 else x.flatten(1) | |
| return self.classifier(x) | |
| def features_from_path( | |
| paths: list[str], audio_window_duration=6, audio_duration=30, resample_freq=16000 | |
| ) -> Iterator[tuple[np.array, float]]: | |
| """ | |
| Loads audio and bpm from an audio path. | |
| """ | |
| for path in paths: | |
| waveform, sr = librosa.load(path, mono=True, sr=resample_freq) | |
| num_frames = audio_window_duration * sr | |
| tempo, _ = librosa.beat.beat_track(y=waveform, sr=sr) | |
| spec = librosa.feature.melspectrogram(y=waveform, sr=sr) | |
| mfccs = librosa.feature.mfcc(y=waveform, sr=sr, n_mfcc=20) | |
| spec_normalized = (spec - spec.mean()) / spec.std() | |
| spec_padded = librosa.util.fix_length( | |
| spec_normalized, size=sr * audio_duration, axis=1 | |
| ) | |
| batched_spec = np.expand_dims(spec_padded, axis=0) | |
| for i in range(audio_duration // audio_window_duration): | |
| spec_window = batched_spec[:, :, i * num_frames : (i + 1) * num_frames] | |
| yield (spec_window, tempo) | |