import os import re import sys import time import yt_dlp import shutil import librosa import logging import argparse import warnings import logging.handlers import soundfile as sf import noisereduce as nr from distutils.util import strtobool from pydub import AudioSegment, silence now_dir = os.getcwd() sys.path.append(now_dir) from main.configs.config import Config from main.library.algorithm.separator import Separator translations = Config().translations log_file = os.path.join("assets", "logs", "create_dataset.log") logger = logging.getLogger(__name__) if logger.hasHandlers(): logger.handlers.clear() else: console_handler = logging.StreamHandler() console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") console_handler.setFormatter(console_formatter) console_handler.setLevel(logging.INFO) file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8') file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") file_handler.setFormatter(file_formatter) file_handler.setLevel(logging.DEBUG) logger.addHandler(console_handler) logger.addHandler(file_handler) logger.setLevel(logging.DEBUG) def parse_arguments() -> tuple: parser = argparse.ArgumentParser() parser.add_argument("--input_audio", type=str, required=True) parser.add_argument("--output_dataset", type=str, default="./dataset") parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False) parser.add_argument("--resample_sr", type=int, default=44100) parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False) parser.add_argument("--clean_strength", type=float, default=0.7) parser.add_argument("--separator_music", type=lambda x: bool(strtobool(x)), default=False) parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False) parser.add_argument("--kim_vocal_version", type=int, default=2) parser.add_argument("--overlap", type=float, default=0.25) parser.add_argument("--segments_size", type=int, default=256) parser.add_argument("--mdx_hop_length", type=int, default=1024) parser.add_argument("--mdx_batch_size", type=int, default=1) parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False) parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False) parser.add_argument("--skip_start_audios", type=str, default="0") parser.add_argument("--skip_end_audios", type=str, default="0") args = parser.parse_args() return args dataset_temp = os.path.join("dataset_temp") def main(): args = parse_arguments() input_audio = args.input_audio output_dataset = args.output_dataset resample = args.resample resample_sr = args.resample_sr clean_dataset = args.clean_dataset clean_strength = args.clean_strength separator_music = args.separator_music separator_reverb = args.separator_reverb kim_vocal_version = args.kim_vocal_version overlap = args.overlap segments_size = args.segments_size hop_length = args.mdx_hop_length batch_size = args.mdx_batch_size denoise_mdx = args.denoise_mdx skip = args.skip skip_start_audios = args.skip_start_audios skip_end_audios = args.skip_end_audios logger.debug(f"{translations['audio_path']}: {input_audio}") logger.debug(f"{translations['output_path']}: {output_dataset}") logger.debug(f"{translations['resample']}: {resample}") if resample: logger.debug(f"{translations['sample_rate']}: {resample_sr}") logger.debug(f"{translations['clear_dataset']}: {clean_dataset}") if clean_dataset: logger.debug(f"{translations['clean_strength']}: {clean_strength}") logger.debug(f"{translations['separator_music']}: {separator_music}") logger.debug(f"{translations['dereveb_audio']}: {separator_reverb}") if separator_music: logger.debug(f"{translations['training_version']}: {kim_vocal_version}") logger.debug(f"{translations['segments_size']}: {segments_size}") logger.debug(f"{translations['overlap']}: {overlap}") logger.debug(f"Hop length: {hop_length}") logger.debug(f"{translations['batch_size']}: {batch_size}") logger.debug(f"{translations['denoise_mdx']}: {denoise_mdx}") logger.debug(f"{translations['skip']}: {skip}") if skip: logger.debug(f"{translations['skip_start']}: {skip_start_audios}") if skip: logger.debug(f"{translations['skip_end']}: {skip_end_audios}") if kim_vocal_version != 1 and kim_vocal_version != 2: raise ValueError(translations["version_not_valid"]) if separator_reverb and not separator_music: raise ValueError(translations["create_dataset_value_not_valid"]) start_time = time.time() try: paths = [] if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True) urls = input_audio.replace(", ", ",").split(",") for url in urls: path = downloader(url, urls.index(url)) paths.append(path) if skip: skip_start_audios = skip_start_audios.replace(", ", ",").split(",") skip_end_audios = skip_end_audios.replace(", ", ",").split(",") if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths): logger.warning(translations["skip<audio"]) sys.exit(1) elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths): logger.warning(translations["skip>audio"]) sys.exit(1) else: for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios): skip_start(audio, skip_start_audio) skip_end(audio, skip_end_audio) if separator_music: separator_paths = [] for audio in paths: vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size) if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size) separator_paths.append(vocals) paths = separator_paths processed_paths = [] for audio in paths: output = process_audio(audio) processed_paths.append(output) paths = processed_paths for audio_path in paths: data, sample_rate = sf.read(audio_path) if resample_sr != sample_rate and resample_sr > 0 and resample: data = librosa.resample(data, orig_sr=sample_rate, target_sr=resample_sr) sample_rate = resample_sr if clean_dataset: data = nr.reduce_noise(y=data, prop_decrease=clean_strength) sf.write(audio_path, data, sample_rate) except Exception as e: raise RuntimeError(f"{translations['create_dataset_error']}: {e}") finally: for audio in paths: shutil.move(audio, output_dataset) if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True) elapsed_time = time.time() - start_time logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}")) def downloader(url, name): with warnings.catch_warnings(): warnings.simplefilter("ignore") ydl_opts = { 'format': 'bestaudio/best', 'outtmpl': os.path.join(dataset_temp, f"{name}"), 'postprocessors': [{ 'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', 'preferredquality': '192', }], 'noplaylist': True, 'verbose': False, } logger.info(f"{translations['starting_download']}: {url}...") with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.extract_info(url) logger.info(f"{translations['download_success']}: {url}") return os.path.join(dataset_temp, f"{name}" + ".wav") def skip_start(input_file, seconds): data, sr = sf.read(input_file) total_duration = len(data) / sr if seconds <= 0: logger.warning(translations["=<0"]) elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}")) else: logger.info(f"{translations['skip_start']}: {input_file}...") sf.write(input_file, data[int(seconds * sr):], sr) logger.info(translations["skip_start_audio"].format(input_file=input_file)) def skip_end(input_file, seconds): data, sr = sf.read(input_file) total_duration = len(data) / sr if seconds <= 0: logger.warning(translations["=<0"]) elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}")) else: logger.info(f"{translations['skip_end']}: {input_file}...") sf.write(input_file, data[:-int(seconds * sr)], sr) logger.info(translations["skip_end_audio"].format(input_file=input_file)) def process_audio(file_path): try: song = AudioSegment.from_file(file_path) nonsilent_parts = silence.detect_nonsilent(song, min_silence_len=750, silence_thresh=-70) cut_files = [] for i, (start_i, end_i) in enumerate(nonsilent_parts): chunk = song[start_i:end_i] if len(chunk) >= 30: chunk_file_path = os.path.join(os.path.dirname(file_path), f"chunk{i}.wav") if os.path.exists(chunk_file_path): os.remove(chunk_file_path) chunk.export(chunk_file_path, format="wav") cut_files.append(chunk_file_path) else: logger.warning(translations["skip_file"].format(i=i, chunk=len(chunk))) logger.info(f"{translations['split_total']}: {len(cut_files)}") def extract_number(filename): match = re.search(r'_(\d+)', filename) return int(match.group(1)) if match else 0 cut_files = sorted(cut_files, key=extract_number) combined = AudioSegment.empty() for file in cut_files: combined += AudioSegment.from_file(file) output_path = os.path.splitext(file_path)[0] + "_processed" + ".wav" logger.info(translations["merge_audio"]) combined.export(output_path, format="wav") return output_path except Exception as e: raise RuntimeError(f"{translations['process_audio_error']}: {e}") def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size): if not os.path.exists(input): logger.warning(translations["input_not_valid"]) return None if not os.path.exists(output): logger.warning(translations["output_not_valid"]) return None model = f"Kim_Vocal_{version}.onnx" logger.info(translations["separator_process"].format(input=input)) output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise) for f in output_separator: path = os.path.join(output, f) if not os.path.exists(path): logger.error(translations["not_found"].format(name=path)) if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav") elif '_(Vocals)_' in f: rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav" os.rename(path, rename_file) logger.info(f": {rename_file}") return rename_file def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size): reverb_models = "Reverb_HQ_By_FoxJoy.onnx" if not os.path.exists(input): logger.warning(translations["input_not_valid"]) return None if not os.path.exists(output): logger.warning(translations["output_not_valid"]) return None logger.info(f"{translations['dereverb']}: {input}...") output_dereverb = separator_main(audio_file=input, model_filename=reverb_models, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise) for f in output_dereverb: path = os.path.join(output, f) if not os.path.exists(path): logger.error(translations["not_found"].format(name=path)) if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav") elif '_(No Reverb)_' in f: rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav" os.rename(path, rename_file) logger.info(f"{translations['dereverb_success']}: {rename_file}") return rename_file def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True): separator = Separator( log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={ "hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise, }, ) separator.load_model(model_filename=model_filename) return separator.separate(audio_file) if __name__ == "__main__": main()