English
music
emotion
music2emo / preprocess /feature_extractor.py
kjysmu's picture
Upload 6 files
1adb3ce verified
raw
history blame
4.02 kB
import os
import torch
import torchaudio
import torchaudio.transforms as T
from tqdm import tqdm
import numpy as np
from omegaconf import DictConfig
import hydra
from hydra.utils import to_absolute_path
from transformers import Wav2Vec2FeatureExtractor, AutoModel
from encoder.mert import FeatureExtractorMERT
from encoder.music2latent import FeatureExtractorM2L
class AudioProcessor:
def __init__(self, cfg: DictConfig):
self.input_directory = cfg.dataset.input_dir
self.output_directory = cfg.dataset.output_dir
self.segment_duration = cfg.segment_duration
self.resample_rate = cfg.model.sr
self.device_id = cfg.device_id
self.feature_extractor = self._initialize_extractor(cfg.model.name)
self.is_split = cfg.is_split
def _initialize_extractor(self, model_name: str):
if "MERT" in model_name:
return FeatureExtractorMERT(model_name=model_name, device_id=self.device_id, sr=self.resample_rate)
elif "music2latent" == model_name:
return FeatureExtractorM2L(device_id=self.device_id, sr=self.resample_rate)
else:
raise NotImplementedError(f"Feature extraction for model {model_name} is not implemented.")
def resample_waveform(self, waveform, original_sample_rate, target_sample_rate):
if original_sample_rate != target_sample_rate:
resampler = T.Resample(original_sample_rate, target_sample_rate)
return resampler(waveform), target_sample_rate
return waveform, original_sample_rate
def split_audio(self, waveform, sample_rate):
segment_samples = self.segment_duration * sample_rate
total_samples = waveform.size(0)
segments = []
for start in range(0, total_samples, segment_samples):
end = start + segment_samples
if end <= total_samples:
segment = waveform[start:end]
segments.append(segment)
# In case audio length is shorter than segment length.
if len(segments) == 0:
segment = waveform
segments.append(segment)
return segments
def process_audio_file(self, file_path, output_dir):
print(f"Processing {file_path}")
waveform, sample_rate = torchaudio.load(file_path)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0).unsqueeze(0)
waveform = waveform.squeeze()
waveform, sample_rate = self.resample_waveform(waveform, sample_rate, self.resample_rate)
if self.is_split:
segments = self.split_audio(waveform, sample_rate)
for i, segment in enumerate(segments):
segment_save_path = os.path.join(output_dir, f"segment_{i}.npy")
if os.path.exists(segment_save_path):
continue
self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path)
else:
segment_save_path = os.path.join(output_dir, f"segment_0.npy")
if not os.path.exists(segment_save_path):
self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
def process_directory(self):
for root, _, files in os.walk(self.input_directory):
for file in files:
if file.endswith('.mp3'):
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, self.input_directory)
output_file_dir = os.path.join(self.output_directory, os.path.splitext(relative_path)[0])
os.makedirs(output_file_dir, exist_ok=True)
self.process_audio_file(file_path, output_file_dir)
@hydra.main(version_base=None, config_path="../config", config_name="prep_config")
def main(cfg: DictConfig):
processor = AudioProcessor(cfg)
processor.process_directory()
if __name__ == "__main__":
main()