import cv2 import numpy as np import tensorflow as tf from PIL import Image if tf.__version__ >= '2.0': tf = tf.compat.v1 class ImageUniversalMatting: def __init__(self, weight_path): super().__init__() config = tf.ConfigProto(allow_soft_placement=True, device_count={'GPU': 0}) config.gpu_options.allow_growth = True self._session = tf.Session(config=config) with self._session.as_default(): print(f'loading model from {weight_path}') with tf.gfile.FastGFile(weight_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') self.output = self._session.graph.get_tensor_by_name( 'output_png:0') self.input_name = 'input_image:0' print('load model done') self._session.graph.finalize() def __call__(self, image): output = self.preprocess(image) output = self.forward(output) output = self.postprocess(output) return output def resize_image(self, img, limit_side_len): """ resize image to a size multiple of 32 which is required by the network args: img(array): array with shape [h, w, c] return(tuple): img, (ratio_h, ratio_w) """ h, w, _ = img.shape # limit the max side if max(h, w) > limit_side_len: if h > w: ratio = float(limit_side_len) / h else: ratio = float(limit_side_len) / w else: ratio = 1. resize_h = int(h * ratio) resize_w = int(w * ratio) resize_h = int(round(resize_h / 32) * 32) resize_w = int(round(resize_w / 32) * 32) img = cv2.resize(img, (int(resize_w), int(resize_h))) return img @staticmethod def convert_to_ndarray(img): if isinstance(img, Image.Image): img = np.array(img.convert('RGB')) elif isinstance(img, np.ndarray): if len(img.shape) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = img[:, :, ::-1] # convert to rgb else: raise TypeError(f'input should be either PIL.Image,' f' np.array, but got {type(img)}') return img def preprocess(self, input, limit_side_len=800): img = self.convert_to_ndarray(input) # rgb input img = img.astype(float) orig_h, orig_w, _ = img.shape img = self.resize_image(img, limit_side_len) result = {'img': img, 'orig_h': orig_h, 'orig_w': orig_w} return result def forward(self, input): orig_h, orig_w = input['orig_h'], input['orig_w'] with self._session.as_default(): feed_dict = {self.input_name: input['img']} output_img = self._session.run(self.output, feed_dict=feed_dict) # RGBA # output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA) output_img = cv2.resize(output_img, (int(orig_w), int(orig_h))) return {"output_img": output_img} def postprocess(self, inputs): return inputs["output_img"]