hjc-owo
init repo
966ae59
raw
history blame
5.01 kB
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from torchvision.utils import make_grid
from skimage.transform import resize
from .u2net import U2NET
def plot_attn_dino(attn, threshold_map, inputs, inds, output_path):
# currently supports one image (and not a batch)
plt.figure(figsize=(10, 5))
plt.subplot(2, attn.shape[0] + 2, 1)
main_im = make_grid(inputs, normalize=True, pad_value=2)
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
plt.imshow(main_im, interpolation='nearest')
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
plt.title("input im")
plt.axis("off")
plt.subplot(2, attn.shape[0] + 2, 2)
plt.imshow(attn.sum(0).numpy(), interpolation='nearest')
plt.title("atn map sum")
plt.axis("off")
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3)
plt.imshow(threshold_map[-1].numpy(), interpolation='nearest')
plt.title("prob sum")
plt.axis("off")
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4)
plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest')
plt.title("thresh sum")
plt.axis("off")
for i in range(attn.shape[0]):
plt.subplot(2, attn.shape[0] + 2, i + 3)
plt.imshow(attn[i].numpy())
plt.axis("off")
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4)
plt.imshow(threshold_map[i].numpy())
plt.axis("off")
plt.tight_layout()
plt.savefig(output_path)
plt.close()
def plot_attn_clip(attn, threshold_map, inputs, inds, output_path):
# currently supports one image (and not a batch)
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
main_im = make_grid(inputs, normalize=True, pad_value=2)
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
plt.imshow(main_im, interpolation='nearest')
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
plt.title("input im")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1)
plt.title("attn map")
plt.axis("off")
plt.subplot(1, 3, 3)
threshold_map_ = (threshold_map - threshold_map.min()) / \
(threshold_map.max() - threshold_map.min())
plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1)
plt.title("prob softmax")
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
plt.axis("off")
plt.tight_layout()
plt.savefig(output_path)
plt.close()
def plot_attn(attn, threshold_map, inputs, inds, output_path, saliency_model):
if saliency_model == "dino":
plot_attn_dino(attn, threshold_map, inputs, inds, output_path)
elif saliency_model == "clip":
plot_attn_clip(attn, threshold_map, inputs, inds, output_path)
def fix_image_scale(im):
im_np = np.array(im) / 255
height, width = im_np.shape[0], im_np.shape[1]
max_len = max(height, width) + 20
new_background = np.ones((max_len, max_len, 3))
y, x = max_len // 2 - height // 2, max_len // 2 - width // 2
new_background[y: y + height, x: x + width] = im_np
new_background = (new_background / new_background.max() * 255).astype(np.uint8)
new_im = Image.fromarray(new_background)
return new_im
def get_mask_u2net(pil_im, output_dir, u2net_path, device="cpu"):
# input preprocess
w, h = pil_im.size[0], pil_im.size[1]
im_size = min(w, h)
data_transforms = transforms.Compose([
transforms.Resize(min(320, im_size), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
])
input_im_trans = data_transforms(pil_im).unsqueeze(0).to(device)
# load U^2 Net model
net = U2NET(in_ch=3, out_ch=1)
net.load_state_dict(torch.load(u2net_path))
net.to(device)
net.eval()
# get mask
with torch.no_grad():
d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach())
pred = d1[:, 0, :, :]
pred = (pred - pred.min()) / (pred.max() - pred.min())
predict = pred
predict[predict < 0.5] = 0
predict[predict >= 0.5] = 1
mask = torch.cat([predict, predict, predict], dim=0).permute(1, 2, 0)
mask = mask.cpu().numpy()
mask = resize(mask, (h, w), anti_aliasing=False)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# predict_np = predict.clone().cpu().data.numpy()
im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB')
save_path_ = output_dir / "mask.png"
im.save(save_path_)
im_np = np.array(pil_im)
im_np = im_np / im_np.max()
im_np = mask * im_np
im_np[mask == 0] = 1
im_final = (im_np / im_np.max() * 255).astype(np.uint8)
im_final = Image.fromarray(im_final)
# free u2net
del net
torch.cuda.empty_cache()
return im_final, predict