|
import os
|
|
import sys
|
|
import onnx
|
|
import torch
|
|
import platform
|
|
import onnx2torch
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
from tqdm import tqdm
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.configs.config import Config
|
|
from main.library.uvr5_separator import spec_utils
|
|
from main.library.uvr5_separator.common_separator import CommonSeparator
|
|
|
|
translations = Config().translations
|
|
|
|
class MDXSeparator(CommonSeparator):
|
|
def __init__(self, common_config, arch_config):
|
|
super().__init__(config=common_config)
|
|
self.segment_size = arch_config.get("segment_size")
|
|
self.overlap = arch_config.get("overlap")
|
|
self.batch_size = arch_config.get("batch_size", 1)
|
|
self.hop_length = arch_config.get("hop_length")
|
|
self.enable_denoise = arch_config.get("enable_denoise")
|
|
self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size))
|
|
self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise))
|
|
self.compensate = self.model_data["compensate"]
|
|
self.dim_f = self.model_data["mdx_dim_f_set"]
|
|
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
|
|
self.n_fft = self.model_data["mdx_n_fft_scale_set"]
|
|
self.config_yaml = self.model_data.get("config_yaml", None)
|
|
self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}")
|
|
self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}")
|
|
self.load_model()
|
|
self.n_bins = 0
|
|
self.trim = 0
|
|
self.chunk_size = 0
|
|
self.gen_size = 0
|
|
self.stft = None
|
|
self.primary_source = None
|
|
self.secondary_source = None
|
|
self.audio_file_path = None
|
|
self.audio_file_base = None
|
|
|
|
def load_model(self):
|
|
self.logger.debug(translations["load_model_onnx"])
|
|
|
|
if self.segment_size == self.dim_t:
|
|
ort_session_options = ort.SessionOptions()
|
|
ort_session_options.log_severity_level = 3 if self.log_level > 10 else 0
|
|
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
|
|
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
|
|
self.logger.debug(translations["load_model_onnx_success"])
|
|
else:
|
|
self.model_run = onnx2torch.convert(onnx.load(self.model_path)) if platform.system() == 'Windows' else onnx2torch.convert(self.model_path)
|
|
self.model_run.to(self.torch_device).eval()
|
|
self.logger.debug(translations["onnx_to_pytorch"])
|
|
|
|
def separate(self, audio_file_path):
|
|
self.audio_file_path = audio_file_path
|
|
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
|
self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path))
|
|
mix = self.prepare_mix(self.audio_file_path)
|
|
self.logger.debug(translations["normalization_demix"])
|
|
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold)
|
|
source = self.demix(mix)
|
|
self.logger.debug(translations["mix_success"])
|
|
output_files = []
|
|
self.logger.debug(translations["process_output_file"])
|
|
|
|
if not isinstance(self.primary_source, np.ndarray):
|
|
self.logger.debug(translations["primary_source"])
|
|
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T
|
|
|
|
if not isinstance(self.secondary_source, np.ndarray):
|
|
self.logger.debug(translations["secondary_source"])
|
|
raw_mix = self.demix(mix, is_match_mix=True)
|
|
|
|
if self.invert_using_spec:
|
|
self.logger.debug(translations["invert_using_spec"])
|
|
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
|
|
else:
|
|
self.logger.debug(translations["invert_using_spec_2"])
|
|
self.secondary_source = mix.T - source.T
|
|
|
|
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
|
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
|
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path))
|
|
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
|
output_files.append(self.secondary_stem_output_path)
|
|
|
|
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
|
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
|
if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T
|
|
|
|
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path))
|
|
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
|
output_files.append(self.primary_stem_output_path)
|
|
|
|
return output_files
|
|
|
|
def initialize_model_settings(self):
|
|
self.logger.debug(translations["starting_model"])
|
|
|
|
self.n_bins = self.n_fft // 2 + 1
|
|
self.trim = self.n_fft // 2
|
|
|
|
self.chunk_size = self.hop_length * (self.segment_size - 1)
|
|
self.gen_size = self.chunk_size - 2 * self.trim
|
|
|
|
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
|
|
|
|
self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}")
|
|
self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}")
|
|
|
|
def initialize_mix(self, mix, is_ckpt=False):
|
|
self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape))
|
|
|
|
if mix.shape[0] != 2:
|
|
error_message = translations["!=2"].format(shape=mix.shape[0])
|
|
self.logger.error(error_message)
|
|
raise ValueError(error_message)
|
|
|
|
if is_ckpt:
|
|
self.logger.debug(translations["process_check"])
|
|
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
|
|
self.logger.debug(f"{translations['cache']}: {pad}")
|
|
|
|
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
|
|
|
num_chunks = mixture.shape[-1] // self.gen_size
|
|
self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks))
|
|
|
|
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
|
|
else:
|
|
self.logger.debug(translations["process_no_check"])
|
|
mix_waves = []
|
|
n_sample = mix.shape[1]
|
|
|
|
pad = self.gen_size - n_sample % self.gen_size
|
|
self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad))
|
|
|
|
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
|
|
self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}")
|
|
|
|
i = 0
|
|
while i < n_sample + pad:
|
|
mix_waves.append(np.array(mix_p[:, i : i + self.chunk_size]))
|
|
|
|
self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size))
|
|
i += self.gen_size
|
|
|
|
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
|
|
self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape))
|
|
|
|
return mix_waves_tensor, pad
|
|
|
|
def demix(self, mix, is_match_mix=False):
|
|
self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...")
|
|
self.initialize_model_settings()
|
|
self.logger.debug(f"{translations['mix_shape']}: {mix.shape}")
|
|
tar_waves_ = []
|
|
|
|
if is_match_mix:
|
|
chunk_size = self.hop_length * (self.segment_size - 1)
|
|
overlap = 0.02
|
|
self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap))
|
|
else:
|
|
chunk_size = self.chunk_size
|
|
overlap = self.overlap
|
|
self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap))
|
|
|
|
gen_size = chunk_size - 2 * self.trim
|
|
self.logger.debug(f"{translations['calc_size']}: {gen_size}")
|
|
|
|
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, gen_size + self.trim - ((mix.shape[-1]) % gen_size)), dtype="float32")), 1)
|
|
self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}")
|
|
|
|
step = int((1 - overlap) * chunk_size)
|
|
self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap))
|
|
|
|
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
|
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
|
|
|
total = 0
|
|
total_chunks = (mixture.shape[-1] + step - 1) // step
|
|
self.logger.debug(f"{translations['all_process_part']}: {total_chunks}")
|
|
|
|
for i in tqdm(range(0, mixture.shape[-1], step), ncols=100, unit="f"):
|
|
total += 1
|
|
start = i
|
|
end = min(i + chunk_size, mixture.shape[-1])
|
|
self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end))
|
|
|
|
chunk_size_actual = end - start
|
|
window = None
|
|
|
|
if overlap != 0:
|
|
window = np.hanning(chunk_size_actual)
|
|
window = np.tile(window[None, None, :], (1, 2, 1))
|
|
self.logger.debug(translations["window"])
|
|
|
|
mix_part_ = mixture[:, start:end]
|
|
|
|
if end != i + chunk_size:
|
|
pad_size = (i + chunk_size) - end
|
|
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
|
|
|
|
mix_waves = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device).split(self.batch_size)
|
|
|
|
total_batches = len(mix_waves)
|
|
self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}")
|
|
|
|
with torch.no_grad():
|
|
batches_processed = 0
|
|
|
|
for mix_wave in mix_waves:
|
|
batches_processed += 1
|
|
self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}")
|
|
|
|
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
|
|
|
|
if window is not None:
|
|
tar_waves[..., :chunk_size_actual] *= window
|
|
divider[..., start:end] += window
|
|
else: divider[..., start:end] += 1
|
|
|
|
result[..., start:end] += tar_waves[..., : end - start]
|
|
|
|
|
|
self.logger.debug(translations["normalization_2"])
|
|
tar_waves = result / divider
|
|
tar_waves_.append(tar_waves)
|
|
|
|
tar_waves = np.concatenate(np.vstack(tar_waves_)[:, :, self.trim : -self.trim], axis=-1)[:, : mix.shape[-1]]
|
|
|
|
source = tar_waves[:, 0:None]
|
|
self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}")
|
|
|
|
if not is_match_mix:
|
|
source *= self.compensate
|
|
self.logger.debug(translations["mix_match"])
|
|
|
|
self.logger.debug(translations["mix_success"])
|
|
return source
|
|
|
|
def run_model(self, mix, is_match_mix=False):
|
|
spek = self.stft(mix.to(self.torch_device))
|
|
self.logger.debug(translations["stft_2"].format(shape=spek.shape))
|
|
|
|
spek[:, :, :3, :] *= 0
|
|
|
|
if is_match_mix:
|
|
spec_pred = spek.cpu().numpy()
|
|
self.logger.debug(translations["is_match_mix"])
|
|
else:
|
|
if self.enable_denoise:
|
|
spec_pred_neg = self.model_run(-spek)
|
|
spec_pred_pos = self.model_run(spek)
|
|
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5)
|
|
self.logger.debug(translations["enable_denoise"])
|
|
else:
|
|
spec_pred = self.model_run(spek)
|
|
self.logger.debug(translations["no_denoise"])
|
|
|
|
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
|
|
self.logger.debug(f"{translations['stft']}: {result.shape}")
|
|
|
|
return result
|
|
|
|
class STFT:
|
|
def __init__(self, logger, n_fft, hop_length, dim_f, device):
|
|
self.logger = logger
|
|
self.n_fft = n_fft
|
|
self.hop_length = hop_length
|
|
self.dim_f = dim_f
|
|
self.device = device
|
|
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
|
|
|
def __call__(self, input_tensor):
|
|
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
|
|
|
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
|
|
|
batch_dimensions = input_tensor.shape[:-2]
|
|
channel_dim, time_dim = input_tensor.shape[-2:]
|
|
|
|
permuted_stft_output = torch.stft(input_tensor.reshape([-1, time_dim]), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True, return_complex=False).permute([0, 3, 1, 2])
|
|
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]])
|
|
|
|
if is_non_standard_device: final_output = final_output.to(self.device)
|
|
return final_output[..., : self.dim_f, :]
|
|
|
|
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
|
|
return torch.cat([input_tensor, torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)], -2)
|
|
|
|
def calculate_inverse_dimensions(self, input_tensor):
|
|
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
|
|
|
|
return input_tensor.shape[:-3], channel_dim, freq_dim, time_dim, self.n_fft // 2 + 1
|
|
|
|
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
|
|
permuted_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim]).reshape([-1, 2, num_freq_bins, time_dim]).permute([0, 2, 3, 1])
|
|
|
|
return permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
|
|
|
|
def inverse(self, input_tensor):
|
|
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
|
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
|
|
|
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
|
|
final_output = torch.istft(self.prepare_for_istft(self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins), batch_dimensions, channel_dim, num_freq_bins, time_dim), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True).reshape([*batch_dimensions, 2, -1])
|
|
|
|
if is_non_standard_device: final_output = final_output.to(self.device)
|
|
|
|
return final_output |