|
import random |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from scipy import ndimage |
|
|
|
|
|
class Point: |
|
def __init__(self, cfg, is_train=True): |
|
self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS'] |
|
self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'] |
|
self.is_train = is_train |
|
|
|
def draw(self, mask=None, box=None): |
|
if mask.sum() < 10: |
|
return torch.zeros(mask.shape).bool() |
|
if not self.is_train: |
|
return self.draw_eval(mask=mask, box=box) |
|
max_points = min(self.max_points, mask.sum().item()) |
|
num_points = random.randint(1, max_points) |
|
h,w = mask.shape |
|
view_mask = mask.view(-1) |
|
non_zero_idx = view_mask.nonzero()[:,0] |
|
selected_idx = torch.randperm(len(non_zero_idx))[:num_points] |
|
non_zero_idx = non_zero_idx[selected_idx] |
|
rand_mask = torch.zeros(view_mask.shape).bool() |
|
rand_mask[non_zero_idx] = True |
|
|
|
|
|
|
|
|
|
return rand_mask.reshape(h, w) |
|
|
|
def draw_eval(self, mask=None, box=None): |
|
background = ~mask |
|
neg_num = min(self.max_eval // 2, background.sum().item()) |
|
pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1 |
|
|
|
h,w = mask.shape |
|
view_mask = mask.view(-1) |
|
non_zero_idx_pos = view_mask.nonzero()[:,0] |
|
selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] |
|
non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] |
|
pos_idx = torch.ones(non_zero_idx_pos.shape) |
|
|
|
view_background = background.view(-1) |
|
non_zero_idx_neg = view_background.nonzero()[:,0] |
|
selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] |
|
non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] |
|
neg_idx = torch.ones(non_zero_idx_neg.shape) * -1 |
|
|
|
non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg]) |
|
idx = torch.cat([pos_idx, neg_idx]) |
|
rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long() |
|
non_zero_idx = non_zero_idx[rand_idx] |
|
idx = idx[rand_idx] |
|
|
|
rand_masks = [] |
|
for i in range(0, len(non_zero_idx)): |
|
rand_mask = torch.zeros(view_mask.shape) |
|
rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] |
|
|
|
|
|
rand_masks += [rand_mask.reshape(h, w)] |
|
|
|
|
|
rand_masks = torch.stack(rand_masks) |
|
|
|
|
|
|
|
return rand_masks |
|
|
|
def __repr__(self,): |
|
return 'point' |