English
music
emotion
File size: 4,020 Bytes
1adb3ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import torch
import torchaudio
import torchaudio.transforms as T
from tqdm import tqdm
import numpy as np
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
from transformers import Wav2Vec2FeatureExtractor, AutoModel

from encoder.mert import FeatureExtractorMERT
from encoder.music2latent import FeatureExtractorM2L

class AudioProcessor:
    def __init__(self, cfg: DictConfig):
        self.input_directory = cfg.dataset.input_dir
        self.output_directory = cfg.dataset.output_dir
        self.segment_duration = cfg.segment_duration
        self.resample_rate = cfg.model.sr
        self.device_id = cfg.device_id        
        self.feature_extractor = self._initialize_extractor(cfg.model.name)
        self.is_split = cfg.is_split

    def _initialize_extractor(self, model_name: str):
        if "MERT" in model_name:
            return FeatureExtractorMERT(model_name=model_name, device_id=self.device_id, sr=self.resample_rate)
        elif "music2latent" == model_name:
            return FeatureExtractorM2L(device_id=self.device_id, sr=self.resample_rate)
        else:
            raise NotImplementedError(f"Feature extraction for model {model_name} is not implemented.")

    def resample_waveform(self, waveform, original_sample_rate, target_sample_rate):
        if original_sample_rate != target_sample_rate:
            resampler = T.Resample(original_sample_rate, target_sample_rate)
            return resampler(waveform), target_sample_rate
        return waveform, original_sample_rate

    def split_audio(self, waveform, sample_rate):
        segment_samples = self.segment_duration * sample_rate
        total_samples = waveform.size(0)

        segments = []
        for start in range(0, total_samples, segment_samples):
            end = start + segment_samples
            if end <= total_samples:
                segment = waveform[start:end]
                segments.append(segment)
        
        # In case audio length is shorter than segment length.
        if len(segments) == 0: 
            segment = waveform
            segments.append(segment)

        return segments

    def process_audio_file(self, file_path, output_dir):
        print(f"Processing {file_path}")
        waveform, sample_rate = torchaudio.load(file_path)
        
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0).unsqueeze(0)
        waveform = waveform.squeeze()
        waveform, sample_rate = self.resample_waveform(waveform, sample_rate, self.resample_rate)
        
        if self.is_split:        
            segments = self.split_audio(waveform, sample_rate)
            for i, segment in enumerate(segments):
                segment_save_path = os.path.join(output_dir, f"segment_{i}.npy")
                if os.path.exists(segment_save_path):
                    continue
                self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path)
        else:
            segment_save_path = os.path.join(output_dir, f"segment_0.npy")
            if not os.path.exists(segment_save_path):
                self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)

    def process_directory(self):
        for root, _, files in os.walk(self.input_directory):
            for file in files:
                if file.endswith('.mp3'):
                    file_path = os.path.join(root, file)
                    relative_path = os.path.relpath(file_path, self.input_directory)
                    output_file_dir = os.path.join(self.output_directory, os.path.splitext(relative_path)[0])
                    os.makedirs(output_file_dir, exist_ok=True)
                    self.process_audio_file(file_path, output_file_dir)

@hydra.main(version_base=None, config_path="../config", config_name="prep_config")
def main(cfg: DictConfig):
    processor = AudioProcessor(cfg)
    processor.process_directory()

if __name__ == "__main__":
    main()