Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						51f4763
	
1
								Parent(s):
							
							17a2a7d
								
updated production build to use multiple overlapping samples
Browse files- app.py +23 -5
 - models/config/train_local.yaml +8 -9
 - models/residual.py +2 -5
 - preprocessing/dataset.py +1 -5
 - preprocessing/pipelines.py +10 -6
 
    	
        app.py
    CHANGED
    
    | 
         @@ -7,7 +7,7 @@ from functools import cache 
     | 
|
| 7 | 
         
             
            from pathlib import Path
         
     | 
| 8 | 
         
             
            from models.residual import ResidualDancer
         
     | 
| 9 | 
         
             
            from models.training_environment import TrainingEnvironment
         
     | 
| 10 | 
         
            -
            from preprocessing.pipelines import SpectrogramProductionPipeline
         
     | 
| 11 | 
         
             
            import torch
         
     | 
| 12 | 
         
             
            from torch import nn
         
     | 
| 13 | 
         
             
            import yaml
         
     | 
| 
         @@ -17,6 +17,8 @@ CONFIG_FILE = Path("models/weights/ResidualDancer/multilabel/config.yaml") 
     | 
|
| 17 | 
         | 
| 18 | 
         
             
            DANCE_MAPPING_FILE = Path("data/dance_mapping.csv")
         
     | 
| 19 | 
         | 
| 
         | 
|
| 
         | 
|
| 20 | 
         | 
| 21 | 
         
             
            class DancePredictor:
         
     | 
