File size: 7,210 Bytes
1e4a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
import sys
import json
import onnx
import torch
import datetime
from collections import OrderedDict
sys.path.append(os.getcwd())
from main.app.core.ui import gr_info, gr_warning, gr_error
from main.library.algorithm.onnx_export import onnx_exporter
from main.app.variables import config, logger, translations, configs
def fushion_model_pth(name, pth_1, pth_2, ratio):
if not name.endswith(".pth"): name = name + ".pth"
if not pth_1 or not os.path.exists(pth_1) or not pth_1.endswith(".pth"):
gr_warning(translations["provide_file"].format(filename=translations["model"] + " 1"))
return [translations["provide_file"].format(filename=translations["model"] + " 1"), None]
if not pth_2 or not os.path.exists(pth_2) or not pth_2.endswith(".pth"):
gr_warning(translations["provide_file"].format(filename=translations["model"] + " 2"))
return [translations["provide_file"].format(filename=translations["model"] + " 2"), None]
def extract(ckpt):
a = ckpt["model"]
opt = OrderedDict()
opt["weight"] = {}
for key in a.keys():
if "enc_q" in key: continue
opt["weight"][key] = a[key]
return opt
try:
ckpt1 = torch.load(pth_1, map_location="cpu", weights_only=True)
ckpt2 = torch.load(pth_2, map_location="cpu", weights_only=True)
if ckpt1["sr"] != ckpt2["sr"]:
gr_warning(translations["sr_not_same"])
return [translations["sr_not_same"], None]
cfg = ckpt1["config"]
cfg_f0 = ckpt1["f0"]
cfg_version = ckpt1["version"]
cfg_sr = ckpt1["sr"]
vocoder = ckpt1.get("vocoder", "Default")
rms_extract = ckpt1.get("energy", False)
ckpt1 = extract(ckpt1) if "model" in ckpt1 else ckpt1["weight"]
ckpt2 = extract(ckpt2) if "model" in ckpt2 else ckpt2["weight"]
if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
gr_warning(translations["architectures_not_same"])
return [translations["architectures_not_same"], None]
gr_info(translations["start"].format(start=translations["fushion_model"]))
opt = OrderedDict()
opt["weight"] = {}
for key in ckpt1.keys():
if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
opt["weight"][key] = (ratio * (ckpt1[key][:min_shape0].float()) + (1 - ratio) * (ckpt2[key][:min_shape0].float())).half()
else: opt["weight"][key] = (ratio * (ckpt1[key].float()) + (1 - ratio) * (ckpt2[key].float())).half()
opt["config"] = cfg
opt["sr"] = cfg_sr
opt["f0"] = cfg_f0
opt["version"] = cfg_version
opt["infos"] = translations["model_fushion_info"].format(name=name, pth_1=pth_1, pth_2=pth_2, ratio=ratio)
opt["vocoder"] = vocoder
opt["energy"] = rms_extract
output_model = configs["weights_path"]
if not os.path.exists(output_model): os.makedirs(output_model, exist_ok=True)
torch.save(opt, os.path.join(output_model, name))
gr_info(translations["success"])
return [translations["success"], os.path.join(output_model, name)]
except Exception as e:
gr_error(message=translations["error_occurred"].format(e=e))
return [e, None]
def fushion_model(name, path_1, path_2, ratio):
if not name:
gr_warning(translations["provide_name_is_save"])
return [translations["provide_name_is_save"], None]
if path_1.endswith(".pth") and path_2.endswith(".pth"): return fushion_model_pth(name.replace(".onnx", ".pth"), path_1, path_2, ratio)
else:
gr_warning(translations["format_not_valid"])
return [None, None]
def onnx_export(model_path):
if not model_path.endswith(".pth"): model_path + ".pth"
if not model_path or not os.path.exists(model_path) or not model_path.endswith(".pth"): return gr_warning(translations["provide_file"].format(filename=translations["model"]))
try:
gr_info(translations["start_onnx_export"])
output = onnx_exporter(model_path, model_path.replace(".pth", ".onnx"), is_half=config.is_half, device=config.device)
gr_info(translations["success"])
return output
except Exception as e:
return gr_error(e)
def model_info(path):
if not path or not os.path.exists(path) or os.path.isdir(path) or not path.endswith((".pth", ".onnx")): return gr_warning(translations["provide_file"].format(filename=translations["model"]))
def prettify_date(date_str):
if date_str == translations["not_found_create_time"]: return None
try:
return datetime.datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%f").strftime("%Y-%m-%d %H:%M:%S")
except ValueError as e:
logger.debug(e)
return translations["format_not_valid"]
if path.endswith(".pth"): model_data = torch.load(path, map_location=torch.device("cpu"))
else:
model = onnx.load(path)
model_data = None
for prop in model.metadata_props:
if prop.key == "model_info":
model_data = json.loads(prop.value)
break
gr_info(translations["read_info"])
epochs = model_data.get("epoch", None)
if epochs is None:
epochs = model_data.get("info", None)
try:
epoch = epochs.replace("epoch", "").replace("e", "").isdigit()
if epoch and epochs is None: epochs = translations["not_found"].format(name=translations["epoch"])
except:
pass
steps = model_data.get("step", translations["not_found"].format(name=translations["step"]))
sr = model_data.get("sr", translations["not_found"].format(name=translations["sr"]))
f0 = model_data.get("f0", translations["not_found"].format(name=translations["f0"]))
version = model_data.get("version", translations["not_found"].format(name=translations["version"]))
creation_date = model_data.get("creation_date", translations["not_found_create_time"])
model_hash = model_data.get("model_hash", translations["not_found"].format(name="model_hash"))
pitch_guidance = translations["trained_f0"] if f0 else translations["not_f0"]
creation_date_str = prettify_date(creation_date) if creation_date else translations["not_found_create_time"]
model_name = model_data.get("model_name", translations["unregistered"])
model_author = model_data.get("author", translations["not_author"])
vocoder = model_data.get("vocoder", "Default")
rms_extract = model_data.get("energy", False)
gr_info(translations["success"])
return translations["model_info"].format(model_name=model_name, model_author=model_author, epochs=epochs, steps=steps, version=version, sr=sr, pitch_guidance=pitch_guidance, model_hash=model_hash, creation_date_str=creation_date_str, vocoder=vocoder, rms_extract=rms_extract) |