|
import os |
|
import sys |
|
import time |
|
import torchaudio |
|
import torch |
|
from torch import nn |
|
from scipy import signal |
|
from scipy.io import wavfile |
|
import numpy as np |
|
import multiprocessing |
|
from pydub import AudioSegment |
|
|
|
multiprocessing.set_start_method("spawn", force=True) |
|
|
|
now_directory = os.getcwd() |
|
sys.path.append(now_directory) |
|
|
|
from rvc.lib.utils import load_audio |
|
from rvc.train.slicer import Slicer |
|
|
|
|
|
OVERLAP = 0.3 |
|
MAX_AMPLITUDE = 0.9 |
|
ALPHA = 0.75 |
|
HIGH_PASS_CUTOFF = 48 |
|
SAMPLE_RATE_16K = 16000 |
|
|
|
|
|
class PreProcess: |
|
def __init__(self, sr: int, exp_dir: str, per: float): |
|
self.slicer = Slicer( |
|
sr=sr, |
|
threshold=-42, |
|
min_length=1500, |
|
min_interval=400, |
|
hop_size=15, |
|
max_sil_kept=500, |
|
) |
|
self.sr = sr |
|
self.b_high, self.a_high = signal.butter( |
|
N=5, Wn=HIGH_PASS_CUTOFF, btype="high", fs=self.sr |
|
) |
|
self.per = per |
|
self.exp_dir = exp_dir |
|
self.device = "cpu" |
|
self.gt_wavs_dir = os.path.join(exp_dir, "sliced_audios") |
|
self.wavs16k_dir = os.path.join(exp_dir, "sliced_audios_16k") |
|
os.makedirs(self.gt_wavs_dir, exist_ok=True) |
|
os.makedirs(self.wavs16k_dir, exist_ok=True) |
|
|
|
def _normalize_audio(self, audio: torch.Tensor): |
|
tmp_max = torch.abs(audio).max() |
|
if tmp_max > 2.5: |
|
return None |
|
return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio |
|
|
|
def _write_audio(self, audio: torch.Tensor, filename: str, sr: int): |
|
audio = audio.cpu().numpy() |
|
wavfile.write(filename, sr, audio.astype(np.float32)) |
|
|
|
def process_audio_segment(self, audio_segment: torch.Tensor, idx0: int, idx1: int): |
|
normalized_audio = self._normalize_audio(audio_segment) |
|
if normalized_audio is None: |
|
print(f"{idx0}-{idx1}-filtered") |
|
return |
|
|
|
gt_wav_path = os.path.join(self.gt_wavs_dir, f"{idx0}_{idx1}.wav") |
|
self._write_audio(normalized_audio, gt_wav_path, self.sr) |
|
|
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=self.sr, new_freq=SAMPLE_RATE_16K |
|
).to(self.device) |
|
audio_16k = resampler(normalized_audio.float()) |
|
wav_16k_path = os.path.join(self.wavs16k_dir, f"{idx0}_{idx1}.wav") |
|
self._write_audio(audio_16k, wav_16k_path, SAMPLE_RATE_16K) |
|
|
|
def process_audio(self, path: str, idx0: int): |
|
try: |
|
audio = load_audio(path, self.sr) |
|
audio = torch.tensor( |
|
signal.lfilter(self.b_high, self.a_high, audio), device=self.device |
|
).float() |
|
|
|
idx1 = 0 |
|
for audio_segment in self.slicer.slice(audio.cpu().numpy()): |
|
audio_segment = torch.tensor(audio_segment, device=self.device).float() |
|
i = 0 |
|
while True: |
|
start = int(self.sr * (self.per - OVERLAP) * i) |
|
i += 1 |
|
if len(audio_segment[start:]) > (self.per + OVERLAP) * self.sr: |
|
tmp_audio = audio_segment[ |
|
start : start + int(self.per * self.sr) |
|
] |
|
self.process_audio_segment(tmp_audio, idx0, idx1) |
|
idx1 += 1 |
|
else: |
|
tmp_audio = audio_segment[start:] |
|
self.process_audio_segment(tmp_audio, idx0, idx1) |
|
idx1 += 1 |
|
break |
|
except Exception as error: |
|
print(f"An error occurred on {path} path: {error}") |
|
|
|
def process_audio_file(self, file_path_idx): |
|
file_path, idx0 = file_path_idx |
|
ext = os.path.splitext(file_path)[1].lower() |
|
if ext not in [".wav"]: |
|
audio = AudioSegment.from_file(file_path) |
|
file_path = os.path.join("/tmp", f"{idx0}.wav") |
|
audio.export(file_path, format="wav") |
|
self.process_audio(file_path, idx0) |
|
|
|
|
|
def preprocess_training_set( |
|
input_root: str, |
|
sr: int, |
|
num_processes: int, |
|
exp_dir: str, |
|
per: float, |
|
): |
|
start_time = time.time() |
|
|
|
pp = PreProcess(sr, exp_dir, per) |
|
print(f"Starting preprocess with {num_processes} processes...") |
|
|
|
files = [ |
|
(os.path.join(input_root, f), idx) |
|
for idx, f in enumerate(os.listdir(input_root)) |
|
if f.lower().endswith((".wav", ".mp3", ".flac", ".ogg")) |
|
] |
|
|
|
ctx = multiprocessing.get_context("spawn") |
|
with ctx.Pool(processes=num_processes) as pool: |
|
pool.map(pp.process_audio_file, files) |
|
|
|
elapsed_time = time.time() - start_time |
|
print(f"Preprocess completed in {elapsed_time:.2f} seconds.") |
|
|
|
|
|
if __name__ == "__main__": |
|
experiment_directory = str(sys.argv[1]) |
|
input_root = str(sys.argv[2]) |
|
sample_rate = int(sys.argv[3]) |
|
percentage = float(sys.argv[4]) |
|
num_processes = ( |
|
int(sys.argv[5]) if len(sys.argv) > 5 else multiprocessing.cpu_count() |
|
) |
|
|
|
preprocess_training_set( |
|
input_root, |
|
sample_rate, |
|
num_processes, |
|
experiment_directory, |
|
percentage, |
|
) |
|
|