File size: 4,402 Bytes
1e4a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import sys
import logging
import argparse
import warnings
import torch.multiprocessing as mp
from distutils.util import strtobool
sys.path.append(os.getcwd())
from main.library import opencl
from main.library.utils import check_assets
from main.inference.extracting.rms import run_rms_extraction
from main.inference.extracting.feature import run_pitch_extraction
from main.app.variables import config, logger, translations, configs
from main.inference.extracting.embedding import run_embedding_extraction
from main.inference.extracting.preparing_files import generate_config, generate_filelist
warnings.filterwarnings("ignore")
for l in ["torch", "faiss", "httpx", "httpcore", "faiss.loader", "numba.core", "urllib3", "matplotlib"]:
logging.getLogger(l).setLevel(logging.ERROR)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--extract", action='store_true')
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--rvc_version", type=str, default="v2")
parser.add_argument("--f0_method", type=str, default="rmvpe")
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
parser.add_argument("--hop_length", type=int, default=128)
parser.add_argument("--cpu_cores", type=int, default=2)
parser.add_argument("--gpu", type=str, default="-")
parser.add_argument("--sample_rate", type=int, required=True)
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
parser.add_argument("--f0_onnx", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--embedders_mode", type=str, default="fairseq")
parser.add_argument("--f0_autotune", type=lambda x: bool(strtobool(x)), default=False)
parser.add_argument("--f0_autotune_strength", type=float, default=1)
parser.add_argument("--rms_extract", type=lambda x: bool(strtobool(x)), default=False)
return parser.parse_args()
def main():
args = parse_arguments()
f0_method, hop_length, num_processes, gpus, version, pitch_guidance, sample_rate, embedder_model, f0_onnx, embedders_mode, f0_autotune, f0_autotune_strength, rms_extract = args.f0_method, args.hop_length, args.cpu_cores, args.gpu, args.rvc_version, args.pitch_guidance, args.sample_rate, args.embedder_model, args.f0_onnx, args.embedders_mode, args.f0_autotune, args.f0_autotune_strength, args.rms_extract
exp_dir = os.path.join(configs["logs_path"], args.model_name)
devices = ["cpu"] if gpus == "-" else [(f"ocl:{idx}" if opencl.is_available() and config.device.startswith("ocl") else f"cuda:{idx}") for idx in gpus.split("-")]
check_assets(f0_method, embedder_model, f0_onnx=f0_onnx, embedders_mode=embedders_mode)
log_data = {translations['modelname']: args.model_name, translations['export_process']: exp_dir, translations['f0_method']: f0_method, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, "Gpu": gpus, "Hop length": hop_length, translations['training_version']: version, translations['extract_f0']: pitch_guidance, translations['hubert_model']: embedder_model, translations["f0_onnx_mode"]: f0_onnx, translations["embed_mode"]: embedders_mode, translations["train&energy"]: rms_extract}
for key, value in log_data.items():
logger.debug(f"{key}: {value}")
pid_path = os.path.join(exp_dir, "extract_pid.txt")
with open(pid_path, "w") as pid_file:
pid_file.write(str(os.getpid()))
try:
run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, devices, f0_onnx, config.is_half, f0_autotune, f0_autotune_strength)
run_embedding_extraction(exp_dir, version, num_processes, devices, embedder_model, embedders_mode, config.is_half)
run_rms_extraction(exp_dir, num_processes, devices, rms_extract)
generate_config(version, sample_rate, exp_dir)
generate_filelist(pitch_guidance, exp_dir, version, sample_rate, embedders_mode, embedder_model, rms_extract)
except Exception as e:
logger.error(f"{translations['extract_error']}: {e}")
if os.path.exists(pid_path): os.remove(pid_path)
logger.info(f"{translations['extract_success']} {args.model_name}.")
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
main() |