|
import os
|
|
import sys
|
|
import logging
|
|
import argparse
|
|
import warnings
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
from distutils.util import strtobool
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.library import opencl
|
|
from main.library.utils import check_assets
|
|
from main.inference.extracting.rms import run_rms_extraction
|
|
from main.inference.extracting.feature import run_pitch_extraction
|
|
from main.app.variables import config, logger, translations, configs
|
|
from main.inference.extracting.embedding import run_embedding_extraction
|
|
from main.inference.extracting.preparing_files import generate_config, generate_filelist
|
|
|
|
warnings.filterwarnings("ignore")
|
|
for l in ["torch", "faiss", "httpx", "httpcore", "faiss.loader", "numba.core", "urllib3", "matplotlib"]:
|
|
logging.getLogger(l).setLevel(logging.ERROR)
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--extract", action='store_true')
|
|
parser.add_argument("--model_name", type=str, required=True)
|
|
parser.add_argument("--rvc_version", type=str, default="v2")
|
|
parser.add_argument("--f0_method", type=str, default="rmvpe")
|
|
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
|
|
parser.add_argument("--hop_length", type=int, default=128)
|
|
parser.add_argument("--cpu_cores", type=int, default=2)
|
|
parser.add_argument("--gpu", type=str, default="-")
|
|
parser.add_argument("--sample_rate", type=int, required=True)
|
|
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
|
|
parser.add_argument("--f0_onnx", type=lambda x: bool(strtobool(x)), default=False)
|
|
parser.add_argument("--embedders_mode", type=str, default="fairseq")
|
|
parser.add_argument("--f0_autotune", type=lambda x: bool(strtobool(x)), default=False)
|
|
parser.add_argument("--f0_autotune_strength", type=float, default=1)
|
|
parser.add_argument("--rms_extract", type=lambda x: bool(strtobool(x)), default=False)
|
|
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
|
|
f0_method, hop_length, num_processes, gpus, version, pitch_guidance, sample_rate, embedder_model, f0_onnx, embedders_mode, f0_autotune, f0_autotune_strength, rms_extract = args.f0_method, args.hop_length, args.cpu_cores, args.gpu, args.rvc_version, args.pitch_guidance, args.sample_rate, args.embedder_model, args.f0_onnx, args.embedders_mode, args.f0_autotune, args.f0_autotune_strength, args.rms_extract
|
|
exp_dir = os.path.join(configs["logs_path"], args.model_name)
|
|
|
|
devices = ["cpu"] if gpus == "-" else [(f"ocl:{idx}" if opencl.is_available() and config.device.startswith("ocl") else f"cuda:{idx}") for idx in gpus.split("-")]
|
|
check_assets(f0_method, embedder_model, f0_onnx=f0_onnx, embedders_mode=embedders_mode)
|
|
|
|
log_data = {translations['modelname']: args.model_name, translations['export_process']: exp_dir, translations['f0_method']: f0_method, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, "Gpu": gpus, "Hop length": hop_length, translations['training_version']: version, translations['extract_f0']: pitch_guidance, translations['hubert_model']: embedder_model, translations["f0_onnx_mode"]: f0_onnx, translations["embed_mode"]: embedders_mode, translations["train&energy"]: rms_extract}
|
|
for key, value in log_data.items():
|
|
logger.debug(f"{key}: {value}")
|
|
|
|
pid_path = os.path.join(exp_dir, "extract_pid.txt")
|
|
with open(pid_path, "w") as pid_file:
|
|
pid_file.write(str(os.getpid()))
|
|
|
|
try:
|
|
run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, devices, f0_onnx, config.is_half, f0_autotune, f0_autotune_strength)
|
|
run_embedding_extraction(exp_dir, version, num_processes, devices, embedder_model, embedders_mode, config.is_half)
|
|
run_rms_extraction(exp_dir, num_processes, devices, rms_extract)
|
|
generate_config(version, sample_rate, exp_dir)
|
|
generate_filelist(pitch_guidance, exp_dir, version, sample_rate, embedders_mode, embedder_model, rms_extract)
|
|
except Exception as e:
|
|
logger.error(f"{translations['extract_error']}: {e}")
|
|
|
|
if os.path.exists(pid_path): os.remove(pid_path)
|
|
logger.info(f"{translations['extract_success']} {args.model_name}.")
|
|
|
|
if __name__ == "__main__":
|
|
mp.set_start_method("spawn", force=True)
|
|
main() |