u2net_rgba / app.py
ichtestenurmal's picture
Duplicate from xiongjie/u2net_rgba
6d8a73e
raw
history blame
1.81 kB
import os
import copy
import time
import cv2 as cv
import numpy as np
import onnxruntime
from PIL import Image
import gradio
def run_inference(onnx_session, input_size, image):
# リサイズ
temp_image = copy.deepcopy(image)
resize_image = cv.resize(temp_image, dsize=(input_size, input_size))
x = cv.cvtColor(resize_image, cv.COLOR_BGR2RGB)
# 前処理
x = np.array(x, dtype=np.float32)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x = (x / 255 - mean) / std
x = x.transpose(2, 0, 1).astype('float32')
x = x.reshape(-1, 3, input_size, input_size)
# 推論
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
onnx_result = onnx_session.run([output_name], {input_name: x})
# 後処理
onnx_result = np.array(onnx_result).squeeze()
min_value = np.min(onnx_result)
max_value = np.max(onnx_result)
onnx_result = (onnx_result - min_value) / (max_value - min_value)
onnx_result *= 255
onnx_result = onnx_result.astype('uint8')
return onnx_result
# Load model
onnx_session = onnxruntime.InferenceSession("u2net.onnx")
def create_rgba(mode, image):
out = run_inference(
onnx_session,
320,
image,
)
resize_image = cv.resize(out, dsize=(image.shape[1], image.shape[0]))
if mode == "binary":
resize_image[resize_image > 255] = 255
resize_image[resize_image < 125] = 0
mask = Image.fromarray(resize_image)
rgba_image = Image.fromarray(image).convert('RGBA')
rgba_image.putalpha(mask)
return rgba_image
inputs = [gradio.inputs.Radio(["binary", "smooth"]), gradio.inputs.Image()]
outputs = gradio.outputs.Image()
gradio.Interface(fn=create_rgba, inputs=inputs, outputs=outputs).launch()