AnyStory / src /matting.py
Junjie96's picture
Upload 46 files
9c18e52 verified
import cv2
import numpy as np
import tensorflow as tf
from PIL import Image
if tf.__version__ >= '2.0':
tf = tf.compat.v1
class ImageUniversalMatting:
def __init__(self, weight_path):
super().__init__()
config = tf.ConfigProto(allow_soft_placement=True, device_count={'GPU': 0})
config.gpu_options.allow_growth = True
self._session = tf.Session(config=config)
with self._session.as_default():
print(f'loading model from {weight_path}')
with tf.gfile.FastGFile(weight_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
self.output = self._session.graph.get_tensor_by_name(
'output_png:0')
self.input_name = 'input_image:0'
print('load model done')
self._session.graph.finalize()
def __call__(self, image):
output = self.preprocess(image)
output = self.forward(output)
output = self.postprocess(output)
return output
def resize_image(self, img, limit_side_len):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
h, w, _ = img.shape
# limit the max side
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img
@staticmethod
def convert_to_ndarray(img):
if isinstance(img, Image.Image):
img = np.array(img.convert('RGB'))
elif isinstance(img, np.ndarray):
if len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = img[:, :, ::-1] # convert to rgb
else:
raise TypeError(f'input should be either PIL.Image,'
f' np.array, but got {type(img)}')
return img
def preprocess(self, input, limit_side_len=800):
img = self.convert_to_ndarray(input) # rgb input
img = img.astype(float)
orig_h, orig_w, _ = img.shape
img = self.resize_image(img, limit_side_len)
result = {'img': img, 'orig_h': orig_h, 'orig_w': orig_w}
return result
def forward(self, input):
orig_h, orig_w = input['orig_h'], input['orig_w']
with self._session.as_default():
feed_dict = {self.input_name: input['img']}
output_img = self._session.run(self.output, feed_dict=feed_dict) # RGBA
# output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA)
output_img = cv2.resize(output_img, (int(orig_w), int(orig_h)))
return {"output_img": output_img}
def postprocess(self, inputs):
return inputs["output_img"]