Spaces:
Runtime error
Runtime error
| import pytorch_lightning as pl | |
| from pytorch_lightning import callbacks as cb | |
| import torch | |
| from torch import nn | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import torchaudio | |
| import yaml | |
| from models.training_environment import TrainingEnvironment | |
| from preprocessing.dataset import DanceDataModule, get_datasets | |
| from preprocessing.pipelines import ( | |
| SpectrogramTrainingPipeline, | |
| WaveformPreprocessing, | |
| ) | |
| # Architecture based on: https://github.com/minzwon/sota-music-tagging-models/blob/36aa13b7205ff156cf4dcab60fd69957da453151/training/model.py | |
| class ResidualDancer(nn.Module): | |
| def __init__(self, n_channels=128, n_classes=50): | |
| super().__init__() | |
| self.n_channels = n_channels | |
| self.n_classes = n_classes | |
| # Spectrogram | |
| self.spec_bn = nn.BatchNorm2d(1) | |
| # CNN | |
| self.res_layers = nn.Sequential( | |
| ResBlock(1, n_channels, stride=2), | |
| ResBlock(n_channels, n_channels, stride=2), | |
| ResBlock(n_channels, n_channels * 2, stride=2), | |
| ResBlock(n_channels * 2, n_channels * 2, stride=2), | |
| ResBlock(n_channels * 2, n_channels * 2, stride=2), | |
| ResBlock(n_channels * 2, n_channels * 2, stride=2), | |
| ResBlock(n_channels * 2, n_channels * 4, stride=2), | |
| ) | |
| # Dense | |
| self.dense1 = nn.Linear(n_channels * 4, n_channels * 4) | |
| self.bn = nn.BatchNorm1d(n_channels * 4) | |
| self.dense2 = nn.Linear(n_channels * 4, n_classes) | |
| self.dropout = nn.Dropout(0.2) | |
| def forward(self, x): | |
| x = self.spec_bn(x) | |
| # CNN | |
| x = self.res_layers(x) | |
| x = x.squeeze(2) | |
| # Global Max Pooling | |
| if x.size(-1) != 1: | |
| x = nn.MaxPool1d(x.size(-1))(x) | |
| x = x.squeeze(2) | |
| # Dense | |
| x = self.dense1(x) | |
| x = self.bn(x) | |
| x = F.relu(x) | |
| x = self.dropout(x) | |
| x = self.dense2(x) | |
| # x = nn.Sigmoid()(x) | |
| return x | |
| class ResBlock(nn.Module): | |
| def __init__(self, input_channels, output_channels, shape=3, stride=2): | |
| super().__init__() | |
| # convolution | |
| self.conv_1 = nn.Conv2d( | |
| input_channels, output_channels, shape, stride=stride, padding=shape // 2 | |
| ) | |
| self.bn_1 = nn.BatchNorm2d(output_channels) | |
| self.conv_2 = nn.Conv2d( | |
| output_channels, output_channels, shape, padding=shape // 2 | |
| ) | |
| self.bn_2 = nn.BatchNorm2d(output_channels) | |
| # residual | |
| self.diff = False | |
| if (stride != 1) or (input_channels != output_channels): | |
| self.conv_3 = nn.Conv2d( | |
| input_channels, | |
| output_channels, | |
| shape, | |
| stride=stride, | |
| padding=shape // 2, | |
| ) | |
| self.bn_3 = nn.BatchNorm2d(output_channels) | |
| self.diff = True | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| # convolution | |
| out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) | |
| # residual | |
| if self.diff: | |
| x = self.bn_3(self.conv_3(x)) | |
| out = x + out | |
| out = self.relu(out) | |
| return out | |
| class DancePredictor: | |
| def __init__( | |
| self, | |
| weight_path: str, | |
| labels: list[str], | |
| expected_duration=6, | |
| threshold=0.5, | |
| resample_frequency=16000, | |
| device="cpu", | |
| ): | |
| super().__init__() | |
| self.expected_duration = expected_duration | |
| self.threshold = threshold | |
| self.resample_frequency = resample_frequency | |
| self.preprocess_waveform = WaveformPreprocessing( | |
| resample_frequency * expected_duration | |
| ) | |
| self.audio_to_spectrogram = lambda x: x # TODO: Fix | |
| self.labels = np.array(labels) | |
| self.device = device | |
| self.model = self.get_model(weight_path) | |
| def get_model(self, weight_path: str) -> nn.Module: | |
| weights = torch.load(weight_path, map_location=self.device)["state_dict"] | |
| model = ResidualDancer(n_classes=len(self.labels)) | |
| for key in list(weights): | |
| weights[key.replace("model.", "")] = weights.pop(key) | |
| model.load_state_dict(weights) | |
| return model.to(self.device).eval() | |
| def from_config(cls, config_path: str) -> "DancePredictor": | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| return DancePredictor(**config) | |
| def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]: | |
| if len(waveform.shape) > 1 and waveform.shape[1] < waveform.shape[0]: | |
| waveform = waveform.transpose(1, 0) | |
| elif len(waveform.shape) == 1: | |
| waveform = np.expand_dims(waveform, 0) | |
| waveform = torch.from_numpy(waveform.astype("int16")) | |
| waveform = torchaudio.functional.apply_codec( | |
| waveform, sample_rate, "wav", channels_first=True | |
| ) | |
| waveform = torchaudio.functional.resample( | |
| waveform, sample_rate, self.resample_frequency | |
| ) | |
| waveform = self.preprocess_waveform(waveform) | |
| spectrogram = self.audio_to_spectrogram(waveform) | |
| spectrogram = spectrogram.unsqueeze(0).to(self.device) | |
| results = self.model(spectrogram) | |
| results = results.squeeze(0).detach().cpu().numpy() | |
| result_mask = results > self.threshold | |
| probs = results[result_mask] | |
| dances = self.labels[result_mask] | |
| return {dance: float(prob) for dance, prob in zip(dances, probs)} | |
| def train_residual_dancer(config: dict): | |
| TARGET_CLASSES = config["dance_ids"] | |
| DEVICE = config["device"] | |
| SEED = config["seed"] | |
| pl.seed_everything(SEED, workers=True) | |
| feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"]) | |
| dataset = get_datasets(config["datasets"], feature_extractor) | |
| data = DanceDataModule(dataset, **config["data_module"]) | |
| model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"]) | |
| label_weights = data.get_label_weights().to(DEVICE) | |
| criterion = nn.CrossEntropyLoss(label_weights) | |
| train_env = TrainingEnvironment(model, criterion, config) | |
| callbacks = [ | |
| # cb.LearningRateFinder(update_attr=True), | |
| cb.EarlyStopping("val/loss", patience=5), | |
| cb.StochasticWeightAveraging(1e-2), | |
| cb.RichProgressBar(), | |
| cb.DeviceStatsMonitor(), | |
| ] | |
| trainer = pl.Trainer(callbacks=callbacks, **config["trainer"]) | |
| trainer.fit(train_env, datamodule=data) | |
| trainer.test(train_env, datamodule=data) | |