Spaces:
Running
Running
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
from PIL import Image | |
if tf.__version__ >= '2.0': | |
tf = tf.compat.v1 | |
class ImageUniversalMatting: | |
def __init__(self, weight_path): | |
super().__init__() | |
config = tf.ConfigProto(allow_soft_placement=True, device_count={'GPU': 0}) | |
config.gpu_options.allow_growth = True | |
self._session = tf.Session(config=config) | |
with self._session.as_default(): | |
print(f'loading model from {weight_path}') | |
with tf.gfile.FastGFile(weight_path, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
tf.import_graph_def(graph_def, name='') | |
self.output = self._session.graph.get_tensor_by_name( | |
'output_png:0') | |
self.input_name = 'input_image:0' | |
print('load model done') | |
self._session.graph.finalize() | |
def __call__(self, image): | |
output = self.preprocess(image) | |
output = self.forward(output) | |
output = self.postprocess(output) | |
return output | |
def resize_image(self, img, limit_side_len): | |
""" | |
resize image to a size multiple of 32 which is required by the network | |
args: | |
img(array): array with shape [h, w, c] | |
return(tuple): | |
img, (ratio_h, ratio_w) | |
""" | |
h, w, _ = img.shape | |
# limit the max side | |
if max(h, w) > limit_side_len: | |
if h > w: | |
ratio = float(limit_side_len) / h | |
else: | |
ratio = float(limit_side_len) / w | |
else: | |
ratio = 1. | |
resize_h = int(h * ratio) | |
resize_w = int(w * ratio) | |
resize_h = int(round(resize_h / 32) * 32) | |
resize_w = int(round(resize_w / 32) * 32) | |
img = cv2.resize(img, (int(resize_w), int(resize_h))) | |
return img | |
def convert_to_ndarray(img): | |
if isinstance(img, Image.Image): | |
img = np.array(img.convert('RGB')) | |
elif isinstance(img, np.ndarray): | |
if len(img.shape) == 2: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
img = img[:, :, ::-1] # convert to rgb | |
else: | |
raise TypeError(f'input should be either PIL.Image,' | |
f' np.array, but got {type(img)}') | |
return img | |
def preprocess(self, input, limit_side_len=800): | |
img = self.convert_to_ndarray(input) # rgb input | |
img = img.astype(float) | |
orig_h, orig_w, _ = img.shape | |
img = self.resize_image(img, limit_side_len) | |
result = {'img': img, 'orig_h': orig_h, 'orig_w': orig_w} | |
return result | |
def forward(self, input): | |
orig_h, orig_w = input['orig_h'], input['orig_w'] | |
with self._session.as_default(): | |
feed_dict = {self.input_name: input['img']} | |
output_img = self._session.run(self.output, feed_dict=feed_dict) # RGBA | |
# output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA) | |
output_img = cv2.resize(output_img, (int(orig_w), int(orig_h))) | |
return {"output_img": output_img} | |
def postprocess(self, inputs): | |
return inputs["output_img"] | |