Spaces:
Build error
Build error
import os | |
import sys | |
import onnx | |
import torch | |
import platform | |
import onnx2torch | |
import numpy as np | |
import onnxruntime as ort | |
from tqdm import tqdm | |
now_dir = os.getcwd() | |
sys.path.append(now_dir) | |
from main.configs.config import Config | |
from main.library.uvr5_separator import spec_utils | |
from main.library.uvr5_separator.common_separator import CommonSeparator | |
translations = Config().translations | |
class MDXSeparator(CommonSeparator): | |
def __init__(self, common_config, arch_config): | |
super().__init__(config=common_config) | |
self.segment_size = arch_config.get("segment_size") | |
self.overlap = arch_config.get("overlap") | |
self.batch_size = arch_config.get("batch_size", 1) | |
self.hop_length = arch_config.get("hop_length") | |
self.enable_denoise = arch_config.get("enable_denoise") | |
self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size)) | |
self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise)) | |
self.compensate = self.model_data["compensate"] | |
self.dim_f = self.model_data["mdx_dim_f_set"] | |
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"] | |
self.n_fft = self.model_data["mdx_n_fft_scale_set"] | |
self.config_yaml = self.model_data.get("config_yaml", None) | |
self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}") | |
self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}") | |
self.load_model() | |
self.n_bins = 0 | |
self.trim = 0 | |
self.chunk_size = 0 | |
self.gen_size = 0 | |
self.stft = None | |
self.primary_source = None | |
self.secondary_source = None | |
self.audio_file_path = None | |
self.audio_file_base = None | |
def load_model(self): | |
self.logger.debug(translations["load_model_onnx"]) | |
if self.segment_size == self.dim_t: | |
ort_session_options = ort.SessionOptions() | |
if self.log_level > 10: ort_session_options.log_severity_level = 3 | |
else: ort_session_options.log_severity_level = 0 | |
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options) | |
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0] | |
self.logger.debug(translations["load_model_onnx_success"]) | |
else: | |
if platform.system() == 'Windows': | |
onnx_model = onnx.load(self.model_path) | |
self.model_run = onnx2torch.convert(onnx_model) | |
else: self.model_run = onnx2torch.convert(self.model_path) | |
self.model_run.to(self.torch_device).eval() | |
self.logger.warning(translations["onnx_to_pytorch"]) | |
def separate(self, audio_file_path): | |
self.audio_file_path = audio_file_path | |
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0] | |
self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path)) | |
mix = self.prepare_mix(self.audio_file_path) | |
self.logger.debug(translations["normalization_demix"]) | |
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold) | |
source = self.demix(mix) | |
self.logger.debug(translations["mix_success"]) | |
output_files = [] | |
self.logger.debug(translations["process_output_file"]) | |
if not isinstance(self.primary_source, np.ndarray): | |
self.logger.debug(translations["primary_source"]) | |
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T | |
if not isinstance(self.secondary_source, np.ndarray): | |
self.logger.debug(translations["secondary_source"]) | |
raw_mix = self.demix(mix, is_match_mix=True) | |
if self.invert_using_spec: | |
self.logger.debug(translations["invert_using_spec"]) | |
self.secondary_source = spec_utils.invert_stem(raw_mix, source) | |
else: | |
self.logger.debug(translations["invert_using_spec_2"]) | |
self.secondary_source = mix.T - source.T | |
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower(): | |
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}") | |
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path)) | |
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name) | |
output_files.append(self.secondary_stem_output_path) | |
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower(): | |
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}") | |
if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T | |
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path)) | |
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name) | |
output_files.append(self.primary_stem_output_path) | |
return output_files | |
def initialize_model_settings(self): | |
self.logger.debug(translations["starting_model"]) | |
self.n_bins = self.n_fft // 2 + 1 | |
self.trim = self.n_fft // 2 | |
self.chunk_size = self.hop_length * (self.segment_size - 1) | |
self.gen_size = self.chunk_size - 2 * self.trim | |
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device) | |
self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}") | |
self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}") | |
def initialize_mix(self, mix, is_ckpt=False): | |
self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape)) | |
if mix.shape[0] != 2: | |
error_message = translations["!=2"].format(shape=mix.shape[0]) | |
self.logger.error(error_message) | |
raise ValueError(error_message) | |
if is_ckpt: | |
self.logger.debug(translations["process_check"]) | |
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size) | |
self.logger.debug(f"{translations['cache']}: {pad}") | |
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1) | |
num_chunks = mixture.shape[-1] // self.gen_size | |
self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks)) | |
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)] | |
else: | |
self.logger.debug(translations["process_no_check"]) | |
mix_waves = [] | |
n_sample = mix.shape[1] | |
pad = self.gen_size - n_sample % self.gen_size | |
self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad)) | |
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1) | |
self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}") | |
i = 0 | |
while i < n_sample + pad: | |
waves = np.array(mix_p[:, i : i + self.chunk_size]) | |
mix_waves.append(waves) | |
self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size)) | |
i += self.gen_size | |
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device) | |
self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape)) | |
return mix_waves_tensor, pad | |
def demix(self, mix, is_match_mix=False): | |
self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...") | |
self.initialize_model_settings() | |
org_mix = mix | |
self.logger.debug(f"{translations['mix_shape']}: {org_mix.shape}") | |
tar_waves_ = [] | |
if is_match_mix: | |
chunk_size = self.hop_length * (self.segment_size - 1) | |
overlap = 0.02 | |
self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap)) | |
else: | |
chunk_size = self.chunk_size | |
overlap = self.overlap | |
self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap)) | |
gen_size = chunk_size - 2 * self.trim | |
self.logger.debug(f"{translations['calc_size']}: {gen_size}") | |
pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size) | |
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1) | |
self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}") | |
step = int((1 - overlap) * chunk_size) | |
self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap)) | |
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32) | |
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32) | |
total = 0 | |
total_chunks = (mixture.shape[-1] + step - 1) // step | |
self.logger.debug(f"{translations['all_process_part']}: {total_chunks}") | |
for i in tqdm(range(0, mixture.shape[-1], step)): | |
total += 1 | |
start = i | |
end = min(i + chunk_size, mixture.shape[-1]) | |
self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end)) | |
chunk_size_actual = end - start | |
window = None | |
if overlap != 0: | |
window = np.hanning(chunk_size_actual) | |
window = np.tile(window[None, None, :], (1, 2, 1)) | |
self.logger.debug(translations["window"]) | |
mix_part_ = mixture[:, start:end] | |
if end != i + chunk_size: | |
pad_size = (i + chunk_size) - end | |
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1) | |
mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device) | |
mix_waves = mix_part.split(self.batch_size) | |
total_batches = len(mix_waves) | |
self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}") | |
with torch.no_grad(): | |
batches_processed = 0 | |
for mix_wave in mix_waves: | |
batches_processed += 1 | |
self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}") | |
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix) | |
if window is not None: | |
tar_waves[..., :chunk_size_actual] *= window | |
divider[..., start:end] += window | |
else: divider[..., start:end] += 1 | |
result[..., start:end] += tar_waves[..., : end - start] | |
self.logger.debug(translations["normalization_2"]) | |
tar_waves = result / divider | |
tar_waves_.append(tar_waves) | |
tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim] | |
tar_waves = np.concatenate(tar_waves_, axis=-1)[:, : mix.shape[-1]] | |
source = tar_waves[:, 0:None] | |
self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}") | |
if not is_match_mix: | |
source *= self.compensate | |
self.logger.debug(translations["mix_match"]) | |
self.logger.debug(translations["mix_success"]) | |
return source | |
def run_model(self, mix, is_match_mix=False): | |
spek = self.stft(mix.to(self.torch_device)) | |
self.logger.debug(translations["stft_2"].format(shape=spek.shape)) | |
spek[:, :, :3, :] *= 0 | |
if is_match_mix: | |
spec_pred = spek.cpu().numpy() | |
self.logger.debug(translations["is_match_mix"]) | |
else: | |
if self.enable_denoise: | |
spec_pred_neg = self.model_run(-spek) | |
spec_pred_pos = self.model_run(spek) | |
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5) | |
self.logger.debug(translations["enable_denoise"]) | |
else: | |
spec_pred = self.model_run(spek) | |
self.logger.debug(translations["no_denoise"]) | |
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy() | |
self.logger.debug(f"{translations['stft']}: {result.shape}") | |
return result | |
class STFT: | |
def __init__(self, logger, n_fft, hop_length, dim_f, device): | |
self.logger = logger | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.dim_f = dim_f | |
self.device = device | |
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True) | |
def __call__(self, input_tensor): | |
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"] | |
if is_non_standard_device: input_tensor = input_tensor.cpu() | |
stft_window = self.hann_window.to(input_tensor.device) | |
batch_dimensions = input_tensor.shape[:-2] | |
channel_dim, time_dim = input_tensor.shape[-2:] | |
reshaped_tensor = input_tensor.reshape([-1, time_dim]) | |
stft_output = torch.stft(reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False) | |
permuted_stft_output = stft_output.permute([0, 3, 1, 2]) | |
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]]) | |
if is_non_standard_device: final_output = final_output.to(self.device) | |
return final_output[..., : self.dim_f, :] | |
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins): | |
freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device) | |
padded_tensor = torch.cat([input_tensor, freq_padding], -2) | |
return padded_tensor | |
def calculate_inverse_dimensions(self, input_tensor): | |
batch_dimensions = input_tensor.shape[:-3] | |
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:] | |
num_freq_bins = self.n_fft // 2 + 1 | |
return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins | |
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim): | |
reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim]) | |
flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim]) | |
permuted_tensor = flattened_tensor.permute([0, 2, 3, 1]) | |
complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j | |
return complex_tensor | |
def inverse(self, input_tensor): | |
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"] | |
if is_non_standard_device: input_tensor = input_tensor.cpu() | |
stft_window = self.hann_window.to(input_tensor.device) | |
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor) | |
padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins) | |
complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim) | |
istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True) | |
final_output = istft_result.reshape([*batch_dimensions, 2, -1]) | |
if is_non_standard_device: final_output = final_output.to(self.device) | |
return final_output |