VOICEVN / main /library /architectures /mdx_separator.py
AnhP's picture
Upload 65 files
98bb602 verified
raw
history blame
16.7 kB
import os
import sys
import onnx
import torch
import platform
import onnx2torch
import numpy as np
import onnxruntime as ort
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
from main.configs.config import Config
from main.library.uvr5_separator import spec_utils
from main.library.uvr5_separator.common_separator import CommonSeparator
translations = Config().translations
class MDXSeparator(CommonSeparator):
def __init__(self, common_config, arch_config):
super().__init__(config=common_config)
self.segment_size = arch_config.get("segment_size")
self.overlap = arch_config.get("overlap")
self.batch_size = arch_config.get("batch_size", 1)
self.hop_length = arch_config.get("hop_length")
self.enable_denoise = arch_config.get("enable_denoise")
self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size))
self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise))
self.compensate = self.model_data["compensate"]
self.dim_f = self.model_data["mdx_dim_f_set"]
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
self.n_fft = self.model_data["mdx_n_fft_scale_set"]
self.config_yaml = self.model_data.get("config_yaml", None)
self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}")
self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}")
self.load_model()
self.n_bins = 0
self.trim = 0
self.chunk_size = 0
self.gen_size = 0
self.stft = None
self.primary_source = None
self.secondary_source = None
self.audio_file_path = None
self.audio_file_base = None
def load_model(self):
self.logger.debug(translations["load_model_onnx"])
if self.segment_size == self.dim_t:
ort_session_options = ort.SessionOptions()
if self.log_level > 10: ort_session_options.log_severity_level = 3
else: ort_session_options.log_severity_level = 0
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
self.logger.debug(translations["load_model_onnx_success"])
else:
if platform.system() == 'Windows':
onnx_model = onnx.load(self.model_path)
self.model_run = onnx2torch.convert(onnx_model)
else: self.model_run = onnx2torch.convert(self.model_path)
self.model_run.to(self.torch_device).eval()
self.logger.warning(translations["onnx_to_pytorch"])
def separate(self, audio_file_path):
self.audio_file_path = audio_file_path
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path))
mix = self.prepare_mix(self.audio_file_path)
self.logger.debug(translations["normalization_demix"])
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold)
source = self.demix(mix)
self.logger.debug(translations["mix_success"])
output_files = []
self.logger.debug(translations["process_output_file"])
if not isinstance(self.primary_source, np.ndarray):
self.logger.debug(translations["primary_source"])
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T
if not isinstance(self.secondary_source, np.ndarray):
self.logger.debug(translations["secondary_source"])
raw_mix = self.demix(mix, is_match_mix=True)
if self.invert_using_spec:
self.logger.debug(translations["invert_using_spec"])
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
else:
self.logger.debug(translations["invert_using_spec_2"])
self.secondary_source = mix.T - source.T
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path))
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
output_files.append(self.secondary_stem_output_path)
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path))
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
output_files.append(self.primary_stem_output_path)
return output_files
def initialize_model_settings(self):
self.logger.debug(translations["starting_model"])
self.n_bins = self.n_fft // 2 + 1
self.trim = self.n_fft // 2
self.chunk_size = self.hop_length * (self.segment_size - 1)
self.gen_size = self.chunk_size - 2 * self.trim
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}")
self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}")
def initialize_mix(self, mix, is_ckpt=False):
self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape))
if mix.shape[0] != 2:
error_message = translations["!=2"].format(shape=mix.shape[0])
self.logger.error(error_message)
raise ValueError(error_message)
if is_ckpt:
self.logger.debug(translations["process_check"])
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
self.logger.debug(f"{translations['cache']}: {pad}")
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
num_chunks = mixture.shape[-1] // self.gen_size
self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks))
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
else:
self.logger.debug(translations["process_no_check"])
mix_waves = []
n_sample = mix.shape[1]
pad = self.gen_size - n_sample % self.gen_size
self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad))
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}")
i = 0
while i < n_sample + pad:
waves = np.array(mix_p[:, i : i + self.chunk_size])
mix_waves.append(waves)
self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size))
i += self.gen_size
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape))
return mix_waves_tensor, pad
def demix(self, mix, is_match_mix=False):
self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...")
self.initialize_model_settings()
org_mix = mix
self.logger.debug(f"{translations['mix_shape']}: {org_mix.shape}")
tar_waves_ = []
if is_match_mix:
chunk_size = self.hop_length * (self.segment_size - 1)
overlap = 0.02
self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap))
else:
chunk_size = self.chunk_size
overlap = self.overlap
self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap))
gen_size = chunk_size - 2 * self.trim
self.logger.debug(f"{translations['calc_size']}: {gen_size}")
pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size)
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}")
step = int((1 - overlap) * chunk_size)
self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap))
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
total = 0
total_chunks = (mixture.shape[-1] + step - 1) // step
self.logger.debug(f"{translations['all_process_part']}: {total_chunks}")
for i in tqdm(range(0, mixture.shape[-1], step)):
total += 1
start = i
end = min(i + chunk_size, mixture.shape[-1])
self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end))
chunk_size_actual = end - start
window = None
if overlap != 0:
window = np.hanning(chunk_size_actual)
window = np.tile(window[None, None, :], (1, 2, 1))
self.logger.debug(translations["window"])
mix_part_ = mixture[:, start:end]
if end != i + chunk_size:
pad_size = (i + chunk_size) - end
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device)
mix_waves = mix_part.split(self.batch_size)
total_batches = len(mix_waves)
self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}")
with torch.no_grad():
batches_processed = 0
for mix_wave in mix_waves:
batches_processed += 1
self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}")
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
if window is not None:
tar_waves[..., :chunk_size_actual] *= window
divider[..., start:end] += window
else: divider[..., start:end] += 1
result[..., start:end] += tar_waves[..., : end - start]
self.logger.debug(translations["normalization_2"])
tar_waves = result / divider
tar_waves_.append(tar_waves)
tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim]
tar_waves = np.concatenate(tar_waves_, axis=-1)[:, : mix.shape[-1]]
source = tar_waves[:, 0:None]
self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}")
if not is_match_mix:
source *= self.compensate
self.logger.debug(translations["mix_match"])
self.logger.debug(translations["mix_success"])
return source
def run_model(self, mix, is_match_mix=False):
spek = self.stft(mix.to(self.torch_device))
self.logger.debug(translations["stft_2"].format(shape=spek.shape))
spek[:, :, :3, :] *= 0
if is_match_mix:
spec_pred = spek.cpu().numpy()
self.logger.debug(translations["is_match_mix"])
else:
if self.enable_denoise:
spec_pred_neg = self.model_run(-spek)
spec_pred_pos = self.model_run(spek)
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5)
self.logger.debug(translations["enable_denoise"])
else:
spec_pred = self.model_run(spek)
self.logger.debug(translations["no_denoise"])
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
self.logger.debug(f"{translations['stft']}: {result.shape}")
return result
class STFT:
def __init__(self, logger, n_fft, hop_length, dim_f, device):
self.logger = logger
self.n_fft = n_fft
self.hop_length = hop_length
self.dim_f = dim_f
self.device = device
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
def __call__(self, input_tensor):
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
if is_non_standard_device: input_tensor = input_tensor.cpu()
stft_window = self.hann_window.to(input_tensor.device)
batch_dimensions = input_tensor.shape[:-2]
channel_dim, time_dim = input_tensor.shape[-2:]
reshaped_tensor = input_tensor.reshape([-1, time_dim])
stft_output = torch.stft(reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False)
permuted_stft_output = stft_output.permute([0, 3, 1, 2])
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]])
if is_non_standard_device: final_output = final_output.to(self.device)
return final_output[..., : self.dim_f, :]
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)
padded_tensor = torch.cat([input_tensor, freq_padding], -2)
return padded_tensor
def calculate_inverse_dimensions(self, input_tensor):
batch_dimensions = input_tensor.shape[:-3]
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
num_freq_bins = self.n_fft // 2 + 1
return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim])
flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim])
permuted_tensor = flattened_tensor.permute([0, 2, 3, 1])
complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
return complex_tensor
def inverse(self, input_tensor):
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
if is_non_standard_device: input_tensor = input_tensor.cpu()
stft_window = self.hann_window.to(input_tensor.device)
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins)
complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim)
istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True)
final_output = istft_result.reshape([*batch_dimensions, 2, -1])
if is_non_standard_device: final_output = final_output.to(self.device)
return final_output