File size: 3,272 Bytes
9c18e52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import cv2
import numpy as np
import tensorflow as tf
from PIL import Image

if tf.__version__ >= '2.0':
    tf = tf.compat.v1


class ImageUniversalMatting:

    def __init__(self, weight_path):
        super().__init__()
        config = tf.ConfigProto(allow_soft_placement=True, device_count={'GPU': 0})
        config.gpu_options.allow_growth = True
        self._session = tf.Session(config=config)
        with self._session.as_default():
            print(f'loading model from {weight_path}')
            with tf.gfile.FastGFile(weight_path, 'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                tf.import_graph_def(graph_def, name='')
                self.output = self._session.graph.get_tensor_by_name(
                    'output_png:0')
                self.input_name = 'input_image:0'
            print('load model done')
        self._session.graph.finalize()

    def __call__(self, image):
        output = self.preprocess(image)
        output = self.forward(output)
        output = self.postprocess(output)
        return output

    def resize_image(self, img, limit_side_len):
        """
        resize image to a size multiple of 32 which is required by the network
        args:
            img(array): array with shape [h, w, c]
        return(tuple):
            img, (ratio_h, ratio_w)
        """
        h, w, _ = img.shape

        # limit the max side
        if max(h, w) > limit_side_len:
            if h > w:
                ratio = float(limit_side_len) / h
            else:
                ratio = float(limit_side_len) / w
        else:
            ratio = 1.
        resize_h = int(h * ratio)
        resize_w = int(w * ratio)

        resize_h = int(round(resize_h / 32) * 32)
        resize_w = int(round(resize_w / 32) * 32)

        img = cv2.resize(img, (int(resize_w), int(resize_h)))

        return img

    @staticmethod
    def convert_to_ndarray(img):
        if isinstance(img, Image.Image):
            img = np.array(img.convert('RGB'))
        elif isinstance(img, np.ndarray):
            if len(img.shape) == 2:
                img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
                img = img[:, :, ::-1]  # convert to rgb
        else:
            raise TypeError(f'input should be either PIL.Image,'
                            f' np.array, but got {type(img)}')
        return img

    def preprocess(self, input, limit_side_len=800):
        img = self.convert_to_ndarray(input)  # rgb input
        img = img.astype(float)
        orig_h, orig_w, _ = img.shape
        img = self.resize_image(img, limit_side_len)
        result = {'img': img, 'orig_h': orig_h, 'orig_w': orig_w}
        return result

    def forward(self, input):
        orig_h, orig_w = input['orig_h'], input['orig_w']
        with self._session.as_default():
            feed_dict = {self.input_name: input['img']}
            output_img = self._session.run(self.output, feed_dict=feed_dict)  # RGBA
            # output_img = cv2.cvtColor(output_img, cv2.COLOR_RGBA2BGRA)
            output_img = cv2.resize(output_img, (int(orig_w), int(orig_h)))
        return {"output_img": output_img}

    def postprocess(self, inputs):
        return inputs["output_img"]