File size: 4,027 Bytes
1e4a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import io
import sys
import onnx
import json
import torch
import onnxsim
import warnings
sys.path.append(os.getcwd())
from main.library.algorithm.synthesizers import SynthesizerONNX
warnings.filterwarnings("ignore")
def onnx_exporter(input_path, output_path, is_half=False, device="cpu"):
cpt = (torch.load(input_path, map_location="cpu", weights_only=True) if os.path.isfile(input_path) else None)
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
model_name, model_author, epochs, steps, version, f0, model_hash, vocoder, creation_date, energy_use = cpt.get("model_name", None), cpt.get("author", None), cpt.get("epoch", None), cpt.get("step", None), cpt.get("version", "v1"), cpt.get("f0", 1), cpt.get("model_hash", None), cpt.get("vocoder", "Default"), cpt.get("creation_date", None), cpt.get("energy", False)
text_enc_hidden_dim = 768 if version == "v2" else 256
tgt_sr = cpt["config"][-1]
net_g = SynthesizerONNX(*cpt["config"], use_f0=f0, text_enc_hidden_dim=text_enc_hidden_dim, vocoder=vocoder, checkpointing=False, energy=energy_use)
net_g.load_state_dict(cpt["weight"], strict=False)
net_g.eval().to(device)
net_g = (net_g.half() if is_half else net_g.float())
phone = torch.rand(1, 200, text_enc_hidden_dim).to(device)
phone_length = torch.tensor([200]).long().to(device)
ds = torch.LongTensor([0]).to(device)
rnd = torch.rand(1, 192, 200).to(device)
if f0:
pitch = torch.randint(size=(1, 200), low=5, high=255).to(device)
pitchf = torch.rand(1, 200).to(device)
if energy_use:
energy = torch.rand(1, 200).to(device)
args = [phone, phone_length, ds, rnd]
input_names = ["phone", "phone_lengths", "ds", "rnd"]
dynamic_axes = {"phone": [1], "rnd": [2]}
if f0:
args += [pitch, pitchf]
input_names += ["pitch", "pitchf"]
dynamic_axes.update({"pitch": [1], "pitchf": [1]})
if energy_use:
args.append(energy)
input_names.append("energy")
dynamic_axes.update({"energy": [1]})
try:
with io.BytesIO() as model:
torch.onnx.export(
net_g,
tuple(args),
model,
do_constant_folding=True,
opset_version=17,
verbose=False,
input_names=input_names,
output_names=["audio"],
dynamic_axes=dynamic_axes
)
model, _ = onnxsim.simplify(onnx.load_model_from_string(model.getvalue()))
model.metadata_props.append(
onnx.StringStringEntryProto(
key="model_info",
value=json.dumps(
{
"model_name": model_name,
"author": model_author,
"epoch": epochs,
"step": steps,
"version": version,
"sr": tgt_sr,
"f0": f0,
"model_hash": model_hash,
"creation_date": creation_date,
"vocoder": vocoder,
"text_enc_hidden_dim": text_enc_hidden_dim,
"energy": energy_use
}
)
)
)
if is_half:
try:
import onnxconverter_common
except:
os.system(f"{sys.executable} -m pip install onnxconverter_common")
import onnxconverter_common
model = onnxconverter_common.convert_float_to_float16(model, keep_io_types=True)
onnx.save(model, output_path)
return output_path
except:
import traceback
print(traceback.print_exc())
return None |