|
import os
|
|
import sys
|
|
import json
|
|
import onnx
|
|
import torch
|
|
import datetime
|
|
|
|
from collections import OrderedDict
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.app.core.ui import gr_info, gr_warning, gr_error
|
|
from main.library.algorithm.onnx_export import onnx_exporter
|
|
from main.app.variables import config, logger, translations, configs
|
|
|
|
def fushion_model_pth(name, pth_1, pth_2, ratio):
|
|
if not name.endswith(".pth"): name = name + ".pth"
|
|
|
|
if not pth_1 or not os.path.exists(pth_1) or not pth_1.endswith(".pth"):
|
|
gr_warning(translations["provide_file"].format(filename=translations["model"] + " 1"))
|
|
return [translations["provide_file"].format(filename=translations["model"] + " 1"), None]
|
|
|
|
if not pth_2 or not os.path.exists(pth_2) or not pth_2.endswith(".pth"):
|
|
gr_warning(translations["provide_file"].format(filename=translations["model"] + " 2"))
|
|
return [translations["provide_file"].format(filename=translations["model"] + " 2"), None]
|
|
|
|
def extract(ckpt):
|
|
a = ckpt["model"]
|
|
opt = OrderedDict()
|
|
opt["weight"] = {}
|
|
|
|
for key in a.keys():
|
|
if "enc_q" in key: continue
|
|
|
|
opt["weight"][key] = a[key]
|
|
|
|
return opt
|
|
|
|
try:
|
|
ckpt1 = torch.load(pth_1, map_location="cpu", weights_only=True)
|
|
ckpt2 = torch.load(pth_2, map_location="cpu", weights_only=True)
|
|
|
|
if ckpt1["sr"] != ckpt2["sr"]:
|
|
gr_warning(translations["sr_not_same"])
|
|
return [translations["sr_not_same"], None]
|
|
|
|
cfg = ckpt1["config"]
|
|
cfg_f0 = ckpt1["f0"]
|
|
cfg_version = ckpt1["version"]
|
|
cfg_sr = ckpt1["sr"]
|
|
|
|
vocoder = ckpt1.get("vocoder", "Default")
|
|
rms_extract = ckpt1.get("energy", False)
|
|
|
|
ckpt1 = extract(ckpt1) if "model" in ckpt1 else ckpt1["weight"]
|
|
ckpt2 = extract(ckpt2) if "model" in ckpt2 else ckpt2["weight"]
|
|
|
|
if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
|
|
gr_warning(translations["architectures_not_same"])
|
|
return [translations["architectures_not_same"], None]
|
|
|
|
gr_info(translations["start"].format(start=translations["fushion_model"]))
|
|
|
|
opt = OrderedDict()
|
|
opt["weight"] = {}
|
|
|
|
for key in ckpt1.keys():
|
|
if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
|
|
min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
|
|
opt["weight"][key] = (ratio * (ckpt1[key][:min_shape0].float()) + (1 - ratio) * (ckpt2[key][:min_shape0].float())).half()
|
|
else: opt["weight"][key] = (ratio * (ckpt1[key].float()) + (1 - ratio) * (ckpt2[key].float())).half()
|
|
|
|
opt["config"] = cfg
|
|
opt["sr"] = cfg_sr
|
|
opt["f0"] = cfg_f0
|
|
opt["version"] = cfg_version
|
|
opt["infos"] = translations["model_fushion_info"].format(name=name, pth_1=pth_1, pth_2=pth_2, ratio=ratio)
|
|
opt["vocoder"] = vocoder
|
|
opt["energy"] = rms_extract
|
|
|
|
output_model = configs["weights_path"]
|
|
if not os.path.exists(output_model): os.makedirs(output_model, exist_ok=True)
|
|
|
|
torch.save(opt, os.path.join(output_model, name))
|
|
|
|
gr_info(translations["success"])
|
|
return [translations["success"], os.path.join(output_model, name)]
|
|
except Exception as e:
|
|
gr_error(message=translations["error_occurred"].format(e=e))
|
|
return [e, None]
|
|
|
|
def fushion_model(name, path_1, path_2, ratio):
|
|
if not name:
|
|
gr_warning(translations["provide_name_is_save"])
|
|
return [translations["provide_name_is_save"], None]
|
|
|
|
if path_1.endswith(".pth") and path_2.endswith(".pth"): return fushion_model_pth(name.replace(".onnx", ".pth"), path_1, path_2, ratio)
|
|
else:
|
|
gr_warning(translations["format_not_valid"])
|
|
return [None, None]
|
|
|
|
def onnx_export(model_path):
|
|
if not model_path.endswith(".pth"): model_path + ".pth"
|
|
if not model_path or not os.path.exists(model_path) or not model_path.endswith(".pth"): return gr_warning(translations["provide_file"].format(filename=translations["model"]))
|
|
|
|
try:
|
|
gr_info(translations["start_onnx_export"])
|
|
output = onnx_exporter(model_path, model_path.replace(".pth", ".onnx"), is_half=config.is_half, device=config.device)
|
|
|
|
gr_info(translations["success"])
|
|
return output
|
|
except Exception as e:
|
|
return gr_error(e)
|
|
|
|
def model_info(path):
|
|
if not path or not os.path.exists(path) or os.path.isdir(path) or not path.endswith((".pth", ".onnx")): return gr_warning(translations["provide_file"].format(filename=translations["model"]))
|
|
|
|
def prettify_date(date_str):
|
|
if date_str == translations["not_found_create_time"]: return None
|
|
|
|
try:
|
|
return datetime.datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%f").strftime("%Y-%m-%d %H:%M:%S")
|
|
except ValueError as e:
|
|
logger.debug(e)
|
|
return translations["format_not_valid"]
|
|
|
|
if path.endswith(".pth"): model_data = torch.load(path, map_location=torch.device("cpu"))
|
|
else:
|
|
model = onnx.load(path)
|
|
model_data = None
|
|
|
|
for prop in model.metadata_props:
|
|
if prop.key == "model_info":
|
|
model_data = json.loads(prop.value)
|
|
break
|
|
|
|
gr_info(translations["read_info"])
|
|
|
|
epochs = model_data.get("epoch", None)
|
|
if epochs is None:
|
|
epochs = model_data.get("info", None)
|
|
try:
|
|
epoch = epochs.replace("epoch", "").replace("e", "").isdigit()
|
|
if epoch and epochs is None: epochs = translations["not_found"].format(name=translations["epoch"])
|
|
except:
|
|
pass
|
|
|
|
steps = model_data.get("step", translations["not_found"].format(name=translations["step"]))
|
|
sr = model_data.get("sr", translations["not_found"].format(name=translations["sr"]))
|
|
f0 = model_data.get("f0", translations["not_found"].format(name=translations["f0"]))
|
|
version = model_data.get("version", translations["not_found"].format(name=translations["version"]))
|
|
creation_date = model_data.get("creation_date", translations["not_found_create_time"])
|
|
model_hash = model_data.get("model_hash", translations["not_found"].format(name="model_hash"))
|
|
pitch_guidance = translations["trained_f0"] if f0 else translations["not_f0"]
|
|
creation_date_str = prettify_date(creation_date) if creation_date else translations["not_found_create_time"]
|
|
model_name = model_data.get("model_name", translations["unregistered"])
|
|
model_author = model_data.get("author", translations["not_author"])
|
|
vocoder = model_data.get("vocoder", "Default")
|
|
rms_extract = model_data.get("energy", False)
|
|
|
|
gr_info(translations["success"])
|
|
return translations["model_info"].format(model_name=model_name, model_author=model_author, epochs=epochs, steps=steps, version=version, sr=sr, pitch_guidance=pitch_guidance, model_hash=model_hash, creation_date_str=creation_date_str, vocoder=vocoder, rms_extract=rms_extract) |