File size: 857 Bytes
0891b79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torchvision
import io
import numpy as np
import torch.onnx
import onnx
from tools import Inference, Matting, log, extract_matte, harmonize, css, execute_onnx_model
from omegaconf import OmegaConf
import os
import sys
import torch
import numpy as np
import torchvision.transforms.functional as tf
from PIL import Image
import cv2 as cv
from onnxruntime import InferenceSession

args = OmegaConf.load(os.path.join(f"./config/test.yaml"))

log("Model loading")
phnet = Inference(**args)
stylematte = Matting(**args)
log("Model loaded")
model = stylematte.model

x = torch.randn((1, 3, 720, 1280))
mask = torch.ones((1, 1, 512, 512))
path = 'checkpoints/stylematte-test.onnx'

# Export
torch.onnx.export(model, x, path, opset_version=16)

# Validation
onnx_model = onnx.load(path)
onnx.checker.check_model(onnx_model)
# execute_onnx_model(x, onnx_model)