File size: 7,210 Bytes
1e4a2ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import sys
import json
import onnx
import torch
import datetime

from collections import OrderedDict

sys.path.append(os.getcwd())

from main.app.core.ui import gr_info, gr_warning, gr_error
from main.library.algorithm.onnx_export import onnx_exporter
from main.app.variables import config, logger, translations, configs

def fushion_model_pth(name, pth_1, pth_2, ratio):
    if not name.endswith(".pth"): name = name + ".pth"

    if not pth_1 or not os.path.exists(pth_1) or not pth_1.endswith(".pth"):
        gr_warning(translations["provide_file"].format(filename=translations["model"] + " 1"))
        return [translations["provide_file"].format(filename=translations["model"] + " 1"), None]
    
    if not pth_2 or not os.path.exists(pth_2) or not pth_2.endswith(".pth"):
        gr_warning(translations["provide_file"].format(filename=translations["model"] + " 2"))
        return [translations["provide_file"].format(filename=translations["model"] + " 2"), None]
    
    def extract(ckpt):
        a = ckpt["model"]
        opt = OrderedDict()
        opt["weight"] = {}

        for key in a.keys():
            if "enc_q" in key: continue

            opt["weight"][key] = a[key]

        return opt
    
    try:
        ckpt1 = torch.load(pth_1, map_location="cpu", weights_only=True)
        ckpt2 = torch.load(pth_2, map_location="cpu", weights_only=True)

        if ckpt1["sr"] != ckpt2["sr"]: 
            gr_warning(translations["sr_not_same"])
            return [translations["sr_not_same"], None]

        cfg = ckpt1["config"]
        cfg_f0 = ckpt1["f0"]
        cfg_version = ckpt1["version"]
        cfg_sr = ckpt1["sr"]

        vocoder = ckpt1.get("vocoder", "Default")
        rms_extract = ckpt1.get("energy", False)

        ckpt1 = extract(ckpt1) if "model" in ckpt1 else ckpt1["weight"]
        ckpt2 = extract(ckpt2) if "model" in ckpt2 else ckpt2["weight"]

        if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())): 
            gr_warning(translations["architectures_not_same"])
            return [translations["architectures_not_same"], None]
         
        gr_info(translations["start"].format(start=translations["fushion_model"]))

        opt = OrderedDict()
        opt["weight"] = {}

        for key in ckpt1.keys():
            if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
                min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
                opt["weight"][key] = (ratio * (ckpt1[key][:min_shape0].float()) + (1 - ratio) * (ckpt2[key][:min_shape0].float())).half()
            else: opt["weight"][key] = (ratio * (ckpt1[key].float()) + (1 - ratio) * (ckpt2[key].float())).half()

        opt["config"] = cfg
        opt["sr"] = cfg_sr
        opt["f0"] = cfg_f0
        opt["version"] = cfg_version
        opt["infos"] = translations["model_fushion_info"].format(name=name, pth_1=pth_1, pth_2=pth_2, ratio=ratio)
        opt["vocoder"] = vocoder
        opt["energy"] = rms_extract

        output_model = configs["weights_path"]
        if not os.path.exists(output_model): os.makedirs(output_model, exist_ok=True)

        torch.save(opt, os.path.join(output_model, name))

        gr_info(translations["success"])
        return [translations["success"], os.path.join(output_model, name)]
    except Exception as e:
        gr_error(message=translations["error_occurred"].format(e=e))
        return [e, None]

def fushion_model(name, path_1, path_2, ratio):
    if not name:
        gr_warning(translations["provide_name_is_save"]) 
        return [translations["provide_name_is_save"], None]

    if path_1.endswith(".pth") and path_2.endswith(".pth"): return fushion_model_pth(name.replace(".onnx", ".pth"), path_1, path_2, ratio)
    else:
        gr_warning(translations["format_not_valid"])
        return [None, None]
    
def onnx_export(model_path):    
    if not model_path.endswith(".pth"): model_path + ".pth"
    if not model_path or not os.path.exists(model_path) or not model_path.endswith(".pth"): return gr_warning(translations["provide_file"].format(filename=translations["model"]))
    
    try:
        gr_info(translations["start_onnx_export"])
        output = onnx_exporter(model_path, model_path.replace(".pth", ".onnx"), is_half=config.is_half, device=config.device)

        gr_info(translations["success"])
        return output
    except Exception as e:
        return gr_error(e)
    
def model_info(path):
    if not path or not os.path.exists(path) or os.path.isdir(path) or not path.endswith((".pth", ".onnx")): return gr_warning(translations["provide_file"].format(filename=translations["model"]))
    
    def prettify_date(date_str):
        if date_str == translations["not_found_create_time"]: return None

        try:
            return datetime.datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%f").strftime("%Y-%m-%d %H:%M:%S")
        except ValueError as e:
            logger.debug(e)
            return translations["format_not_valid"]
    
    if path.endswith(".pth"): model_data = torch.load(path, map_location=torch.device("cpu"))
    else:        
        model = onnx.load(path)
        model_data = None

        for prop in model.metadata_props:
            if prop.key == "model_info":
                model_data = json.loads(prop.value)
                break

    gr_info(translations["read_info"])

    epochs = model_data.get("epoch", None)
    if epochs is None: 
        epochs = model_data.get("info", None)
        try:
            epoch = epochs.replace("epoch", "").replace("e", "").isdigit()
            if epoch and epochs is None: epochs = translations["not_found"].format(name=translations["epoch"])
        except: 
            pass

    steps = model_data.get("step", translations["not_found"].format(name=translations["step"]))
    sr = model_data.get("sr", translations["not_found"].format(name=translations["sr"]))
    f0 = model_data.get("f0", translations["not_found"].format(name=translations["f0"]))
    version = model_data.get("version", translations["not_found"].format(name=translations["version"]))
    creation_date = model_data.get("creation_date", translations["not_found_create_time"])
    model_hash = model_data.get("model_hash", translations["not_found"].format(name="model_hash"))
    pitch_guidance = translations["trained_f0"] if f0 else translations["not_f0"]
    creation_date_str = prettify_date(creation_date) if creation_date else translations["not_found_create_time"]
    model_name = model_data.get("model_name", translations["unregistered"])
    model_author = model_data.get("author", translations["not_author"])
    vocoder = model_data.get("vocoder", "Default")
    rms_extract = model_data.get("energy", False)

    gr_info(translations["success"])
    return translations["model_info"].format(model_name=model_name, model_author=model_author, epochs=epochs, steps=steps, version=version, sr=sr, pitch_guidance=pitch_guidance, model_hash=model_hash, creation_date_str=creation_date_str, vocoder=vocoder, rms_extract=rms_extract)