|
import os |
|
import torch |
|
import torchaudio |
|
import torchaudio.transforms as T |
|
from tqdm import tqdm |
|
import numpy as np |
|
from omegaconf import DictConfig |
|
import hydra |
|
from hydra.utils import to_absolute_path |
|
from transformers import Wav2Vec2FeatureExtractor, AutoModel |
|
|
|
from encoder.mert import FeatureExtractorMERT |
|
from encoder.music2latent import FeatureExtractorM2L |
|
|
|
class AudioProcessor: |
|
def __init__(self, cfg: DictConfig): |
|
self.input_directory = cfg.dataset.input_dir |
|
self.output_directory = cfg.dataset.output_dir |
|
self.segment_duration = cfg.segment_duration |
|
self.resample_rate = cfg.model.sr |
|
self.device_id = cfg.device_id |
|
self.feature_extractor = self._initialize_extractor(cfg.model.name) |
|
self.is_split = cfg.is_split |
|
|
|
def _initialize_extractor(self, model_name: str): |
|
if "MERT" in model_name: |
|
return FeatureExtractorMERT(model_name=model_name, device_id=self.device_id, sr=self.resample_rate) |
|
elif "music2latent" == model_name: |
|
return FeatureExtractorM2L(device_id=self.device_id, sr=self.resample_rate) |
|
else: |
|
raise NotImplementedError(f"Feature extraction for model {model_name} is not implemented.") |
|
|
|
def resample_waveform(self, waveform, original_sample_rate, target_sample_rate): |
|
if original_sample_rate != target_sample_rate: |
|
resampler = T.Resample(original_sample_rate, target_sample_rate) |
|
return resampler(waveform), target_sample_rate |
|
return waveform, original_sample_rate |
|
|
|
def split_audio(self, waveform, sample_rate): |
|
segment_samples = self.segment_duration * sample_rate |
|
total_samples = waveform.size(0) |
|
|
|
segments = [] |
|
for start in range(0, total_samples, segment_samples): |
|
end = start + segment_samples |
|
if end <= total_samples: |
|
segment = waveform[start:end] |
|
segments.append(segment) |
|
|
|
|
|
if len(segments) == 0: |
|
segment = waveform |
|
segments.append(segment) |
|
|
|
return segments |
|
|
|
def process_audio_file(self, file_path, output_dir): |
|
print(f"Processing {file_path}") |
|
waveform, sample_rate = torchaudio.load(file_path) |
|
|
|
if waveform.shape[0] > 1: |
|
waveform = waveform.mean(dim=0).unsqueeze(0) |
|
waveform = waveform.squeeze() |
|
waveform, sample_rate = self.resample_waveform(waveform, sample_rate, self.resample_rate) |
|
|
|
if self.is_split: |
|
segments = self.split_audio(waveform, sample_rate) |
|
for i, segment in enumerate(segments): |
|
segment_save_path = os.path.join(output_dir, f"segment_{i}.npy") |
|
if os.path.exists(segment_save_path): |
|
continue |
|
self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path) |
|
else: |
|
segment_save_path = os.path.join(output_dir, f"segment_0.npy") |
|
if not os.path.exists(segment_save_path): |
|
self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path) |
|
|
|
def process_directory(self): |
|
for root, _, files in os.walk(self.input_directory): |
|
for file in files: |
|
if file.endswith('.mp3'): |
|
file_path = os.path.join(root, file) |
|
relative_path = os.path.relpath(file_path, self.input_directory) |
|
output_file_dir = os.path.join(self.output_directory, os.path.splitext(relative_path)[0]) |
|
os.makedirs(output_file_dir, exist_ok=True) |
|
self.process_audio_file(file_path, output_file_dir) |
|
|
|
@hydra.main(version_base=None, config_path="../config", config_name="prep_config") |
|
def main(cfg: DictConfig): |
|
processor = AudioProcessor(cfg) |
|
processor.process_directory() |
|
|
|
if __name__ == "__main__": |
|
main() |