| 22 | 
         
             
                def __init__(
         
     | 
| 
         @@ -37,6 +39,9 @@ class DancePredictor: 
     | 
|
| 37 | 
         
             
                    self.labels = np.array(labels)
         
     | 
| 38 | 
         
             
                    self.device = device
         
     | 
| 39 | 
         
             
                    self.model = self.get_model(weight_path)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 40 | 
         
             
                    self.extractor = SpectrogramProductionPipeline()
         
     | 
| 41 | 
         | 
| 42 | 
         
             
                def get_model(self, weight_path: str) -> nn.Module:
         
     | 
| 
         @@ -87,10 +92,21 @@ class DancePredictor: 
     | 
|
| 87 | 
         
             
                    waveform = torchaudio.functional.resample(
         
     | 
| 88 | 
         
             
                        waveform, sample_rate, self.resample_frequency
         
     | 
| 89 | 
         
             
                    )
         
     | 
| 90 | 
         
            -
                     
     | 
| 91 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 92 | 
         
             
                    results = self.model(features)
         
     | 
| 93 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 94 | 
         
             
                    results = results.detach().cpu().numpy()
         
     | 
| 95 | 
         | 
| 96 | 
         
             
                    result_mask = results > self.threshold
         
     | 
| 
         @@ -116,6 +132,9 @@ def predict(audio: tuple[int, np.ndarray]) -> list[str]: 
     | 
|
| 116 | 
         
             
                if audio is None:
         
     | 
| 117 | 
         
             
                    return "Dance Not Found"
         
     | 
| 118 | 
         
             
                sample_rate, waveform = audio
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 119 | 
         | 
| 120 | 
         
             
                model = get_model(CONFIG_FILE)
         
     | 
| 121 | 
         
             
                results = model(waveform, sample_rate)
         
     | 
| 
         @@ -133,7 +152,6 @@ def demo(): 
     | 
|
| 133 | 
         | 
| 134 | 
         
             
                recording_interface = gr.Interface(
         
     | 
| 135 | 
         
             
                    fn=predict,
         
     | 
| 136 | 
         
            -
                    description="Record at least **6 seconds** of the song.",
         
     | 
| 137 | 
         
             
                    inputs=gr.Audio(source="microphone", label="Song Recording"),
         
     | 
| 138 | 
         
             
                    outputs=gr.Label(label="Dances"),
         
     | 
| 139 | 
         
             
                    examples=example_audio,
         
     | 
| 
         | 
|
| 7 | 
         
             
            from pathlib import Path
         
     | 
| 8 | 
         
             
            from models.residual import ResidualDancer
         
     | 
| 9 | 
         
             
            from models.training_environment import TrainingEnvironment
         
     | 
| 10 | 
         
            +
            from preprocessing.pipelines import SpectrogramProductionPipeline, WaveformPreprocessing
         
     | 
| 11 | 
         
             
            import torch
         
     | 
| 12 | 
         
             
            from torch import nn
         
     | 
| 13 | 
         
             
            import yaml
         
     | 
| 
         | 
|
| 17 | 
         | 
| 18 | 
         
             
            DANCE_MAPPING_FILE = Path("data/dance_mapping.csv")
         
     | 
| 19 | 
         | 
| 20 | 
         
            +
            MIN_DURATION = 3.0
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         | 
| 23 | 
         
             
            class DancePredictor:
         
     | 
| 24 | 
         
             
                def __init__(
         
     | 
| 
         | 
|
| 39 | 
         
             
                    self.labels = np.array(labels)
         
     | 
| 40 | 
         
             
                    self.device = device
         
     | 
| 41 | 
         
             
                    self.model = self.get_model(weight_path)
         
     | 
| 42 | 
         
            +
                    self.process_waveform = WaveformPreprocessing(
         
     | 
| 43 | 
         
            +
                        resample_frequency * expected_duration
         
     | 
| 44 | 
         
            +
                    )
         
     | 
| 45 | 
         
             
                    self.extractor = SpectrogramProductionPipeline()
         
     | 
| 46 | 
         | 
| 47 | 
         
             
                def get_model(self, weight_path: str) -> nn.Module:
         
     | 
| 
         | 
|
| 92 | 
         
             
                    waveform = torchaudio.functional.resample(
         
     | 
| 93 | 
         
             
                        waveform, sample_rate, self.resample_frequency
         
     | 
| 94 | 
         
             
                    )
         
     | 
| 95 | 
         
            +
                    window_size = self.resample_frequency * self.expected_duration
         
     | 
| 96 | 
         
            +
                    n_preds = int(waveform.shape[1] // (window_size / 2))
         
     | 
| 97 | 
         
            +
                    step_size = int(waveform.shape[1] / n_preds)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    inputs = [
         
     | 
| 100 | 
         
            +
                        waveform[:, i * step_size : i * step_size + window_size]
         
     | 
| 101 | 
         
            +
                        for i in range(n_preds)
         
     | 
| 102 | 
         
            +
                    ]
         
     | 
| 103 | 
         
            +
                    features = [self.extractor(window) for window in inputs]
         
     | 
| 104 | 
         
            +
                    features = torch.stack(features).to(self.device)
         
     | 
| 105 | 
         
             
                    results = self.model(features)
         
     | 
| 106 | 
         
            +
                    # Convert to probabilities
         
     | 
| 107 | 
         
            +
                    results = nn.functional.softmax(results, dim=1)
         
     | 
| 108 | 
         
            +
                    # Take average prediction over all of the windows
         
     | 
| 109 | 
         
            +
                    results = results.mean(dim=0)
         
     | 
| 110 | 
         
             
                    results = results.detach().cpu().numpy()
         
     | 
| 111 | 
         | 
| 112 | 
         
             
                    result_mask = results > self.threshold
         
     | 
| 
         | 
|
| 132 | 
         
             
                if audio is None:
         
     | 
| 133 | 
         
             
                    return "Dance Not Found"
         
     | 
| 134 | 
         
             
                sample_rate, waveform = audio
         
     | 
| 135 | 
         
            +
                duration = len(waveform) / sample_rate
         
     | 
| 136 | 
         
            +
                if duration < MIN_DURATION:
         
     | 
| 137 | 
         
            +
                    return f"Please record at least {MIN_DURATION} seconds of audio"
         
     | 
| 138 | 
         | 
| 139 | 
         
             
                model = get_model(CONFIG_FILE)
         
     | 
| 140 | 
         
             
                results = model(waveform, sample_rate)
         
     | 
| 
         | 
|
| 152 | 
         | 
| 153 | 
         
             
                recording_interface = gr.Interface(
         
     | 
| 154 | 
         
             
                    fn=predict,
         
     | 
| 
         | 
|
| 155 | 
         
             
                    inputs=gr.Audio(source="microphone", label="Song Recording"),
         
     | 
| 156 | 
         
             
                    outputs=gr.Label(label="Dances"),
         
     | 
| 157 | 
         
             
                    examples=example_audio,
         
     | 
    	
        models/config/train_local.yaml
    CHANGED
    
    | 
         @@ -1,12 +1,15 @@ 
     | 
|
| 1 | 
         
             
            training_fn: residual.train_residual_dancer
         
     | 
| 2 | 
         
            -
            checkpoint: lightning_logs/version_176/checkpoints/epoch=12-step=40404.ckpt
         
     | 
| 3 | 
         
             
            device: mps
         
     | 
| 4 | 
         
             
            seed: 42
         
     | 
| 5 | 
         
             
            dance_ids: &dance_ids
         
     | 
| 6 | 
         
             
              - BCH
         
     | 
| 
         | 
|
| 7 | 
         
             
              - CHA
         
     | 
| 8 | 
         
            -
              - JIV
         
     | 
| 9 | 
         
             
              - ECS
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 10 | 
         
             
              - QST
         
     | 
| 11 | 
         
             
              - RMB
         
     | 
| 12 | 
         
             
              - SFT
         
     | 
| 
         @@ -20,8 +23,7 @@ dance_ids: &dance_ids 
     | 
|
| 20 | 
         
             
            data_module:
         
     | 
| 21 | 
         
             
              batch_size: 128
         
     | 
| 22 | 
         
             
              num_workers: 10
         
     | 
| 23 | 
         
            -
               
     | 
| 24 | 
         
            -
              test_proportion: 0.001
         
     | 
| 25 | 
         | 
| 26 | 
         
             
            datasets:
         
     | 
| 27 | 
         
             
              preprocessing.dataset.BestBallroomDataset:
         
     | 
| 
         @@ -31,7 +33,7 @@ datasets: 
     | 
|
| 31 | 
         | 
| 32 | 
         
             
              preprocessing.dataset.Music4DanceDataset:
         
     | 
| 33 | 
         
             
                song_data_path: data/songs_cleaned.csv
         
     | 
| 34 | 
         
            -
                song_audio_path: data/samples 
     | 
| 35 | 
         
             
                class_list: *dance_ids
         
     | 
| 36 | 
         
             
                multi_label: True
         
     | 
| 37 | 
         
             
                min_votes: 1
         
     | 
| 
         @@ -56,7 +58,4 @@ trainer: 
     | 
|
| 56 | 
         
             
              # overfit_batches: 1
         
     | 
| 57 | 
         | 
| 58 | 
         
             
            training_environment:
         
     | 
| 59 | 
         
            -
              learning_rate: 0. 
     | 
| 60 | 
         
            -
              # loggers:
         
     | 
| 61 | 
         
            -
              #   models.training_environment.SpectrogramLogger:
         
     | 
| 62 | 
         
            -
              #     frequency: 100
         
     | 
| 
         | 
|
| 1 | 
         
             
            training_fn: residual.train_residual_dancer
         
     | 
| 
         | 
|
| 2 | 
         
             
            device: mps
         
     | 
| 3 | 
         
             
            seed: 42
         
     | 
| 4 | 
         
             
            dance_ids: &dance_ids
         
     | 
| 5 | 
         
             
              - BCH
         
     | 
| 6 | 
         
            +
              - BOL
         
     | 
| 7 | 
         
             
              - CHA
         
     | 
| 
         | 
|
| 8 | 
         
             
              - ECS
         
     | 
| 9 | 
         
            +
              - HST
         
     | 
| 10 | 
         
            +
              - LHP
         
     | 
| 11 | 
         
            +
              - NC2
         
     | 
| 12 | 
         
            +
              - JIV
         
     | 
| 13 | 
         
             
              - QST
         
     | 
| 14 | 
         
             
              - RMB
         
     | 
| 15 | 
         
             
              - SFT
         
     | 
| 
         | 
|
| 23 | 
         
             
            data_module:
         
     | 
| 24 | 
         
             
              batch_size: 128
         
     | 
| 25 | 
         
             
              num_workers: 10
         
     | 
| 26 | 
         
            +
              test_proportion: 0.15
         
     | 
| 
         | 
|
| 27 | 
         | 
| 28 | 
         
             
            datasets:
         
     | 
| 29 | 
         
             
              preprocessing.dataset.BestBallroomDataset:
         
     | 
| 
         | 
|
| 33 | 
         | 
| 34 | 
         
             
              preprocessing.dataset.Music4DanceDataset:
         
     | 
| 35 | 
         
             
                song_data_path: data/songs_cleaned.csv
         
     | 
| 36 | 
         
            +
                song_audio_path: data/samples
         
     | 
| 37 | 
         
             
                class_list: *dance_ids
         
     | 
| 38 | 
         
             
                multi_label: True
         
     | 
| 39 | 
         
             
                min_votes: 1
         
     | 
| 
         | 
|
| 58 | 
         
             
              # overfit_batches: 1
         
     | 
| 59 | 
         | 
| 60 | 
         
             
            training_environment:
         
     | 
| 61 | 
         
            +
              learning_rate: 0.00053
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        models/residual.py
    CHANGED
    
    | 
         @@ -119,14 +119,11 @@ def train_residual_dancer(config: dict): 
     | 
|
| 119 | 
         
             
                data = DanceDataModule(dataset, **config["data_module"])
         
     | 
| 120 | 
         
             
                model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
         
     | 
| 121 | 
         
             
                label_weights = data.get_label_weights().to(DEVICE)
         
     | 
| 122 | 
         
            -
                criterion = LabelWeightedBCELoss(
         
     | 
| 123 | 
         
            -
                    label_weights
         
     | 
| 124 | 
         
            -
                )  # nn.CrossEntropyLoss(label_weights)
         
     | 
| 125 | 
         | 
| 126 | 
         
             
                train_env = TrainingEnvironment(model, criterion, config)
         
     | 
| 127 | 
         
             
                callbacks = [
         
     | 
| 128 | 
         
            -
                     
     | 
| 129 | 
         
            -
                    cb.EarlyStopping("val/loss", patience=1),
         
     | 
| 130 | 
         
             
                    cb.StochasticWeightAveraging(1e-2),
         
     | 
| 131 | 
         
             
                    cb.RichProgressBar(),
         
     | 
| 132 | 
         
             
                ]
         
     | 
| 
         | 
|
| 119 | 
         
             
                data = DanceDataModule(dataset, **config["data_module"])
         
     | 
| 120 | 
         
             
                model = ResidualDancer(n_classes=len(TARGET_CLASSES), **config["model"])
         
     | 
| 121 | 
         
             
                label_weights = data.get_label_weights().to(DEVICE)
         
     | 
| 122 | 
         
            +
                criterion = LabelWeightedBCELoss(label_weights)
         
     | 
| 
         | 
|
| 
         | 
|
| 123 | 
         | 
| 124 | 
         
             
                train_env = TrainingEnvironment(model, criterion, config)
         
     | 
| 125 | 
         
             
                callbacks = [
         
     | 
| 126 | 
         
            +
                    cb.EarlyStopping("val/loss", patience=2),
         
     | 
| 
         | 
|
| 127 | 
         
             
                    cb.StochasticWeightAveraging(1e-2),
         
     | 
| 128 | 
         
             
                    cb.RichProgressBar(),
         
     | 
| 129 | 
         
             
                ]
         
     | 
    	
        preprocessing/dataset.py
    CHANGED
    
    | 
         @@ -424,11 +424,7 @@ def record_audio_durations(folder: str): 
     | 
|
| 424 | 
         
             
                music_files = iglob(os.path.join(folder, "**", "*.wav"), recursive=True)
         
     | 
| 425 | 
         
             
                for file in music_files:
         
     | 
| 426 | 
         
             
                    meta = ta.info(file)
         
     | 
| 427 | 
         
            -
                    durations[file] = meta.num_frames / meta.sample_rate
         
     | 
| 428 | 
         | 
| 429 | 
         
             
                with open(os.path.join(folder, "audio_durations.json"), "w") as f:
         
     | 
| 430 | 
         
             
                    json.dump(durations, f)
         
     | 
| 431 | 
         
            -
             
     | 
| 432 | 
         
            -
             
     | 
| 433 | 
         
            -
            class GTZAN:
         
     | 
| 434 | 
         
            -
                pass
         
     | 
| 
         | 
|
| 424 | 
         
             
                music_files = iglob(os.path.join(folder, "**", "*.wav"), recursive=True)
         
     | 
| 425 | 
         
             
                for file in music_files:
         
     | 
| 426 | 
         
             
                    meta = ta.info(file)
         
     | 
| 427 | 
         
            +
                    durations[os.path.relpath(file, folder)] = meta.num_frames / meta.sample_rate
         
     | 
| 428 | 
         | 
| 429 | 
         
             
                with open(os.path.join(folder, "audio_durations.json"), "w") as f:
         
     | 
| 430 | 
         
             
                    json.dump(durations, f)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        preprocessing/pipelines.py
    CHANGED
    
    | 
         @@ -95,23 +95,27 @@ class WaveformPreprocessing(torch.nn.Module): 
     | 
|
| 95 | 
         
             
                    self.expected_sample_length = expected_sample_length
         
     | 
| 96 | 
         | 
| 97 | 
         
             
                def forward(self, waveform: torch.Tensor) -> torch.Tensor:
         
     | 
| 
         | 
|
| 98 | 
         
             
                    # Take out extra channels
         
     | 
| 99 | 
         
            -
                    if waveform.shape[ 
     | 
| 100 | 
         
            -
                        waveform = waveform.mean( 
     | 
| 101 | 
         | 
| 102 | 
         
             
                    # ensure it is the correct length
         
     | 
| 103 | 
         
            -
                    waveform = self._rectify_duration(waveform)
         
     | 
| 104 | 
         
             
                    return waveform
         
     | 
| 105 | 
         | 
| 106 | 
         
            -
                def _rectify_duration(self, waveform: torch.Tensor):
         
     | 
| 107 | 
         
             
                    expected_samples = self.expected_sample_length
         
     | 
| 108 | 
         
            -
                    sample_count = waveform.shape[1]
         
     | 
| 109 | 
         
             
                    if expected_samples == sample_count:
         
     | 
| 110 | 
         
             
                        return waveform
         
     | 
| 111 | 
         
             
                    elif expected_samples > sample_count:
         
     | 
| 112 | 
         
             
                        pad_amount = expected_samples - sample_count
         
     | 
| 113 | 
         
             
                        return torch.nn.functional.pad(
         
     | 
| 114 | 
         
            -
                            waveform, 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 115 | 
         
             
                        )
         
     | 
| 116 | 
         
             
                    else:
         
     | 
| 117 | 
         
             
                        return waveform[:, :expected_samples]
         
     | 
| 
         | 
|
| 95 | 
         
             
                    self.expected_sample_length = expected_sample_length
         
     | 
| 96 | 
         | 
| 97 | 
         
             
                def forward(self, waveform: torch.Tensor) -> torch.Tensor:
         
     | 
| 98 | 
         
            +
                    c_dim = 1 if len(waveform.shape) == 3 else 0
         
     | 
| 99 | 
         
             
                    # Take out extra channels
         
     | 
| 100 | 
         
            +
                    if waveform.shape[c_dim] > 1:
         
     | 
| 101 | 
         
            +
                        waveform = waveform.mean(c_dim, keepdim=True)
         
     | 
| 102 | 
         | 
| 103 | 
         
             
                    # ensure it is the correct length
         
     | 
| 104 | 
         
            +
                    waveform = self._rectify_duration(waveform, c_dim)
         
     | 
| 105 | 
         
             
                    return waveform
         
     | 
| 106 | 
         | 
| 107 | 
         
            +
                def _rectify_duration(self, waveform: torch.Tensor, channel_dim: int):
         
     | 
| 108 | 
         
             
                    expected_samples = self.expected_sample_length
         
     | 
| 109 | 
         
            +
                    sample_count = waveform.shape[channel_dim + 1]
         
     | 
| 110 | 
         
             
                    if expected_samples == sample_count:
         
     | 
| 111 | 
         
             
                        return waveform
         
     | 
| 112 | 
         
             
                    elif expected_samples > sample_count:
         
     | 
| 113 | 
         
             
                        pad_amount = expected_samples - sample_count
         
     | 
| 114 | 
         
             
                        return torch.nn.functional.pad(
         
     | 
| 115 | 
         
            +
                            waveform,
         
     | 
| 116 | 
         
            +
                            (channel_dim + 1) * [0] + [pad_amount],
         
     | 
| 117 | 
         
            +
                            mode="constant",
         
     | 
| 118 | 
         
            +
                            value=0.0,
         
     | 
| 119 | 
         
             
                        )
         
     | 
| 120 | 
         
             
                    else:
         
     | 
| 121 | 
         
             
                        return waveform[:, :expected_samples]
         
     |