File size: 2,086 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Union

import numpy as np
import torch

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['contour_expand'])


def contour_expand(kernel_mask: Union[np.array, torch.Tensor],
                   internal_kernel_label: Union[np.array, torch.Tensor],
                   min_kernel_area: int, kernel_num: int) -> list:
    """Expand kernel contours so that foreground pixels are assigned into
    instances.

    Args:
        kernel_mask (np.array or torch.Tensor): The instance kernel mask with
            size hxw.
        internal_kernel_label (np.array or torch.Tensor): The instance internal
            kernel label with size hxw.
        min_kernel_area (int): The minimum kernel area.
        kernel_num (int): The instance kernel number.

    Returns:
        list: The instance index map with size hxw.
    """
    assert isinstance(kernel_mask, (torch.Tensor, np.ndarray))
    assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray))
    assert isinstance(min_kernel_area, int)
    assert isinstance(kernel_num, int)

    if isinstance(kernel_mask, np.ndarray):
        kernel_mask = torch.from_numpy(kernel_mask)
    if isinstance(internal_kernel_label, np.ndarray):
        internal_kernel_label = torch.from_numpy(internal_kernel_label)

    if torch.__version__ == 'parrots':
        if kernel_mask.shape[0] == 0 or internal_kernel_label.shape[0] == 0:
            label = []
        else:
            label = ext_module.contour_expand(
                kernel_mask,
                internal_kernel_label,
                min_kernel_area=min_kernel_area,
                kernel_num=kernel_num)
            label = label.tolist()  # type: ignore
    else:
        label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
                                          min_kernel_area, kernel_num)
    return label