Spaces:
Runtime error
Runtime error
Commit
·
0ce9d2a
0
Parent(s):
Duplicate from PaulHilders/CLIPGroundingExplainability
Browse filesCo-authored-by: Paul Hilders <[email protected]>
- .gitattributes +27 -0
- .gitmodules +3 -0
- CLIP_explainability/Transformer-MM-Explainability +1 -0
- CLIP_explainability/utils.py +152 -0
- README.md +14 -0
- app.py +67 -0
- clip_grounding/datasets/png.py +231 -0
- clip_grounding/datasets/png_utils.py +135 -0
- clip_grounding/evaluation/clip_on_png.py +362 -0
- clip_grounding/evaluation/qualitative_results.py +93 -0
- clip_grounding/utils/image.py +46 -0
- clip_grounding/utils/io.py +116 -0
- clip_grounding/utils/log.py +57 -0
- clip_grounding/utils/paths.py +10 -0
- clip_grounding/utils/visualize.py +183 -0
- example_images/Amsterdam.png +0 -0
- example_images/London.png +0 -0
- example_images/dogs_on_bed.png +0 -0
- example_images/harrypotter.png +0 -0
- requirements.txt +121 -0
.gitattributes
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "CLIP_explainability/Transformer-MM-Explainability"]
|
| 2 |
+
path = CLIP_explainability/Transformer-MM-Explainability
|
| 3 |
+
url = https://github.com/hila-chefer/Transformer-MM-Explainability.git
|
CLIP_explainability/Transformer-MM-Explainability
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 6a2c3c9da3fc186878e0c2bcf238c3a4c76d8af8
|
CLIP_explainability/utils.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import CLIP.clip as clip
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from captum.attr import visualization
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 12 |
+
_tokenizer = _Tokenizer()
|
| 13 |
+
|
| 14 |
+
#@title Control context expansion (number of attention layers to consider)
|
| 15 |
+
#@title Number of layers for image Transformer
|
| 16 |
+
start_layer = 11#@param {type:"number"}
|
| 17 |
+
|
| 18 |
+
#@title Number of layers for text Transformer
|
| 19 |
+
start_layer_text = 11#@param {type:"number"}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def interpret(image, texts, model, device):
|
| 23 |
+
batch_size = texts.shape[0]
|
| 24 |
+
images = image.repeat(batch_size, 1, 1, 1)
|
| 25 |
+
logits_per_image, logits_per_text = model(images, texts)
|
| 26 |
+
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
|
| 27 |
+
index = [i for i in range(batch_size)]
|
| 28 |
+
one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
|
| 29 |
+
one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
|
| 30 |
+
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 31 |
+
one_hot = torch.sum(one_hot.to(device) * logits_per_image)
|
| 32 |
+
model.zero_grad()
|
| 33 |
+
|
| 34 |
+
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
|
| 35 |
+
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
|
| 36 |
+
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
|
| 37 |
+
R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
|
| 38 |
+
for i, blk in enumerate(image_attn_blocks):
|
| 39 |
+
if i < start_layer:
|
| 40 |
+
continue
|
| 41 |
+
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
|
| 42 |
+
cam = blk.attn_probs.detach()
|
| 43 |
+
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
|
| 44 |
+
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
|
| 45 |
+
cam = grad * cam
|
| 46 |
+
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
|
| 47 |
+
cam = cam.clamp(min=0).mean(dim=1)
|
| 48 |
+
R = R + torch.bmm(cam, R)
|
| 49 |
+
image_relevance = R[:, 0, 1:]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())
|
| 53 |
+
num_tokens = text_attn_blocks[0].attn_probs.shape[-1]
|
| 54 |
+
R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device)
|
| 55 |
+
R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
|
| 56 |
+
for i, blk in enumerate(text_attn_blocks):
|
| 57 |
+
if i < start_layer_text:
|
| 58 |
+
continue
|
| 59 |
+
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
|
| 60 |
+
cam = blk.attn_probs.detach()
|
| 61 |
+
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
|
| 62 |
+
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
|
| 63 |
+
cam = grad * cam
|
| 64 |
+
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
|
| 65 |
+
cam = cam.clamp(min=0).mean(dim=1)
|
| 66 |
+
R_text = R_text + torch.bmm(cam, R_text)
|
| 67 |
+
text_relevance = R_text
|
| 68 |
+
|
| 69 |
+
return text_relevance, image_relevance
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def show_image_relevance(image_relevance, image, orig_image, device, show=True):
|
| 73 |
+
# create heatmap from mask on image
|
| 74 |
+
def show_cam_on_image(img, mask):
|
| 75 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
| 76 |
+
heatmap = np.float32(heatmap) / 255
|
| 77 |
+
cam = heatmap + np.float32(img)
|
| 78 |
+
cam = cam / np.max(cam)
|
| 79 |
+
return cam
|
| 80 |
+
|
| 81 |
+
# plt.axis('off')
|
| 82 |
+
# f, axarr = plt.subplots(1,2)
|
| 83 |
+
# axarr[0].imshow(orig_image)
|
| 84 |
+
|
| 85 |
+
if show:
|
| 86 |
+
fig, axs = plt.subplots(1, 2)
|
| 87 |
+
axs[0].imshow(orig_image);
|
| 88 |
+
axs[0].axis('off');
|
| 89 |
+
|
| 90 |
+
image_relevance = image_relevance.reshape(1, 1, 7, 7)
|
| 91 |
+
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
|
| 92 |
+
image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
|
| 93 |
+
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
|
| 94 |
+
image = image[0].permute(1, 2, 0).data.cpu().numpy()
|
| 95 |
+
image = (image - image.min()) / (image.max() - image.min())
|
| 96 |
+
vis = show_cam_on_image(image, image_relevance)
|
| 97 |
+
vis = np.uint8(255 * vis)
|
| 98 |
+
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
| 99 |
+
|
| 100 |
+
if show:
|
| 101 |
+
# axar[1].imshow(vis)
|
| 102 |
+
axs[1].imshow(vis);
|
| 103 |
+
axs[1].axis('off');
|
| 104 |
+
# plt.imshow(vis)
|
| 105 |
+
|
| 106 |
+
return image_relevance
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def show_heatmap_on_text(text, text_encoding, R_text, show=True):
|
| 110 |
+
CLS_idx = text_encoding.argmax(dim=-1)
|
| 111 |
+
R_text = R_text[CLS_idx, 1:CLS_idx]
|
| 112 |
+
text_scores = R_text / R_text.sum()
|
| 113 |
+
text_scores = text_scores.flatten()
|
| 114 |
+
# print(text_scores)
|
| 115 |
+
text_tokens=_tokenizer.encode(text)
|
| 116 |
+
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
|
| 117 |
+
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
|
| 118 |
+
|
| 119 |
+
if show:
|
| 120 |
+
visualization.visualize_text(vis_data_records)
|
| 121 |
+
|
| 122 |
+
return text_scores, text_tokens_decoded
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def show_img_heatmap(image_relevance, image, orig_image, device, show=True):
|
| 126 |
+
return show_image_relevance(image_relevance, image, orig_image, device, show=show)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def show_txt_heatmap(text, text_encoding, R_text, show=True):
|
| 130 |
+
return show_heatmap_on_text(text, text_encoding, R_text, show=show)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def load_dataset():
|
| 134 |
+
dataset_path = os.path.join('..', '..', 'dummy-data', '71226_segments' + '.pt')
|
| 135 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 136 |
+
|
| 137 |
+
data = torch.load(dataset_path, map_location=device)
|
| 138 |
+
|
| 139 |
+
return data
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class color:
|
| 143 |
+
PURPLE = '\033[95m'
|
| 144 |
+
CYAN = '\033[96m'
|
| 145 |
+
DARKCYAN = '\033[36m'
|
| 146 |
+
BLUE = '\033[94m'
|
| 147 |
+
GREEN = '\033[92m'
|
| 148 |
+
YELLOW = '\033[93m'
|
| 149 |
+
RED = '\033[91m'
|
| 150 |
+
BOLD = '\033[1m'
|
| 151 |
+
UNDERLINE = '\033[4m'
|
| 152 |
+
END = '\033[0m'
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CLIPGroundingExplainabilityDemo
|
| 3 |
+
emoji: 💩
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.0.22
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: afl-3.0
|
| 11 |
+
duplicated_from: PaulHilders/CLIPGroundingExplainability
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import gradio as gr
|
| 3 |
+
|
| 4 |
+
# sys.path.append("../")
|
| 5 |
+
sys.path.append("CLIP_explainability/Transformer-MM-Explainability/")
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import CLIP.clip as clip
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from clip_grounding.utils.image import pad_to_square
|
| 12 |
+
from clip_grounding.datasets.png import (
|
| 13 |
+
overlay_relevance_map_on_image,
|
| 14 |
+
)
|
| 15 |
+
from CLIP_explainability.utils import interpret, show_img_heatmap, show_heatmap_on_text
|
| 16 |
+
|
| 17 |
+
clip.clip._MODELS = {
|
| 18 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 19 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
| 24 |
+
|
| 25 |
+
# Gradio Section:
|
| 26 |
+
def run_demo(image, text):
|
| 27 |
+
orig_image = pad_to_square(image)
|
| 28 |
+
img = preprocess(orig_image).unsqueeze(0).to(device)
|
| 29 |
+
text_input = clip.tokenize([text]).to(device)
|
| 30 |
+
|
| 31 |
+
R_text, R_image = interpret(model=model, image=img, texts=text_input, device=device)
|
| 32 |
+
|
| 33 |
+
image_relevance = show_img_heatmap(R_image[0], img, orig_image=orig_image, device=device, show=False)
|
| 34 |
+
overlapped = overlay_relevance_map_on_image(image, image_relevance)
|
| 35 |
+
|
| 36 |
+
text_scores, text_tokens_decoded = show_heatmap_on_text(text, text_input, R_text[0], show=False)
|
| 37 |
+
|
| 38 |
+
highlighted_text = []
|
| 39 |
+
for i, token in enumerate(text_tokens_decoded):
|
| 40 |
+
highlighted_text.append((str(token), float(text_scores[i])))
|
| 41 |
+
|
| 42 |
+
return overlapped, highlighted_text
|
| 43 |
+
|
| 44 |
+
input_img = gr.inputs.Image(type='pil', label="Original Image")
|
| 45 |
+
input_txt = "text"
|
| 46 |
+
inputs = [input_img, input_txt]
|
| 47 |
+
|
| 48 |
+
outputs = [gr.inputs.Image(type='pil', label="Output Image"), "highlight"]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
iface = gr.Interface(fn=run_demo,
|
| 52 |
+
inputs=inputs,
|
| 53 |
+
outputs=outputs,
|
| 54 |
+
title="CLIP Grounding Explainability",
|
| 55 |
+
description="A demonstration based on the Generic Attention-model Explainability method for Interpreting Bi-Modal Transformers by Chefer et al. (2021): https://github.com/hila-chefer/Transformer-MM-Explainability.",
|
| 56 |
+
examples=[["example_images/London.png", "London Eye"],
|
| 57 |
+
["example_images/London.png", "Big Ben"],
|
| 58 |
+
["example_images/harrypotter.png", "Harry"],
|
| 59 |
+
["example_images/harrypotter.png", "Hermione"],
|
| 60 |
+
["example_images/harrypotter.png", "Ron"],
|
| 61 |
+
["example_images/Amsterdam.png", "Amsterdam canal"],
|
| 62 |
+
["example_images/Amsterdam.png", "Old buildings"],
|
| 63 |
+
["example_images/Amsterdam.png", "Pink flowers"],
|
| 64 |
+
["example_images/dogs_on_bed.png", "Two dogs"],
|
| 65 |
+
["example_images/dogs_on_bed.png", "Book"],
|
| 66 |
+
["example_images/dogs_on_bed.png", "Cat"]])
|
| 67 |
+
iface.launch(debug=True)
|
clip_grounding/datasets/png.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset object for Panoptic Narrative Grounding.
|
| 3 |
+
|
| 4 |
+
Paper: https://openaccess.thecvf.com/content/ICCV2021/papers/Gonzalez_Panoptic_Narrative_Grounding_ICCV_2021_paper.pdf
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from os.path import join, isdir, exists
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
import cv2
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from skimage import io
|
| 15 |
+
import numpy as np
|
| 16 |
+
import textwrap
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
from matplotlib import transforms
|
| 19 |
+
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
|
| 20 |
+
import matplotlib.colors as mc
|
| 21 |
+
|
| 22 |
+
from clip_grounding.utils.io import load_json
|
| 23 |
+
from clip_grounding.datasets.png_utils import show_image_and_caption
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PNG(Dataset):
|
| 27 |
+
"""Panoptic Narrative Grounding."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, dataset_root, split) -> None:
|
| 30 |
+
"""
|
| 31 |
+
Initializer.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
dataset_root (str): path to the folder containing PNG dataset
|
| 35 |
+
split (str): MS-COCO split such as train2017/val2017
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
|
| 39 |
+
assert isdir(dataset_root)
|
| 40 |
+
self.dataset_root = dataset_root
|
| 41 |
+
|
| 42 |
+
assert split in ["val2017"], f"Split {split} not supported. "\
|
| 43 |
+
"Currently, only supports split `val2017`."
|
| 44 |
+
self.split = split
|
| 45 |
+
|
| 46 |
+
self.ann_dir = join(self.dataset_root, "annotations")
|
| 47 |
+
# feat_dir = join(self.dataset_root, "features")
|
| 48 |
+
|
| 49 |
+
panoptic = load_json(join(self.ann_dir, "panoptic_{:s}.json".format(split)))
|
| 50 |
+
images = panoptic["images"]
|
| 51 |
+
self.images_info = {i["id"]: i for i in images}
|
| 52 |
+
panoptic_anns = panoptic["annotations"]
|
| 53 |
+
self.panoptic_anns = {int(a["image_id"]): a for a in panoptic_anns}
|
| 54 |
+
|
| 55 |
+
# self.panoptic_pred_path = join(
|
| 56 |
+
# feat_dir, split, "panoptic_seg_predictions"
|
| 57 |
+
# )
|
| 58 |
+
# assert isdir(self.panoptic_pred_path)
|
| 59 |
+
|
| 60 |
+
panoptic_narratives_path = join(self.dataset_root, "annotations", f"png_coco_{split}.json")
|
| 61 |
+
self.panoptic_narratives = load_json(panoptic_narratives_path)
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.panoptic_narratives)
|
| 65 |
+
|
| 66 |
+
def get_image_path(self, image_id: str):
|
| 67 |
+
image_path = join(self.dataset_root, "images", self.split, f"{image_id.zfill(12)}.jpg")
|
| 68 |
+
return image_path
|
| 69 |
+
|
| 70 |
+
def __getitem__(self, idx: int):
|
| 71 |
+
narr = self.panoptic_narratives[idx]
|
| 72 |
+
|
| 73 |
+
image_id = narr["image_id"]
|
| 74 |
+
image_path = self.get_image_path(image_id)
|
| 75 |
+
assert exists(image_path)
|
| 76 |
+
|
| 77 |
+
image = Image.open(image_path)
|
| 78 |
+
caption = narr["caption"]
|
| 79 |
+
|
| 80 |
+
# show_single_image(image, title=caption, titlesize=12)
|
| 81 |
+
|
| 82 |
+
segments = narr["segments"]
|
| 83 |
+
|
| 84 |
+
image_id = int(narr["image_id"])
|
| 85 |
+
panoptic_ann = self.panoptic_anns[image_id]
|
| 86 |
+
panoptic_ann = self.panoptic_anns[image_id]
|
| 87 |
+
segment_infos = {}
|
| 88 |
+
for s in panoptic_ann["segments_info"]:
|
| 89 |
+
idi = s["id"]
|
| 90 |
+
segment_infos[idi] = s
|
| 91 |
+
|
| 92 |
+
image_info = self.images_info[image_id]
|
| 93 |
+
panoptic_segm = io.imread(
|
| 94 |
+
join(
|
| 95 |
+
self.ann_dir,
|
| 96 |
+
"panoptic_segmentation",
|
| 97 |
+
self.split,
|
| 98 |
+
"{:012d}.png".format(image_id),
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
panoptic_segm = (
|
| 102 |
+
panoptic_segm[:, :, 0]
|
| 103 |
+
+ panoptic_segm[:, :, 1] * 256
|
| 104 |
+
+ panoptic_segm[:, :, 2] * 256 ** 2
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
panoptic_ann = self.panoptic_anns[image_id]
|
| 108 |
+
# panoptic_pred = io.imread(
|
| 109 |
+
# join(self.panoptic_pred_path, "{:012d}.png".format(image_id))
|
| 110 |
+
# )[:, :, 0]
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# # select a single utterance to visualize
|
| 114 |
+
# segment = segments[7]
|
| 115 |
+
# segment_ids = segment["segment_ids"]
|
| 116 |
+
# segment_mask = np.zeros((image_info["height"], image_info["width"]))
|
| 117 |
+
# for segment_id in segment_ids:
|
| 118 |
+
# segment_id = int(segment_id)
|
| 119 |
+
# segment_mask[panoptic_segm == segment_id] = 1.
|
| 120 |
+
|
| 121 |
+
utterances = [s["utterance"] for s in segments]
|
| 122 |
+
outputs = []
|
| 123 |
+
for i, segment in enumerate(segments):
|
| 124 |
+
|
| 125 |
+
# create segmentation mask on image
|
| 126 |
+
segment_ids = segment["segment_ids"]
|
| 127 |
+
|
| 128 |
+
# if no annotation for this word, skip
|
| 129 |
+
if not len(segment_ids):
|
| 130 |
+
continue
|
| 131 |
+
|
| 132 |
+
segment_mask = np.zeros((image_info["height"], image_info["width"]))
|
| 133 |
+
for segment_id in segment_ids:
|
| 134 |
+
segment_id = int(segment_id)
|
| 135 |
+
segment_mask[panoptic_segm == segment_id] = 1.
|
| 136 |
+
|
| 137 |
+
# store the outputs
|
| 138 |
+
text_mask = np.zeros(len(utterances))
|
| 139 |
+
text_mask[i] = 1.
|
| 140 |
+
segment_data = dict(
|
| 141 |
+
image=image,
|
| 142 |
+
text=utterances,
|
| 143 |
+
image_mask=segment_mask,
|
| 144 |
+
text_mask=text_mask,
|
| 145 |
+
full_caption=caption,
|
| 146 |
+
)
|
| 147 |
+
outputs.append(segment_data)
|
| 148 |
+
|
| 149 |
+
# # visualize segmentation mask with associated text
|
| 150 |
+
# segment_color = "red"
|
| 151 |
+
# segmap = SegmentationMapsOnImage(
|
| 152 |
+
# segment_mask.astype(np.uint8), shape=segment_mask.shape,
|
| 153 |
+
# )
|
| 154 |
+
# image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, COLORS[segment_color]])[0]
|
| 155 |
+
# image_with_segmap = Image.fromarray(image_with_segmap)
|
| 156 |
+
|
| 157 |
+
# colors = ["black" for _ in range(len(utterances))]
|
| 158 |
+
# colors[i] = segment_color
|
| 159 |
+
# show_image_and_caption(image_with_segmap, utterances, colors)
|
| 160 |
+
|
| 161 |
+
return outputs
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def overlay_segmask_on_image(image, image_mask, segment_color="red"):
|
| 165 |
+
segmap = SegmentationMapsOnImage(
|
| 166 |
+
image_mask.astype(np.uint8), shape=image_mask.shape,
|
| 167 |
+
)
|
| 168 |
+
rgb_color = mc.to_rgb(segment_color)
|
| 169 |
+
rgb_color = 255 * np.array(rgb_color)
|
| 170 |
+
image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0]
|
| 171 |
+
image_with_segmap = Image.fromarray(image_with_segmap)
|
| 172 |
+
return image_with_segmap
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_text_colors(text, text_mask, segment_color="red"):
|
| 176 |
+
colors = ["black" for _ in range(len(text))]
|
| 177 |
+
colors[text_mask.nonzero()[0][0]] = segment_color
|
| 178 |
+
return colors
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def overlay_relevance_map_on_image(image, heatmap):
|
| 182 |
+
width, height = image.size
|
| 183 |
+
|
| 184 |
+
# resize the heatmap to image size
|
| 185 |
+
heatmap = cv2.resize(heatmap, (width, height))
|
| 186 |
+
heatmap = np.uint8(255 * heatmap)
|
| 187 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
| 188 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 189 |
+
|
| 190 |
+
# create overlapped super image
|
| 191 |
+
img = np.asarray(image)
|
| 192 |
+
super_img = heatmap * 0.4 + img * 0.6
|
| 193 |
+
super_img = np.uint8(super_img)
|
| 194 |
+
super_img = Image.fromarray(super_img)
|
| 195 |
+
|
| 196 |
+
return super_img
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def visualize_item(image, text, image_mask, text_mask, segment_color="red"):
|
| 200 |
+
|
| 201 |
+
segmap = SegmentationMapsOnImage(
|
| 202 |
+
image_mask.astype(np.uint8), shape=image_mask.shape,
|
| 203 |
+
)
|
| 204 |
+
rgb_color = mc.to_rgb(segment_color)
|
| 205 |
+
rgb_color = 255 * np.array(rgb_color)
|
| 206 |
+
image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0]
|
| 207 |
+
image_with_segmap = Image.fromarray(image_with_segmap)
|
| 208 |
+
|
| 209 |
+
colors = ["black" for _ in range(len(text))]
|
| 210 |
+
|
| 211 |
+
text_idx = text_mask.argmax()
|
| 212 |
+
colors[text_idx] = segment_color
|
| 213 |
+
show_image_and_caption(image_with_segmap, text, colors)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
from clip_grounding.utils.paths import REPO_PATH, DATASET_ROOTS
|
| 219 |
+
|
| 220 |
+
PNG_ROOT = DATASET_ROOTS["PNG"]
|
| 221 |
+
dataset = PNG(dataset_root=PNG_ROOT, split="val2017")
|
| 222 |
+
|
| 223 |
+
item = dataset[0]
|
| 224 |
+
sub_item = item[1]
|
| 225 |
+
visualize_item(
|
| 226 |
+
image=sub_item["image"],
|
| 227 |
+
text=sub_item["text"],
|
| 228 |
+
image_mask=sub_item["image_mask"],
|
| 229 |
+
text_mask=sub_item["text_mask"],
|
| 230 |
+
segment_color="red",
|
| 231 |
+
)
|
clip_grounding/datasets/png_utils.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helper functions for Panoptic Narrative Grounding."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from os.path import join, isdir, exists
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from skimage import io
|
| 10 |
+
import numpy as np
|
| 11 |
+
import textwrap
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
from matplotlib import transforms
|
| 14 |
+
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def rainbow_text(x,y,ls,lc,fig, ax,**kw):
|
| 18 |
+
"""
|
| 19 |
+
Take a list of strings ``ls`` and colors ``lc`` and place them next to each
|
| 20 |
+
other, with text ls[i] being shown in color lc[i].
|
| 21 |
+
|
| 22 |
+
Ref: https://stackoverflow.com/questions/9169052/partial-coloring-of-text-in-matplotlib
|
| 23 |
+
"""
|
| 24 |
+
t = ax.transAxes
|
| 25 |
+
|
| 26 |
+
for s,c in zip(ls,lc):
|
| 27 |
+
|
| 28 |
+
text = ax.text(x,y,s+" ",color=c, transform=t, **kw)
|
| 29 |
+
text.draw(fig.canvas.get_renderer())
|
| 30 |
+
ex = text.get_window_extent()
|
| 31 |
+
t = transforms.offset_copy(text._transform, x=ex.width, units='dots')
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def find_first_index_greater_than(elements, key):
|
| 35 |
+
return next(x[0] for x in enumerate(elements) if x[1] > key)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50):
|
| 39 |
+
char_lengths = np.cumsum([len(x) for x in caption_phrases])
|
| 40 |
+
thresholds = [max_char_in_a_line * i for i in range(1, 1 + char_lengths[-1] // max_char_in_a_line)]
|
| 41 |
+
|
| 42 |
+
utt_per_line = []
|
| 43 |
+
col_per_line = []
|
| 44 |
+
start_index = 0
|
| 45 |
+
for t in thresholds:
|
| 46 |
+
index = find_first_index_greater_than(char_lengths, t)
|
| 47 |
+
utt_per_line.append(caption_phrases[start_index:index])
|
| 48 |
+
col_per_line.append(colors[start_index:index])
|
| 49 |
+
start_index = index
|
| 50 |
+
|
| 51 |
+
return utt_per_line, col_per_line
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def show_image_and_caption(image: Image, caption_phrases: list, colors: list = None):
|
| 55 |
+
|
| 56 |
+
if colors is None:
|
| 57 |
+
colors = ["black" for _ in range(len(caption_phrases))]
|
| 58 |
+
|
| 59 |
+
fig, axes = plt.subplots(1, 2, figsize=(15, 4))
|
| 60 |
+
|
| 61 |
+
ax = axes[0]
|
| 62 |
+
ax.imshow(image)
|
| 63 |
+
ax.set_xticks([])
|
| 64 |
+
ax.set_yticks([])
|
| 65 |
+
|
| 66 |
+
ax = axes[1]
|
| 67 |
+
utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50)
|
| 68 |
+
y = 0.7
|
| 69 |
+
for U, C in zip(utt_per_line, col_per_line):
|
| 70 |
+
rainbow_text(
|
| 71 |
+
0., y,
|
| 72 |
+
U,
|
| 73 |
+
C,
|
| 74 |
+
size=15, ax=ax, fig=fig,
|
| 75 |
+
horizontalalignment='left',
|
| 76 |
+
verticalalignment='center',
|
| 77 |
+
)
|
| 78 |
+
y -= 0.11
|
| 79 |
+
|
| 80 |
+
ax.axis("off")
|
| 81 |
+
|
| 82 |
+
fig.tight_layout()
|
| 83 |
+
plt.show()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def show_images_and_caption(
|
| 87 |
+
images: List,
|
| 88 |
+
caption_phrases: list,
|
| 89 |
+
colors: list = None,
|
| 90 |
+
image_xlabels: List=[],
|
| 91 |
+
figsize=None,
|
| 92 |
+
show=False,
|
| 93 |
+
xlabelsize=14,
|
| 94 |
+
):
|
| 95 |
+
|
| 96 |
+
if colors is None:
|
| 97 |
+
colors = ["black" for _ in range(len(caption_phrases))]
|
| 98 |
+
caption_phrases[0] = caption_phrases[0].capitalize()
|
| 99 |
+
|
| 100 |
+
if figsize is None:
|
| 101 |
+
figsize = (5 * len(images) + 8, 4)
|
| 102 |
+
|
| 103 |
+
if image_xlabels is None:
|
| 104 |
+
image_xlabels = ["" for _ in range(len(images))]
|
| 105 |
+
|
| 106 |
+
fig, axes = plt.subplots(1, len(images) + 1, figsize=figsize)
|
| 107 |
+
|
| 108 |
+
for i, image in enumerate(images):
|
| 109 |
+
ax = axes[i]
|
| 110 |
+
ax.imshow(image)
|
| 111 |
+
ax.set_xticks([])
|
| 112 |
+
ax.set_yticks([])
|
| 113 |
+
ax.set_xlabel(image_xlabels[i], fontsize=xlabelsize)
|
| 114 |
+
|
| 115 |
+
ax = axes[-1]
|
| 116 |
+
utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=40)
|
| 117 |
+
y = 0.7
|
| 118 |
+
for U, C in zip(utt_per_line, col_per_line):
|
| 119 |
+
rainbow_text(
|
| 120 |
+
0., y,
|
| 121 |
+
U,
|
| 122 |
+
C,
|
| 123 |
+
size=23, ax=ax, fig=fig,
|
| 124 |
+
horizontalalignment='left',
|
| 125 |
+
verticalalignment='center',
|
| 126 |
+
# weight='bold'
|
| 127 |
+
)
|
| 128 |
+
y -= 0.11
|
| 129 |
+
|
| 130 |
+
ax.axis("off")
|
| 131 |
+
|
| 132 |
+
fig.tight_layout()
|
| 133 |
+
|
| 134 |
+
if show:
|
| 135 |
+
plt.show()
|
clip_grounding/evaluation/clip_on_png.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluates cross-modal correspondence of CLIP on PNG images."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from os.path import join, exists
|
| 6 |
+
|
| 7 |
+
import warnings
|
| 8 |
+
warnings.filterwarnings('ignore')
|
| 9 |
+
|
| 10 |
+
from clip_grounding.utils.paths import REPO_PATH
|
| 11 |
+
sys.path.append(join(REPO_PATH, "CLIP_explainability/Transformer-MM-Explainability/"))
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import CLIP.clip as clip
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import numpy as np
|
| 17 |
+
import cv2
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
from captum.attr import visualization
|
| 20 |
+
from torchmetrics import JaccardIndex
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
from IPython.core.display import display, HTML
|
| 23 |
+
from skimage import filters
|
| 24 |
+
|
| 25 |
+
from CLIP_explainability.utils import interpret, show_img_heatmap, show_txt_heatmap, color, _tokenizer
|
| 26 |
+
from clip_grounding.datasets.png import PNG
|
| 27 |
+
from clip_grounding.utils.image import pad_to_square
|
| 28 |
+
from clip_grounding.utils.visualize import show_grid_of_images
|
| 29 |
+
from clip_grounding.utils.log import tqdm_iterator, print_update
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# global usage
|
| 33 |
+
# specify device
|
| 34 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
+
|
| 36 |
+
# load CLIP model
|
| 37 |
+
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def show_cam(mask):
|
| 41 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
| 42 |
+
heatmap = np.float32(heatmap) / 255
|
| 43 |
+
cam = heatmap
|
| 44 |
+
cam = cam / np.max(cam)
|
| 45 |
+
return cam
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def interpret_and_generate(model, img, texts, orig_image, return_outputs=False, show=True):
|
| 49 |
+
text = clip.tokenize(texts).to(device)
|
| 50 |
+
R_text, R_image = interpret(model=model, image=img, texts=text, device=device)
|
| 51 |
+
batch_size = text.shape[0]
|
| 52 |
+
|
| 53 |
+
outputs = []
|
| 54 |
+
for i in range(batch_size):
|
| 55 |
+
text_scores, text_tokens_decoded = show_txt_heatmap(texts[i], text[i], R_text[i], show=show)
|
| 56 |
+
image_relevance = show_img_heatmap(R_image[i], img, orig_image=orig_image, device=device, show=show)
|
| 57 |
+
plt.show()
|
| 58 |
+
outputs.append({"text_scores": text_scores, "image_relevance": image_relevance, "tokens_decoded": text_tokens_decoded})
|
| 59 |
+
|
| 60 |
+
if return_outputs:
|
| 61 |
+
return outputs
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def process_entry_text_to_image(entry, unimodal=False):
|
| 65 |
+
image = entry['image']
|
| 66 |
+
text_mask = entry['text_mask']
|
| 67 |
+
text = entry['text']
|
| 68 |
+
orig_image = pad_to_square(image)
|
| 69 |
+
|
| 70 |
+
img = preprocess(orig_image).unsqueeze(0).to(device)
|
| 71 |
+
text_index = text_mask.argmax()
|
| 72 |
+
texts = [text[text_index]] if not unimodal else ['']
|
| 73 |
+
|
| 74 |
+
return img, texts, orig_image
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def preprocess_ground_truth_mask(mask, resize_shape):
|
| 78 |
+
mask = Image.fromarray(mask.astype(np.uint8) * 255)
|
| 79 |
+
mask = pad_to_square(mask, color=0)
|
| 80 |
+
mask = mask.resize(resize_shape)
|
| 81 |
+
mask = np.asarray(mask) / 255.
|
| 82 |
+
return mask
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def apply_otsu_threshold(relevance_map):
|
| 86 |
+
threshold = filters.threshold_otsu(relevance_map)
|
| 87 |
+
otsu_map = (relevance_map > threshold).astype(np.uint8)
|
| 88 |
+
return otsu_map
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def evaluate_text_to_image(method, dataset, debug=False):
|
| 92 |
+
|
| 93 |
+
instance_level_metrics = defaultdict(list)
|
| 94 |
+
entry_level_metrics = defaultdict(list)
|
| 95 |
+
|
| 96 |
+
jaccard = JaccardIndex(num_classes=2)
|
| 97 |
+
jaccard = jaccard.to(device)
|
| 98 |
+
|
| 99 |
+
num_iter = len(dataset)
|
| 100 |
+
if debug:
|
| 101 |
+
num_iter = 100
|
| 102 |
+
|
| 103 |
+
iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset")
|
| 104 |
+
for idx in iterator:
|
| 105 |
+
instance = dataset[idx]
|
| 106 |
+
|
| 107 |
+
instance_iou = 0.
|
| 108 |
+
for entry in instance:
|
| 109 |
+
|
| 110 |
+
# preprocess the image and text
|
| 111 |
+
unimodal = True if method == "clip-unimodal" else False
|
| 112 |
+
test_img, test_texts, orig_image = process_entry_text_to_image(entry, unimodal=unimodal)
|
| 113 |
+
|
| 114 |
+
if method in ["clip", "clip-unimodal"]:
|
| 115 |
+
|
| 116 |
+
# compute the relevance scores
|
| 117 |
+
outputs = interpret_and_generate(model, test_img, test_texts, orig_image, return_outputs=True, show=False)
|
| 118 |
+
|
| 119 |
+
# use the image relevance score to compute IoU w.r.t. ground truth segmentation masks
|
| 120 |
+
|
| 121 |
+
# NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs
|
| 122 |
+
relevance_map = outputs[0]["image_relevance"]
|
| 123 |
+
elif method == "random":
|
| 124 |
+
relevance_map = np.random.uniform(low=0., high=1., size=tuple(test_img.shape[2:]))
|
| 125 |
+
|
| 126 |
+
otsu_relevance_map = apply_otsu_threshold(relevance_map)
|
| 127 |
+
|
| 128 |
+
ground_truth_mask = entry["image_mask"]
|
| 129 |
+
ground_truth_mask = preprocess_ground_truth_mask(ground_truth_mask, relevance_map.shape)
|
| 130 |
+
|
| 131 |
+
entry_iou = jaccard(
|
| 132 |
+
torch.from_numpy(otsu_relevance_map).to(device),
|
| 133 |
+
torch.from_numpy(ground_truth_mask.astype(np.uint8)).to(device),
|
| 134 |
+
)
|
| 135 |
+
entry_iou = entry_iou.item()
|
| 136 |
+
instance_iou += (entry_iou / len(entry))
|
| 137 |
+
|
| 138 |
+
entry_level_metrics["iou"].append(entry_iou)
|
| 139 |
+
|
| 140 |
+
# capture instance (image-sentence pair) level IoU
|
| 141 |
+
instance_level_metrics["iou"].append(instance_iou)
|
| 142 |
+
|
| 143 |
+
average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()}
|
| 144 |
+
|
| 145 |
+
return (
|
| 146 |
+
average_metrics,
|
| 147 |
+
instance_level_metrics,
|
| 148 |
+
entry_level_metrics
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def process_entry_image_to_text(entry, unimodal=False):
|
| 153 |
+
|
| 154 |
+
if not unimodal:
|
| 155 |
+
if len(np.asarray(entry["image"]).shape) == 3:
|
| 156 |
+
mask = np.repeat(np.expand_dims(entry['image_mask'], -1), 3, axis=-1)
|
| 157 |
+
else:
|
| 158 |
+
mask = np.asarray(entry['image_mask'])
|
| 159 |
+
|
| 160 |
+
masked_image = (mask * np.asarray(entry['image'])).astype(np.uint8)
|
| 161 |
+
masked_image = Image.fromarray(masked_image)
|
| 162 |
+
orig_image = pad_to_square(masked_image)
|
| 163 |
+
img = preprocess(orig_image).unsqueeze(0).to(device)
|
| 164 |
+
else:
|
| 165 |
+
orig_image_shape = max(np.asarray(entry['image']).shape[:2])
|
| 166 |
+
orig_image = Image.fromarray(np.zeros((orig_image_shape, orig_image_shape, 3), dtype=np.uint8))
|
| 167 |
+
# orig_image = Image.fromarray(np.random.randint(0, 256, (orig_image_shape, orig_image_shape, 3), dtype=np.uint8))
|
| 168 |
+
img = preprocess(orig_image).unsqueeze(0).to(device)
|
| 169 |
+
|
| 170 |
+
texts = [' '.join(entry['text'])]
|
| 171 |
+
|
| 172 |
+
return img, texts, orig_image
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def process_text_mask(text, text_mask, tokens):
|
| 176 |
+
|
| 177 |
+
token_level_mask = np.zeros(len(tokens))
|
| 178 |
+
|
| 179 |
+
for label, subtext in zip(text_mask, text):
|
| 180 |
+
|
| 181 |
+
subtext_tokens=_tokenizer.encode(subtext)
|
| 182 |
+
subtext_tokens_decoded=[_tokenizer.decode([a]) for a in subtext_tokens]
|
| 183 |
+
|
| 184 |
+
if label == 1:
|
| 185 |
+
start = tokens.index(subtext_tokens_decoded[0])
|
| 186 |
+
end = tokens.index(subtext_tokens_decoded[-1])
|
| 187 |
+
token_level_mask[start:end + 1] = 1
|
| 188 |
+
|
| 189 |
+
return token_level_mask
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def evaluate_image_to_text(method, dataset, debug=False, clamp_sentence_len=70):
|
| 193 |
+
|
| 194 |
+
instance_level_metrics = defaultdict(list)
|
| 195 |
+
entry_level_metrics = defaultdict(list)
|
| 196 |
+
|
| 197 |
+
# skipped if text length > 77 which is CLIP limit
|
| 198 |
+
num_entries_skipped = 0
|
| 199 |
+
num_total_entries = 0
|
| 200 |
+
|
| 201 |
+
num_iter = len(dataset)
|
| 202 |
+
if debug:
|
| 203 |
+
num_iter = 100
|
| 204 |
+
|
| 205 |
+
jaccard_image_to_text = JaccardIndex(num_classes=2).to(device)
|
| 206 |
+
|
| 207 |
+
iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset")
|
| 208 |
+
for idx in iterator:
|
| 209 |
+
instance = dataset[idx]
|
| 210 |
+
|
| 211 |
+
instance_iou = 0.
|
| 212 |
+
for entry in instance:
|
| 213 |
+
num_total_entries += 1
|
| 214 |
+
|
| 215 |
+
# preprocess the image and text
|
| 216 |
+
unimodal = True if method == "clip-unimodal" else False
|
| 217 |
+
img, texts, orig_image = process_entry_image_to_text(entry, unimodal=unimodal)
|
| 218 |
+
|
| 219 |
+
appx_total_sent_len = np.sum([len(x.split(" ")) for x in texts])
|
| 220 |
+
if appx_total_sent_len > clamp_sentence_len:
|
| 221 |
+
# print(f"Skipping an entry since it's text has appx"\
|
| 222 |
+
# " {appx_total_sent_len} while CLIP cannot process beyond {clamp_sentence_len}")
|
| 223 |
+
num_entries_skipped += 1
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
# compute the relevance scores
|
| 227 |
+
if method in ["clip", "clip-unimodal"]:
|
| 228 |
+
try:
|
| 229 |
+
outputs = interpret_and_generate(model, img, texts, orig_image, return_outputs=True, show=False)
|
| 230 |
+
except:
|
| 231 |
+
num_entries_skipped += 1
|
| 232 |
+
continue
|
| 233 |
+
elif method == "random":
|
| 234 |
+
text = texts[0]
|
| 235 |
+
text_tokens = _tokenizer.encode(text)
|
| 236 |
+
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
|
| 237 |
+
outputs = [
|
| 238 |
+
{
|
| 239 |
+
"text_scores": np.random.uniform(low=0., high=1., size=len(text_tokens_decoded)),
|
| 240 |
+
"tokens_decoded": text_tokens_decoded,
|
| 241 |
+
}
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
# use the text relevance score to compute IoU w.r.t. ground truth text masks
|
| 245 |
+
# NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs
|
| 246 |
+
token_relevance_scores = outputs[0]["text_scores"]
|
| 247 |
+
if isinstance(token_relevance_scores, torch.Tensor):
|
| 248 |
+
token_relevance_scores = token_relevance_scores.cpu().numpy()
|
| 249 |
+
token_relevance_scores = apply_otsu_threshold(token_relevance_scores)
|
| 250 |
+
token_ground_truth_mask = process_text_mask(entry["text"], entry["text_mask"], outputs[0]["tokens_decoded"])
|
| 251 |
+
|
| 252 |
+
entry_iou = jaccard_image_to_text(
|
| 253 |
+
torch.from_numpy(token_relevance_scores).to(device),
|
| 254 |
+
torch.from_numpy(token_ground_truth_mask.astype(np.uint8)).to(device),
|
| 255 |
+
)
|
| 256 |
+
entry_iou = entry_iou.item()
|
| 257 |
+
|
| 258 |
+
instance_iou += (entry_iou / len(entry))
|
| 259 |
+
entry_level_metrics["iou"].append(entry_iou)
|
| 260 |
+
|
| 261 |
+
# capture instance (image-sentence pair) level IoU
|
| 262 |
+
instance_level_metrics["iou"].append(instance_iou)
|
| 263 |
+
|
| 264 |
+
print(f"CAUTION: Skipped {(num_entries_skipped / num_total_entries) * 100} % since these had length > 77 (CLIP limit).")
|
| 265 |
+
average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()}
|
| 266 |
+
|
| 267 |
+
return (
|
| 268 |
+
average_metrics,
|
| 269 |
+
instance_level_metrics,
|
| 270 |
+
entry_level_metrics
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
|
| 276 |
+
import argparse
|
| 277 |
+
parser = argparse.ArgumentParser("Evaluate Image-to-Text & Text-to-Image model")
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--eval_method", type=str, default="clip",
|
| 280 |
+
choices=["clip", "random", "clip-unimodal"],
|
| 281 |
+
help="Evaluation method to use",
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--ignore_cache", action="store_true",
|
| 285 |
+
help="Ignore cache and force re-generation of the results",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--debug", action="store_true",
|
| 289 |
+
help="Run evaluation on a small subset of the dataset",
|
| 290 |
+
)
|
| 291 |
+
args = parser.parse_args()
|
| 292 |
+
|
| 293 |
+
print_update("Using evaluation method: {}".format(args.eval_method))
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
clip.clip._MODELS = {
|
| 297 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 298 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
# specify device
|
| 302 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 303 |
+
|
| 304 |
+
# load CLIP model
|
| 305 |
+
print_update("Loading CLIP model...")
|
| 306 |
+
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
| 307 |
+
print()
|
| 308 |
+
|
| 309 |
+
# load PNG dataset
|
| 310 |
+
print_update("Loading PNG dataset...")
|
| 311 |
+
dataset = PNG(dataset_root=join(REPO_PATH, "data", "panoptic_narrative_grounding"), split="val2017")
|
| 312 |
+
print()
|
| 313 |
+
|
| 314 |
+
# evaluate
|
| 315 |
+
|
| 316 |
+
# save metrics
|
| 317 |
+
metrics_dir = join(REPO_PATH, "outputs")
|
| 318 |
+
os.makedirs(metrics_dir, exist_ok=True)
|
| 319 |
+
|
| 320 |
+
metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_text2image_metrics.pt")
|
| 321 |
+
if (not exists(metrics_path)) or args.ignore_cache:
|
| 322 |
+
print_update("Computing metrics for text-to-image grounding")
|
| 323 |
+
average_metrics, instance_level_metrics, entry_level_metrics = evaluate_text_to_image(
|
| 324 |
+
args.eval_method, dataset, debug=args.debug,
|
| 325 |
+
)
|
| 326 |
+
metrics = {
|
| 327 |
+
"average_metrics": average_metrics,
|
| 328 |
+
"instance_level_metrics":instance_level_metrics,
|
| 329 |
+
"entry_level_metrics": entry_level_metrics
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
torch.save(metrics, metrics_path)
|
| 333 |
+
print("TEXT2IMAGE METRICS SAVED TO:", metrics_path)
|
| 334 |
+
else:
|
| 335 |
+
print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.")
|
| 336 |
+
metrics = torch.load(metrics_path)
|
| 337 |
+
average_metrics = metrics["average_metrics"]
|
| 338 |
+
print("TEXT2IMAGE METRICS:", np.round(average_metrics["iou"], 4))
|
| 339 |
+
|
| 340 |
+
print()
|
| 341 |
+
|
| 342 |
+
metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_image2text_metrics.pt")
|
| 343 |
+
if (not exists(metrics_path)) or args.ignore_cache:
|
| 344 |
+
print_update("Computing metrics for image-to-text grounding")
|
| 345 |
+
average_metrics, instance_level_metrics, entry_level_metrics = evaluate_image_to_text(
|
| 346 |
+
args.eval_method, dataset, debug=args.debug,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
torch.save(
|
| 350 |
+
{
|
| 351 |
+
"average_metrics": average_metrics,
|
| 352 |
+
"instance_level_metrics":instance_level_metrics,
|
| 353 |
+
"entry_level_metrics": entry_level_metrics
|
| 354 |
+
},
|
| 355 |
+
metrics_path,
|
| 356 |
+
)
|
| 357 |
+
print("IMAGE2TEXT METRICS SAVED TO:", metrics_path)
|
| 358 |
+
else:
|
| 359 |
+
print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.")
|
| 360 |
+
metrics = torch.load(metrics_path)
|
| 361 |
+
average_metrics = metrics["average_metrics"]
|
| 362 |
+
print("IMAGE2TEXT METRICS:", np.round(average_metrics["iou"], 4))
|
clip_grounding/evaluation/qualitative_results.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Converts notebook for qualitative results to a python script."""
|
| 2 |
+
import sys
|
| 3 |
+
from os.path import join
|
| 4 |
+
|
| 5 |
+
from clip_grounding.utils.paths import REPO_PATH
|
| 6 |
+
sys.path.append(join(REPO_PATH, "CLIP_explainability/Transformer-MM-Explainability/"))
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
from matplotlib.patches import Patch
|
| 13 |
+
import CLIP.clip as clip
|
| 14 |
+
import cv2
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from glob import glob
|
| 17 |
+
from natsort import natsorted
|
| 18 |
+
|
| 19 |
+
from clip_grounding.utils.paths import REPO_PATH
|
| 20 |
+
from clip_grounding.utils.io import load_json
|
| 21 |
+
from clip_grounding.utils.visualize import set_latex_fonts, show_grid_of_images
|
| 22 |
+
from clip_grounding.utils.image import pad_to_square
|
| 23 |
+
from clip_grounding.datasets.png_utils import show_images_and_caption
|
| 24 |
+
from clip_grounding.datasets.png import (
|
| 25 |
+
PNG,
|
| 26 |
+
visualize_item,
|
| 27 |
+
overlay_segmask_on_image,
|
| 28 |
+
overlay_relevance_map_on_image,
|
| 29 |
+
get_text_colors,
|
| 30 |
+
)
|
| 31 |
+
from clip_grounding.evaluation.clip_on_png import (
|
| 32 |
+
process_entry_image_to_text,
|
| 33 |
+
process_entry_text_to_image,
|
| 34 |
+
interpret_and_generate,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# load dataset
|
| 38 |
+
dataset = PNG(dataset_root=join(REPO_PATH, "data/panoptic_narrative_grounding"), split="val2017")
|
| 39 |
+
|
| 40 |
+
# load CLIP model
|
| 41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 42 |
+
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def visualize_entry_text_to_image(entry, pad_images=True, figsize=(18, 5)):
|
| 46 |
+
test_img, test_texts, orig_image = process_entry_text_to_image(entry, unimodal=False)
|
| 47 |
+
outputs = interpret_and_generate(model, test_img, test_texts, orig_image, return_outputs=True, show=False)
|
| 48 |
+
relevance_map = outputs[0]["image_relevance"]
|
| 49 |
+
|
| 50 |
+
image_with_mask = overlay_segmask_on_image(entry["image"], entry["image_mask"])
|
| 51 |
+
if pad_images:
|
| 52 |
+
image_with_mask = pad_to_square(image_with_mask)
|
| 53 |
+
|
| 54 |
+
image_with_relevance_map = overlay_relevance_map_on_image(entry["image"], relevance_map)
|
| 55 |
+
if pad_images:
|
| 56 |
+
image_with_relevance_map = pad_to_square(image_with_relevance_map)
|
| 57 |
+
|
| 58 |
+
text_colors = get_text_colors(entry["text"], entry["text_mask"])
|
| 59 |
+
|
| 60 |
+
show_images_and_caption(
|
| 61 |
+
[image_with_mask, image_with_relevance_map],
|
| 62 |
+
entry["text"], text_colors, figsize=figsize,
|
| 63 |
+
image_xlabels=["Ground truth segmentation", "Predicted relevance map"]
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def create_and_save_gif(filenames, save_path, **kwargs):
|
| 68 |
+
import imageio
|
| 69 |
+
images = []
|
| 70 |
+
for filename in filenames:
|
| 71 |
+
images.append(imageio.imread(filename))
|
| 72 |
+
imageio.mimsave(save_path, images, **kwargs)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
idx = 100
|
| 76 |
+
instance = dataset[idx]
|
| 77 |
+
|
| 78 |
+
instance_dir = join(REPO_PATH, "figures", f"instance-{idx}")
|
| 79 |
+
os.makedirs(instance_dir, exist_ok=True)
|
| 80 |
+
|
| 81 |
+
for i, entry in enumerate(instance):
|
| 82 |
+
del entry["full_caption"]
|
| 83 |
+
|
| 84 |
+
visualize_entry_text_to_image(entry, pad_images=False, figsize=(19, 4))
|
| 85 |
+
|
| 86 |
+
save_path = instance_dir
|
| 87 |
+
plt.savefig(join(instance_dir, f"viz-{i}.png"), bbox_inches="tight")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
filenames = natsorted(glob(join(instance_dir, "viz-*.png")))
|
| 91 |
+
save_path = join(REPO_PATH, "media", "sample.gif")
|
| 92 |
+
|
| 93 |
+
create_and_save_gif(filenames, save_path, duration=3)
|
clip_grounding/utils/image.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image operations."""
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def center_crop(im: Image):
|
| 7 |
+
width, height = im.size
|
| 8 |
+
new_width = width if width < height else height
|
| 9 |
+
new_height = height if height < width else width
|
| 10 |
+
|
| 11 |
+
left = (width - new_width)/2
|
| 12 |
+
top = (height - new_height)/2
|
| 13 |
+
right = (width + new_width)/2
|
| 14 |
+
bottom = (height + new_height)/2
|
| 15 |
+
|
| 16 |
+
# Crop the center of the image
|
| 17 |
+
im = im.crop((left, top, right, bottom))
|
| 18 |
+
|
| 19 |
+
return im
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def pad_to_square(im: Image, color=(0, 0, 0)):
|
| 23 |
+
im = deepcopy(im)
|
| 24 |
+
width, height = im.size
|
| 25 |
+
|
| 26 |
+
vert_pad = (max(width, height) - height) // 2
|
| 27 |
+
hor_pad = (max(width, height) - width) // 2
|
| 28 |
+
|
| 29 |
+
if len(im.mode) == 3:
|
| 30 |
+
color = (0, 0, 0)
|
| 31 |
+
elif len(im.mode) == 1:
|
| 32 |
+
color = 0
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError(f"Image mode not supported. Image has {im.mode} channels.")
|
| 35 |
+
|
| 36 |
+
return add_margin(im, vert_pad, hor_pad, vert_pad, hor_pad, color=color)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def add_margin(pil_img, top, right, bottom, left, color=(0, 0, 0)):
|
| 40 |
+
"""Ref: https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/"""
|
| 41 |
+
width, height = pil_img.size
|
| 42 |
+
new_width = width + right + left
|
| 43 |
+
new_height = height + top + bottom
|
| 44 |
+
result = Image.new(pil_img.mode, (new_width, new_height), color)
|
| 45 |
+
result.paste(pil_img, (left, top))
|
| 46 |
+
return result
|
clip_grounding/utils/io.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for input-output loading/saving.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, List
|
| 6 |
+
import yaml
|
| 7 |
+
import pickle
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PrettySafeLoader(yaml.SafeLoader):
|
| 12 |
+
"""Custom loader for reading YAML files"""
|
| 13 |
+
def construct_python_tuple(self, node):
|
| 14 |
+
return tuple(self.construct_sequence(node))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
PrettySafeLoader.add_constructor(
|
| 18 |
+
u'tag:yaml.org,2002:python/tuple',
|
| 19 |
+
PrettySafeLoader.construct_python_tuple
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_yml(path: str, loader_type: str = 'default'):
|
| 24 |
+
"""Read params from a yml file.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
path (str): path to the .yml file
|
| 28 |
+
loader_type (str, optional): type of loader used to load yml files. Defaults to 'default'.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Any: object (typically dict) loaded from .yml file
|
| 32 |
+
"""
|
| 33 |
+
assert loader_type in ['default', 'safe']
|
| 34 |
+
|
| 35 |
+
loader = yaml.Loader if (loader_type == "default") else PrettySafeLoader
|
| 36 |
+
|
| 37 |
+
with open(path, 'r') as f:
|
| 38 |
+
data = yaml.load(f, Loader=loader)
|
| 39 |
+
|
| 40 |
+
return data
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def save_yml(data: dict, path: str):
|
| 44 |
+
"""Save params in the given yml file path.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
data (dict): data object to save
|
| 48 |
+
path (str): path to .yml file to be saved
|
| 49 |
+
"""
|
| 50 |
+
with open(path, 'w') as f:
|
| 51 |
+
yaml.dump(data, f, default_flow_style=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_pkl(path: str, encoding: str = "ascii") -> Any:
|
| 55 |
+
"""Loads a .pkl file.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
path (str): path to the .pkl file
|
| 59 |
+
encoding (str, optional): encoding to use for loading. Defaults to "ascii".
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Any: unpickled object
|
| 63 |
+
"""
|
| 64 |
+
return pickle.load(open(path, "rb"), encoding=encoding)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def save_pkl(data: Any, path: str) -> None:
|
| 68 |
+
"""Saves given object into .pkl file
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
data (Any): object to be saved
|
| 72 |
+
path (str): path to the location to be saved at
|
| 73 |
+
"""
|
| 74 |
+
with open(path, 'wb') as f:
|
| 75 |
+
pickle.dump(data, f)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def load_json(path: str) -> dict:
|
| 79 |
+
"""Helper to load json file"""
|
| 80 |
+
with open(path, 'rb') as f:
|
| 81 |
+
data = json.load(f)
|
| 82 |
+
return data
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def save_json(data: dict, path: str):
|
| 86 |
+
"""Helper to save `dict` as .json file."""
|
| 87 |
+
with open(path, 'w') as f:
|
| 88 |
+
json.dump(data, f)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def load_txt(path: str) -> List:
|
| 92 |
+
"""Loads lines of a .txt file.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
path (str): path to the .txt file
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
List: lines of .txt file
|
| 99 |
+
"""
|
| 100 |
+
with open(path) as f:
|
| 101 |
+
lines = f.read().splitlines()
|
| 102 |
+
return lines
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def save_txt(data: dict, path: str):
|
| 106 |
+
"""Writes data (lines) to a txt file.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
data (dict): List of strings
|
| 110 |
+
path (str): path to .txt file
|
| 111 |
+
"""
|
| 112 |
+
assert isinstance(data, list)
|
| 113 |
+
|
| 114 |
+
lines = "\n".join(data)
|
| 115 |
+
with open(path, "w") as f:
|
| 116 |
+
f.write(str(lines))
|
clip_grounding/utils/log.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for logging"""
|
| 2 |
+
import logging
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from termcolor import colored
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def color(string: str, color_name: str = 'yellow') -> str:
|
| 8 |
+
"""Returns colored string for output to terminal"""
|
| 9 |
+
return colored(string, color_name)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def print_update(message: str, width: int = 140, fillchar: str = ":", color="yellow") -> str:
|
| 13 |
+
"""Prints an update message
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
message (str): message
|
| 17 |
+
width (int): width of new update message
|
| 18 |
+
fillchar (str): character to be filled to L and R of message
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
str: print-ready update message
|
| 22 |
+
"""
|
| 23 |
+
message = message.center(len(message) + 2, " ")
|
| 24 |
+
print(colored(message.center(width, fillchar), color))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def set_logger(log_path):
|
| 28 |
+
"""Set the logger to log info in terminal and file `log_path`.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
log_path (str): path to the log file
|
| 32 |
+
"""
|
| 33 |
+
logger = logging.getLogger()
|
| 34 |
+
logger.setLevel(logging.INFO)
|
| 35 |
+
|
| 36 |
+
if not logger.handlers:
|
| 37 |
+
# Logging to a file
|
| 38 |
+
file_handler = logging.FileHandler(log_path)
|
| 39 |
+
file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
| 40 |
+
logger.addHandler(file_handler)
|
| 41 |
+
|
| 42 |
+
# Logging to console
|
| 43 |
+
stream_handler = logging.StreamHandler()
|
| 44 |
+
stream_handler.setFormatter(logging.Formatter('%(message)s'))
|
| 45 |
+
logger.addHandler(stream_handler)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def tqdm_iterator(items, desc=None, bar_format=None, **kwargs):
|
| 49 |
+
tqdm._instances.clear()
|
| 50 |
+
iterator = tqdm(
|
| 51 |
+
items,
|
| 52 |
+
desc=desc,
|
| 53 |
+
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}',
|
| 54 |
+
**kwargs,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return iterator
|
clip_grounding/utils/paths.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Path helpers for the relfm project."""
|
| 2 |
+
from os.path import join, abspath, dirname
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
REPO_PATH = dirname(dirname(dirname(abspath(__file__))))
|
| 6 |
+
DATA_ROOT = join(REPO_PATH, "data")
|
| 7 |
+
|
| 8 |
+
DATASET_ROOTS = {
|
| 9 |
+
"PNG": join(DATA_ROOT, "panoptic_narrative_grounding"),
|
| 10 |
+
}
|
clip_grounding/utils/visualize.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for visualization"""
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import cv2
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# define predominanat colors
|
| 10 |
+
COLORS = {
|
| 11 |
+
"pink": (242, 116, 223),
|
| 12 |
+
"cyan": (46, 242, 203),
|
| 13 |
+
"red": (255, 0, 0),
|
| 14 |
+
"green": (0, 255, 0),
|
| 15 |
+
"blue": (0, 0, 255),
|
| 16 |
+
"yellow": (255, 255, 0),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def show_single_image(image: np.ndarray, figsize: tuple = (8, 8), title: str = None, titlesize=18, cmap: str = None, ticks=False, save=False, save_path=None):
|
| 21 |
+
"""Show a single image."""
|
| 22 |
+
fig, ax = plt.subplots(1, 1, figsize=figsize)
|
| 23 |
+
|
| 24 |
+
if isinstance(image, Image.Image):
|
| 25 |
+
image = np.asarray(image)
|
| 26 |
+
|
| 27 |
+
ax.set_title(title, fontsize=titlesize)
|
| 28 |
+
ax.imshow(image, cmap=cmap)
|
| 29 |
+
|
| 30 |
+
if not ticks:
|
| 31 |
+
ax.set_xticks([])
|
| 32 |
+
ax.set_yticks([])
|
| 33 |
+
|
| 34 |
+
if save:
|
| 35 |
+
plt.savefig(save_path, bbox_inches='tight')
|
| 36 |
+
|
| 37 |
+
plt.show()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def show_grid_of_images(
|
| 41 |
+
images: np.ndarray, n_cols: int = 4, figsize: tuple = (8, 8),
|
| 42 |
+
cmap=None, subtitles=None, title=None, subtitlesize=18,
|
| 43 |
+
save=False, save_path=None, titlesize=20,
|
| 44 |
+
):
|
| 45 |
+
"""Show a grid of images."""
|
| 46 |
+
n_cols = min(n_cols, len(images))
|
| 47 |
+
|
| 48 |
+
copy_of_images = images.copy()
|
| 49 |
+
for i, image in enumerate(copy_of_images):
|
| 50 |
+
if isinstance(image, Image.Image):
|
| 51 |
+
image = np.asarray(image)
|
| 52 |
+
images[i] = image
|
| 53 |
+
|
| 54 |
+
if subtitles is None:
|
| 55 |
+
subtitles = [None] * len(images)
|
| 56 |
+
|
| 57 |
+
n_rows = int(np.ceil(len(images) / n_cols))
|
| 58 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
|
| 59 |
+
for i, ax in enumerate(axes.flat):
|
| 60 |
+
if i < len(images):
|
| 61 |
+
if len(images[i].shape) == 2 and cmap is None:
|
| 62 |
+
cmap="gray"
|
| 63 |
+
ax.imshow(images[i], cmap=cmap)
|
| 64 |
+
ax.set_title(subtitles[i], fontsize=subtitlesize)
|
| 65 |
+
ax.axis('off')
|
| 66 |
+
fig.set_tight_layout(True)
|
| 67 |
+
plt.suptitle(title, y=0.8, fontsize=titlesize)
|
| 68 |
+
|
| 69 |
+
if save:
|
| 70 |
+
plt.savefig(save_path, bbox_inches='tight')
|
| 71 |
+
plt.close()
|
| 72 |
+
else:
|
| 73 |
+
plt.show()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def show_keypoint_matches(
|
| 77 |
+
img1, kp1, img2, kp2, matches,
|
| 78 |
+
K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0)),
|
| 79 |
+
choose_matches="random",
|
| 80 |
+
):
|
| 81 |
+
"""Displays matches found in the pair of images"""
|
| 82 |
+
if choose_matches == "random":
|
| 83 |
+
selected_matches = np.random.choice(matches, K)
|
| 84 |
+
elif choose_matches == "all":
|
| 85 |
+
K = len(matches)
|
| 86 |
+
selected_matches = matches
|
| 87 |
+
elif choose_matches == "topk":
|
| 88 |
+
selected_matches = matches[:K]
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"Unknown value for choose_matches: {choose_matches}")
|
| 91 |
+
|
| 92 |
+
# color each match with a different color
|
| 93 |
+
cmap = matplotlib.cm.get_cmap('gist_rainbow', K)
|
| 94 |
+
colors = [[int(x*255) for x in cmap(i)[:3]] for i in np.arange(0,K)]
|
| 95 |
+
drawMatches_args.update({"matchColor": -1, "singlePointColor": (100, 100, 100)})
|
| 96 |
+
|
| 97 |
+
img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args)
|
| 98 |
+
show_single_image(
|
| 99 |
+
img3,
|
| 100 |
+
figsize=figsize,
|
| 101 |
+
title=f"[{choose_matches.upper()}] Selected K = {K} matches between the pair of images.",
|
| 102 |
+
)
|
| 103 |
+
return img3
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def draw_kps_on_image(image: np.ndarray, kps: np.ndarray, color=COLORS["red"], radius=3, thickness=-1, return_as="numpy"):
|
| 107 |
+
"""
|
| 108 |
+
Draw keypoints on image.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
image: Image to draw keypoints on.
|
| 112 |
+
kps: Keypoints to draw. Note these should be in (x, y) format.
|
| 113 |
+
"""
|
| 114 |
+
if isinstance(image, Image.Image):
|
| 115 |
+
image = np.asarray(image)
|
| 116 |
+
|
| 117 |
+
for kp in kps:
|
| 118 |
+
image = cv2.circle(
|
| 119 |
+
image, (int(kp[0]), int(kp[1])), radius=radius, color=color, thickness=thickness)
|
| 120 |
+
|
| 121 |
+
if return_as == "PIL":
|
| 122 |
+
return Image.fromarray(image)
|
| 123 |
+
|
| 124 |
+
return image
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_concat_h(im1, im2):
|
| 128 |
+
"""Concatenate two images horizontally"""
|
| 129 |
+
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
|
| 130 |
+
dst.paste(im1, (0, 0))
|
| 131 |
+
dst.paste(im2, (im1.width, 0))
|
| 132 |
+
return dst
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_concat_v(im1, im2):
|
| 136 |
+
"""Concatenate two images vertically"""
|
| 137 |
+
dst = Image.new('RGB', (im1.width, im1.height + im2.height))
|
| 138 |
+
dst.paste(im1, (0, 0))
|
| 139 |
+
dst.paste(im2, (0, im1.height))
|
| 140 |
+
return dst
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def show_images_with_keypoints(images: list, kps: list, radius=15, color=(0, 220, 220), figsize=(10, 8), return_images=False, save=False, save_path="sample.png"):
|
| 144 |
+
assert len(images) == len(kps)
|
| 145 |
+
|
| 146 |
+
# generate
|
| 147 |
+
images_with_kps = []
|
| 148 |
+
for i in range(len(images)):
|
| 149 |
+
img_with_kps = draw_kps_on_image(images[i], kps[i], radius=radius, color=color, return_as="PIL")
|
| 150 |
+
images_with_kps.append(img_with_kps)
|
| 151 |
+
|
| 152 |
+
# show
|
| 153 |
+
show_grid_of_images(images_with_kps, n_cols=len(images), figsize=figsize, save=save, save_path=save_path)
|
| 154 |
+
|
| 155 |
+
if return_images:
|
| 156 |
+
return images_with_kps
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs):
|
| 160 |
+
try:
|
| 161 |
+
plt.rcParams.update({
|
| 162 |
+
"text.usetex": usetex,
|
| 163 |
+
"font.family": "serif",
|
| 164 |
+
"font.serif": ["Computer Modern Roman"],
|
| 165 |
+
"font.size": fontsize,
|
| 166 |
+
**kwargs,
|
| 167 |
+
})
|
| 168 |
+
if show_sample:
|
| 169 |
+
plt.figure()
|
| 170 |
+
plt.title("Sample $y = x^2$")
|
| 171 |
+
plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o")
|
| 172 |
+
plt.grid()
|
| 173 |
+
plt.show()
|
| 174 |
+
except:
|
| 175 |
+
print("Failed to setup LaTeX fonts. Proceeding without.")
|
| 176 |
+
pass
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_colors(num_colors, palette="jet"):
|
| 180 |
+
cmap = plt.get_cmap(palette)
|
| 181 |
+
colors = [cmap(i) for i in np.linspace(0, 1, num_colors)]
|
| 182 |
+
return colors
|
| 183 |
+
|
example_images/Amsterdam.png
ADDED
|
example_images/London.png
ADDED
|
example_images/dogs_on_bed.png
ADDED
|
example_images/harrypotter.png
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
anyio==3.6.1
|
| 2 |
+
appnope==0.1.3
|
| 3 |
+
argon2-cffi==21.3.0
|
| 4 |
+
argon2-cffi-bindings==21.2.0
|
| 5 |
+
asttokens==2.0.5
|
| 6 |
+
attrs==21.4.0
|
| 7 |
+
Babel==2.10.3
|
| 8 |
+
backcall==0.2.0
|
| 9 |
+
beautifulsoup4==4.11.1
|
| 10 |
+
bleach==5.0.0
|
| 11 |
+
captum==0.5.0
|
| 12 |
+
certifi==2022.6.15
|
| 13 |
+
cffi==1.15.0
|
| 14 |
+
charset-normalizer==2.0.12
|
| 15 |
+
cycler==0.11.0
|
| 16 |
+
debugpy==1.6.0
|
| 17 |
+
decorator==5.1.1
|
| 18 |
+
defusedxml==0.7.1
|
| 19 |
+
entrypoints==0.4
|
| 20 |
+
executing==0.8.3
|
| 21 |
+
fastjsonschema==2.15.3
|
| 22 |
+
fonttools==4.33.3
|
| 23 |
+
ftfy==6.1.1
|
| 24 |
+
htmlmin==0.1.12
|
| 25 |
+
idna==3.3
|
| 26 |
+
ImageHash==4.2.1
|
| 27 |
+
imageio==2.19.3
|
| 28 |
+
imgaug==0.4.0
|
| 29 |
+
importlib-metadata==4.11.4
|
| 30 |
+
ipdb==0.13.9
|
| 31 |
+
ipykernel==6.15.0
|
| 32 |
+
ipython==8.4.0
|
| 33 |
+
ipython-genutils==0.2.0
|
| 34 |
+
ipywidgets==7.7.1
|
| 35 |
+
jedi==0.18.1
|
| 36 |
+
Jinja2==3.1.2
|
| 37 |
+
joblib==1.1.0
|
| 38 |
+
json5==0.9.8
|
| 39 |
+
jsonschema==4.6.0
|
| 40 |
+
jupyter-client==7.3.4
|
| 41 |
+
jupyter-core==4.10.0
|
| 42 |
+
jupyter-server==1.18.0
|
| 43 |
+
jupyterlab==3.4.3
|
| 44 |
+
jupyterlab-pygments==0.2.2
|
| 45 |
+
jupyterlab-server==2.14.0
|
| 46 |
+
jupyterlab-widgets==1.1.1
|
| 47 |
+
kiwisolver==1.4.3
|
| 48 |
+
MarkupSafe==2.1.1
|
| 49 |
+
matplotlib==3.5.2
|
| 50 |
+
matplotlib-inline==0.1.3
|
| 51 |
+
missingno==0.5.1
|
| 52 |
+
mistune==0.8.4
|
| 53 |
+
multimethod==1.8
|
| 54 |
+
natsort==8.1.0
|
| 55 |
+
nbclassic==0.3.7
|
| 56 |
+
nbclient==0.6.4
|
| 57 |
+
nbconvert==6.5.0
|
| 58 |
+
nbformat==5.4.0
|
| 59 |
+
nest-asyncio==1.5.5
|
| 60 |
+
networkx==2.8.4
|
| 61 |
+
notebook==6.4.12
|
| 62 |
+
notebook-shim==0.1.0
|
| 63 |
+
numpy==1.23.0
|
| 64 |
+
opencv-python==4.6.0.66
|
| 65 |
+
packaging==21.3
|
| 66 |
+
pandas==1.4.3
|
| 67 |
+
pandas-profiling==3.2.0
|
| 68 |
+
pandocfilters==1.5.0
|
| 69 |
+
parso==0.8.3
|
| 70 |
+
pexpect==4.8.0
|
| 71 |
+
phik==0.12.2
|
| 72 |
+
pickleshare==0.7.5
|
| 73 |
+
Pillow==9.1.1
|
| 74 |
+
prometheus-client==0.14.1
|
| 75 |
+
prompt-toolkit==3.0.29
|
| 76 |
+
psutil==5.9.1
|
| 77 |
+
ptyprocess==0.7.0
|
| 78 |
+
pure-eval==0.2.2
|
| 79 |
+
pycparser==2.21
|
| 80 |
+
pydantic==1.9.1
|
| 81 |
+
Pygments==2.12.0
|
| 82 |
+
pyparsing==3.0.9
|
| 83 |
+
pyrsistent==0.18.1
|
| 84 |
+
python-dateutil==2.8.2
|
| 85 |
+
pytz==2022.1
|
| 86 |
+
PyWavelets==1.3.0
|
| 87 |
+
PyYAML==6.0
|
| 88 |
+
pyzmq==23.2.0
|
| 89 |
+
regex==2022.6.2
|
| 90 |
+
requests==2.28.0
|
| 91 |
+
scikit-image==0.19.3
|
| 92 |
+
scikit-learn==1.1.1
|
| 93 |
+
scipy==1.8.1
|
| 94 |
+
seaborn==0.11.2
|
| 95 |
+
Send2Trash==1.8.0
|
| 96 |
+
Shapely==1.8.2
|
| 97 |
+
six==1.16.0
|
| 98 |
+
sniffio==1.2.0
|
| 99 |
+
soupsieve==2.3.2.post1
|
| 100 |
+
stack-data==0.3.0
|
| 101 |
+
tangled-up-in-unicode==0.2.0
|
| 102 |
+
termcolor==1.1.0
|
| 103 |
+
terminado==0.15.0
|
| 104 |
+
threadpoolctl==3.1.0
|
| 105 |
+
tifffile==2022.5.4
|
| 106 |
+
tinycss2==1.1.1
|
| 107 |
+
toml==0.10.2
|
| 108 |
+
torch==1.11.0
|
| 109 |
+
torchmetrics==0.9.1
|
| 110 |
+
torchvision==0.12.0
|
| 111 |
+
tornado==6.1
|
| 112 |
+
tqdm==4.64.0
|
| 113 |
+
traitlets==5.3.0
|
| 114 |
+
typing_extensions==4.2.0
|
| 115 |
+
urllib3==1.26.9
|
| 116 |
+
visions==0.7.4
|
| 117 |
+
wcwidth==0.2.5
|
| 118 |
+
webencodings==0.5.1
|
| 119 |
+
websocket-client==1.3.3
|
| 120 |
+
widgetsnbextension==3.6.1
|
| 121 |
+
zipp==3.8.0
|