Spaces:
Runtime error
Runtime error
Commit
·
ba35f85
1
Parent(s):
9f53273
updated production weights
Browse files- app.py +27 -15
- assets/song-samples/besame_mucho.wav +3 -0
- models/config/production.yaml +20 -0
- models/config/train_local.yaml +6 -6
- models/residual.py +2 -2
- models/weights/ResidualDancer/weights.ckpt +2 -2
- preprocessing/dataset.py +6 -2
- preprocessing/pipelines.py +15 -0
app.py
CHANGED
|
@@ -2,18 +2,21 @@ from pathlib import Path
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import os
|
|
|
|
| 5 |
from functools import cache
|
| 6 |
from pathlib import Path
|
| 7 |
-
from models.
|
| 8 |
from models.training_environment import TrainingEnvironment
|
|
|
|
| 9 |
import torch
|
| 10 |
from torch import nn
|
| 11 |
import yaml
|
| 12 |
import torchaudio
|
| 13 |
|
| 14 |
-
CONFIG_FILE = Path("models/config/
|
| 15 |
-
MODEL_CLS =
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class DancePredictor:
|
|
@@ -22,7 +25,7 @@ class DancePredictor:
|
|
| 22 |
weight_path: str,
|
| 23 |
labels: list[str],
|
| 24 |
expected_duration=6,
|
| 25 |
-
threshold=0.
|
| 26 |
resample_frequency=16000,
|
| 27 |
device="cpu",
|
| 28 |
):
|
|
@@ -35,11 +38,13 @@ class DancePredictor:
|
|
| 35 |
self.labels = np.array(labels)
|
| 36 |
self.device = device
|
| 37 |
self.model = self.get_model(weight_path)
|
| 38 |
-
self.extractor =
|
| 39 |
|
| 40 |
def get_model(self, weight_path: str) -> nn.Module:
|
| 41 |
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
for key in list(weights):
|
| 44 |
weights[
|
| 45 |
key.replace(
|
|
@@ -56,10 +61,12 @@ class DancePredictor:
|
|
| 56 |
config = yaml.safe_load(f)
|
| 57 |
weight_path = config["checkpoint"]
|
| 58 |
labels = sorted(config["dance_ids"])
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
return DancePredictor(
|
| 64 |
weight_path,
|
| 65 |
labels,
|
|
@@ -81,9 +88,6 @@ class DancePredictor:
|
|
| 81 |
waveform = torchaudio.functional.resample(
|
| 82 |
waveform, sample_rate, self.resample_frequency
|
| 83 |
)
|
| 84 |
-
waveform = waveform[
|
| 85 |
-
:, : self.resample_frequency * self.expected_duration
|
| 86 |
-
] # TODO PAD
|
| 87 |
features = self.extractor(waveform)
|
| 88 |
features = features.unsqueeze(0).to(self.device)
|
| 89 |
results = self.model(features)
|
|
@@ -103,7 +107,15 @@ def get_model(config_path: str) -> DancePredictor:
|
|
| 103 |
return model
|
| 104 |
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
|
|
|
|
|
|
| 107 |
sample_rate, waveform = audio
|
| 108 |
|
| 109 |
model = get_model(CONFIG_FILE)
|
|
@@ -116,7 +128,7 @@ def demo():
|
|
| 116 |
description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
|
| 117 |
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
|
| 118 |
example_audio = [
|
| 119 |
-
str(song) for song in song_samples.iterdir() if song.name
|
| 120 |
]
|
| 121 |
all_dances = get_model(CONFIG_FILE).labels
|
| 122 |
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import numpy as np
|
| 4 |
import os
|
| 5 |
+
import pandas as pd
|
| 6 |
from functools import cache
|
| 7 |
from pathlib import Path
|
| 8 |
+
from models.residual import ResidualDancer
|
| 9 |
from models.training_environment import TrainingEnvironment
|
| 10 |
+
from preprocessing.pipelines import SpectrogramProductionPipeline
|
| 11 |
import torch
|
| 12 |
from torch import nn
|
| 13 |
import yaml
|
| 14 |
import torchaudio
|
| 15 |
|
| 16 |
+
CONFIG_FILE = Path("models/config/production.yaml")
|
| 17 |
+
MODEL_CLS = ResidualDancer
|
| 18 |
+
|
| 19 |
+
DANCE_MAPPING_FILE = Path("data/dance_mapping.csv")
|
| 20 |
|
| 21 |
|
| 22 |
class DancePredictor:
|
|
|
|
| 25 |
weight_path: str,
|
| 26 |
labels: list[str],
|
| 27 |
expected_duration=6,
|
| 28 |
+
threshold=0.1,
|
| 29 |
resample_frequency=16000,
|
| 30 |
device="cpu",
|
| 31 |
):
|
|
|
|
| 38 |
self.labels = np.array(labels)
|
| 39 |
self.device = device
|
| 40 |
self.model = self.get_model(weight_path)
|
| 41 |
+
self.extractor = SpectrogramProductionPipeline()
|
| 42 |
|
| 43 |
def get_model(self, weight_path: str) -> nn.Module:
|
| 44 |
weights = torch.load(weight_path, map_location=self.device)["state_dict"]
|
| 45 |
+
n_classes = len(self.labels)
|
| 46 |
+
# NOTE: Channels are not taken into account
|
| 47 |
+
model = ResidualDancer(n_classes=n_classes).to(self.device)
|
| 48 |
for key in list(weights):
|
| 49 |
weights[
|
| 50 |
key.replace(
|
|
|
|
| 61 |
config = yaml.safe_load(f)
|
| 62 |
weight_path = config["checkpoint"]
|
| 63 |
labels = sorted(config["dance_ids"])
|
| 64 |
+
dance_mapping = get_dance_mapping(DANCE_MAPPING_FILE)
|
| 65 |
+
labels = [dance_mapping[label] for label in labels]
|
| 66 |
+
expected_duration = config.get("expected_duration", 6)
|
| 67 |
+
threshold = config.get("threshold", 0.1)
|
| 68 |
+
resample_frequency = config.get("resample_frequency", 16000)
|
| 69 |
+
device = config.get("device", "cpu")
|
| 70 |
return DancePredictor(
|
| 71 |
weight_path,
|
| 72 |
labels,
|
|
|
|
| 88 |
waveform = torchaudio.functional.resample(
|
| 89 |
waveform, sample_rate, self.resample_frequency
|
| 90 |
)
|
|
|
|
|
|
|
|
|
|
| 91 |
features = self.extractor(waveform)
|
| 92 |
features = features.unsqueeze(0).to(self.device)
|
| 93 |
results = self.model(features)
|
|
|
|
| 107 |
return model
|
| 108 |
|
| 109 |
|
| 110 |
+
@cache
|
| 111 |
+
def get_dance_mapping(mapping_file: str) -> dict[str, str]:
|
| 112 |
+
mapping_df = pd.read_csv(mapping_file)
|
| 113 |
+
return {row["id"]: row["name"] for _, row in mapping_df.iterrows()}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
def predict(audio: tuple[int, np.ndarray]) -> list[str]:
|
| 117 |
+
if audio is None:
|
| 118 |
+
return "Dance Not Found"
|
| 119 |
sample_rate, waveform = audio
|
| 120 |
|
| 121 |
model = get_model(CONFIG_FILE)
|
|
|
|
| 128 |
description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!"
|
| 129 |
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples")
|
| 130 |
example_audio = [
|
| 131 |
+
str(song) for song in song_samples.iterdir() if not song.name.startswith(".")
|
| 132 |
]
|
| 133 |
all_dances = get_model(CONFIG_FILE).labels
|
| 134 |
|
assets/song-samples/besame_mucho.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:14ccffab50d9119ec5250fc84e09542dbbf350450102c108ab61846a3c3031c8
|
| 3 |
+
size 5290062
|
models/config/production.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint: models/weights/ResidualDancer/weights.ckpt
|
| 2 |
+
device: cpu
|
| 3 |
+
seed: 42
|
| 4 |
+
dance_ids: &dance_ids
|
| 5 |
+
- BCH
|
| 6 |
+
- CHA
|
| 7 |
+
- JIV
|
| 8 |
+
- ECS
|
| 9 |
+
- QST
|
| 10 |
+
- RMB
|
| 11 |
+
- SFT
|
| 12 |
+
- SLS
|
| 13 |
+
- SMB
|
| 14 |
+
- SWZ
|
| 15 |
+
- TGO
|
| 16 |
+
- VWZ
|
| 17 |
+
- WCS
|
| 18 |
+
|
| 19 |
+
model:
|
| 20 |
+
n_channels: 128
|
models/config/train_local.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
training_fn:
|
| 2 |
-
checkpoint: lightning_logs/
|
| 3 |
device: mps
|
| 4 |
seed: 42
|
| 5 |
dance_ids: &dance_ids
|
|
@@ -24,10 +24,10 @@ data_module:
|
|
| 24 |
test_proportion: 0.2
|
| 25 |
|
| 26 |
datasets:
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
|
| 32 |
preprocessing.dataset.Music4DanceDataset:
|
| 33 |
song_data_path: data/songs_cleaned.csv
|
|
|
|
| 1 |
+
training_fn: residual.train_residual_dancer
|
| 2 |
+
checkpoint: lightning_logs/version_176/checkpoints/epoch=12-step=40404.ckpt
|
| 3 |
device: mps
|
| 4 |
seed: 42
|
| 5 |
dance_ids: &dance_ids
|
|
|
|
| 24 |
test_proportion: 0.2
|
| 25 |
|
| 26 |
datasets:
|
| 27 |
+
preprocessing.dataset.BestBallroomDataset:
|
| 28 |
+
audio_dir: data/ballroom-songs
|
| 29 |
+
class_list: *dance_ids
|
| 30 |
+
audio_window_jitter: 0.7
|
| 31 |
|
| 32 |
preprocessing.dataset.Music4DanceDataset:
|
| 33 |
song_data_path: data/songs_cleaned.csv
|
models/residual.py
CHANGED
|
@@ -110,7 +110,7 @@ def train_residual_dancer(config: dict):
|
|
| 110 |
TARGET_CLASSES = config["dance_ids"]
|
| 111 |
DEVICE = config["device"]
|
| 112 |
SEED = config["seed"]
|
| 113 |
-
torch.set_float32_matmul_precision(
|
| 114 |
pl.seed_everything(SEED, workers=True)
|
| 115 |
feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
|
| 116 |
dataset = get_datasets(config["datasets"], feature_extractor)
|
|
@@ -123,7 +123,7 @@ def train_residual_dancer(config: dict):
|
|
| 123 |
train_env = TrainingEnvironment(model, criterion, config)
|
| 124 |
callbacks = [
|
| 125 |
# cb.LearningRateFinder(update_attr=True),
|
| 126 |
-
cb.EarlyStopping("val/loss", patience=
|
| 127 |
cb.StochasticWeightAveraging(1e-2),
|
| 128 |
cb.RichProgressBar(),
|
| 129 |
cb.DeviceStatsMonitor(),
|
|
|
|
| 110 |
TARGET_CLASSES = config["dance_ids"]
|
| 111 |
DEVICE = config["device"]
|
| 112 |
SEED = config["seed"]
|
| 113 |
+
torch.set_float32_matmul_precision("medium")
|
| 114 |
pl.seed_everything(SEED, workers=True)
|
| 115 |
feature_extractor = SpectrogramTrainingPipeline(**config["feature_extractor"])
|
| 116 |
dataset = get_datasets(config["datasets"], feature_extractor)
|
|
|
|
| 123 |
train_env = TrainingEnvironment(model, criterion, config)
|
| 124 |
callbacks = [
|
| 125 |
# cb.LearningRateFinder(update_attr=True),
|
| 126 |
+
cb.EarlyStopping("val/loss", patience=1),
|
| 127 |
cb.StochasticWeightAveraging(1e-2),
|
| 128 |
cb.RichProgressBar(),
|
| 129 |
cb.DeviceStatsMonitor(),
|
models/weights/ResidualDancer/weights.ckpt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90a58841ce4f40f2981227b63dd848e474e8868795a57da84053e3281c4889c7
|
| 3 |
+
size 193643085
|
preprocessing/dataset.py
CHANGED
|
@@ -78,8 +78,8 @@ class SongDataset(Dataset):
|
|
| 78 |
return waveform, dance_labels
|
| 79 |
else:
|
| 80 |
# WARNING: Could cause train/test split leak
|
| 81 |
-
|
| 82 |
-
|
| 83 |
|
| 84 |
def _idx2audio_idx(self, idx: int) -> int:
|
| 85 |
return self._get_audio_loc_from_idx(idx)[0]
|
|
@@ -424,3 +424,7 @@ def record_audio_durations(folder: str):
|
|
| 424 |
|
| 425 |
with open(os.path.join(folder, "audio_durations.json"), "w") as f:
|
| 426 |
json.dump(durations, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
return waveform, dance_labels
|
| 79 |
else:
|
| 80 |
# WARNING: Could cause train/test split leak
|
| 81 |
+
print("Invalid output, trying next index...")
|
| 82 |
+
return self[idx - 1]
|
| 83 |
|
| 84 |
def _idx2audio_idx(self, idx: int) -> int:
|
| 85 |
return self._get_audio_loc_from_idx(idx)[0]
|
|
|
|
| 424 |
|
| 425 |
with open(os.path.join(folder, "audio_durations.json"), "w") as f:
|
| 426 |
json.dump(durations, f)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class GTZAN:
|
| 430 |
+
pass
|
preprocessing/pipelines.py
CHANGED
|
@@ -74,6 +74,21 @@ class SpectrogramTrainingPipeline(WaveformTrainingPipeline):
|
|
| 74 |
return spec
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
class WaveformPreprocessing(torch.nn.Module):
|
| 78 |
def __init__(self, expected_sample_length: int):
|
| 79 |
super().__init__()
|
|
|
|
| 74 |
return spec
|
| 75 |
|
| 76 |
|
| 77 |
+
class SpectrogramProductionPipeline(torch.nn.Module):
|
| 78 |
+
def __init__(self, sample_rate=16000, expected_duration=6, *args, **kwargs) -> None:
|
| 79 |
+
super().__init__(*args, **kwargs)
|
| 80 |
+
self.preprocess_waveform = WaveformPreprocessing(
|
| 81 |
+
sample_rate * expected_duration
|
| 82 |
+
)
|
| 83 |
+
self.audio_to_spectrogram = AudioToSpectrogram(
|
| 84 |
+
sample_rate=sample_rate,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
waveform = self.preprocess_waveform(waveform)
|
| 89 |
+
return self.audio_to_spectrogram(waveform)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
class WaveformPreprocessing(torch.nn.Module):
|
| 93 |
def __init__(self, expected_sample_length: int):
|
| 94 |
super().__init__()
|