File size: 3,586 Bytes
814a594 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import random
import torch
import torch.nn.functional as F
import numpy as np
from scipy import ndimage
class Point:
def __init__(self, cfg, is_train=True):
self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS']
self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER']
self.is_train = is_train
def draw(self, mask=None, box=None):
if mask.sum() < 10:
return torch.zeros(mask.shape).bool() # if mask is empty
if not self.is_train:
return self.draw_eval(mask=mask, box=box)
max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number
num_points = random.randint(1, max_points) # get a random number of points
h,w = mask.shape
view_mask = mask.view(-1)
non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask
selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id
non_zero_idx = non_zero_idx[selected_idx] # select non-zero index
rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask
rand_mask[non_zero_idx] = True # get non zero place to zero
# dilate
# struct = ndimage.generate_binary_structure(2, 2)
# rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
# return rand_mask
return rand_mask.reshape(h, w)
def draw_eval(self, mask=None, box=None):
background = ~mask
neg_num = min(self.max_eval // 2, background.sum().item())
pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1
h,w = mask.shape
view_mask = mask.view(-1)
non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask
selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id
non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index
pos_idx = torch.ones(non_zero_idx_pos.shape)
view_background = background.view(-1)
non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask
selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id
non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index
neg_idx = torch.ones(non_zero_idx_neg.shape) * -1
non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg])
idx = torch.cat([pos_idx, neg_idx])
rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long()
non_zero_idx = non_zero_idx[rand_idx]
idx = idx[rand_idx]
rand_masks = []
for i in range(0, len(non_zero_idx)):
rand_mask = torch.zeros(view_mask.shape) # init rand mask
rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero
# struct = ndimage.generate_binary_structure(2, 2)
# rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype)))
rand_masks += [rand_mask.reshape(h, w)]
# kernel_size = 3
rand_masks = torch.stack(rand_masks)
# rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0]
# rand_masks[rand_masks>0] = 1
# rand_masks[rand_masks<0] = -1
return rand_masks
def __repr__(self,):
return 'point' |