Spaces:
Running
Running
from .datasets.ab_dataset import ABDataset | |
import matplotlib.pyplot as plt | |
from torchvision.utils import make_grid | |
import math | |
import torch | |
def visualize_classes_image_classification(dataset: ABDataset, class_to_idx_map, rename_map, | |
fig_save_path: str, num_imgs_per_class=2, max_num_classes=20, | |
unknown_class_idx=None): | |
idx_to_images = {} | |
idx_to_class = {} | |
idx_to_original_idx = {} | |
reach_max_num_class_limit = False | |
for i, (c, idx) in enumerate(class_to_idx_map.items()): | |
if unknown_class_idx is not None and idx == unknown_class_idx: | |
continue | |
idx_to_images[idx] = [] | |
idx_to_class[idx] = c | |
idx_to_original_idx[idx] = dataset.raw_classes.index(c) | |
if unknown_class_idx is not None and len(idx_to_images.keys()) == max_num_classes - 1: | |
reach_max_num_class_limit = True | |
break | |
if unknown_class_idx is None and len(idx_to_images.keys()) == max_num_classes: | |
reach_max_num_class_limit = True | |
break | |
if unknown_class_idx is not None: | |
idx_to_images[unknown_class_idx] = [] | |
idx_to_class[unknown_class_idx] = ['(unknown classes)'] | |
full_flags = {k: False for k in idx_to_images.keys()} | |
i = 0 | |
while True: | |
x, y = dataset[i] | |
i += 1 | |
y = int(y) | |
if full_flags[y]: | |
continue | |
idx_to_images[y] += [x] | |
if len(idx_to_images[y]) == num_imgs_per_class: | |
full_flags[y] = True | |
if all(full_flags.values()): | |
break | |
shown_num_classes = len(idx_to_images.keys()) | |
if reach_max_num_class_limit: | |
shown_num_classes += 1 | |
num_cols = 3 | |
num_rows = math.ceil(shown_num_classes / num_cols) | |
plt.figure(figsize=(6.4, 4.8 * num_rows // 2)) | |
draw_i = 1 | |
for class_idx, imgs in idx_to_images.items(): | |
class_name = idx_to_class[class_idx] | |
grid = make_grid(imgs, normalize=True) | |
plt.subplot(num_rows, num_cols, draw_i) | |
draw_i += 1 | |
plt.axis('off') | |
img = grid.permute(1, 2, 0).numpy() | |
plt.imshow(img) | |
if unknown_class_idx is not None and class_idx == unknown_class_idx: | |
plt.title(f'(unknown classes)\n' | |
f'current index: {class_idx}') | |
else: | |
class_i = idx_to_original_idx[class_idx] | |
if class_name in rename_map.keys(): | |
renamed_class = rename_map[class_name] | |
plt.title(f'{class_i}-th original class\n' | |
f'"{class_name}" (→ "{renamed_class}")\n' | |
f'current index: {class_idx}') | |
else: | |
plt.title(f'{class_i}-th original class\n' | |
f'"{class_name}"\n' | |
f'current index: {class_idx}') | |
if reach_max_num_class_limit: | |
plt.subplot(num_rows, num_cols, draw_i) | |
plt.axis('off') | |
plt.imshow(torch.ones_like(grid).permute(1, 2, 0).numpy()) | |
plt.title(f'(Show up to {max_num_classes} classes...)') | |
plt.tight_layout() | |
plt.savefig(fig_save_path, dpi=300) | |
plt.clf() | |
def visualize_classes_in_object_detection(dataset: ABDataset, class_to_idx_map, rename_map, | |
fig_save_path: str, num_imgs_per_class=2, max_num_classes=20, | |
unknown_class_idx=None): | |
idx_to_images = {} | |
idx_to_class = {} | |
idx_to_original_idx = {} | |
reach_max_num_class_limit = False | |
for i, (c, idx) in enumerate(class_to_idx_map.items()): | |
if unknown_class_idx is not None and idx == unknown_class_idx: | |
continue | |
idx_to_images[idx] = [] | |
idx_to_class[idx] = c | |
idx_to_original_idx[idx] = dataset.raw_classes.index(c) | |
if unknown_class_idx is not None and len(idx_to_images.keys()) == max_num_classes - 1: | |
reach_max_num_class_limit = True | |
break | |
if unknown_class_idx is None and len(idx_to_images.keys()) == max_num_classes: | |
reach_max_num_class_limit = True | |
break | |
if unknown_class_idx is not None: | |
idx_to_images[unknown_class_idx] = [] | |
idx_to_class[unknown_class_idx] = ['(unknown classes)'] | |
full_flags = {k: False for k in idx_to_images.keys()} | |
# print(idx_to_images.keys()) | |
ii = 0 | |
import time | |
start_time = time.time() | |
while True: | |
# print(dataset[i]) | |
x, y = dataset[ii][:2] | |
ii += 1 | |
cur_map = {} | |
for label_info in y: | |
if sum(label_info[1:]) == 0: # pad label | |
break | |
ci = label_info[0] | |
print(f'cur ci: {ci}') | |
# print(ci, label_info) | |
if ci in cur_map.keys(): | |
continue # do not visualize multiple objects in an image | |
if len(idx_to_images[ci]) == num_imgs_per_class: | |
full_flags[ci] = True | |
break | |
idx_to_images[ci] += [(x, label_info[1:])] | |
print(f'add image, ci: {ci}') | |
cur_map[ci] = 1 | |
if time.time() - start_time > 40: | |
break | |
if sum(list(full_flags.values())) > len(full_flags.values()) * 0.7: | |
break | |
shown_num_classes = len(idx_to_images.keys()) | |
if reach_max_num_class_limit: | |
shown_num_classes += 1 | |
num_cols = 3 | |
num_rows = math.ceil(shown_num_classes / num_cols) | |
plt.figure(figsize=(6.4, 4.8 * num_rows // 2)) | |
from torchvision.transforms import ToTensor | |
from PIL import Image, ImageDraw | |
import numpy as np | |
def draw_bbox(img, bbox): | |
img = Image.fromarray(np.uint8(img.transpose(1, 2, 0))) | |
draw = ImageDraw.Draw(img) | |
draw.rectangle(bbox, outline=(255, 0, 0), width=6) | |
return np.array(img) | |
draw_i = 1 | |
for class_idx, imgs in idx_to_images.items(): | |
if len(imgs) == 0: | |
draw_i += 1 | |
continue | |
imgs, bboxes = [img[0] for img in imgs], [img[1] for img in imgs] | |
class_name = idx_to_class[class_idx] | |
# draw bbox | |
imgs = [draw_bbox(img, bbox) for img, bbox in zip(imgs, bboxes)] | |
imgs = [ToTensor()(img) for img in imgs] | |
grid = make_grid(imgs, normalize=True) | |
plt.subplot(num_rows, num_cols, draw_i) | |
draw_i += 1 | |
plt.axis('off') | |
img = grid.permute(1, 2, 0).numpy() | |
plt.imshow(img) | |
if unknown_class_idx is not None and class_idx == unknown_class_idx: | |
plt.title(f'(unknown classes)\n' | |
f'current index: {class_idx}') | |
else: | |
class_i = idx_to_original_idx[class_idx] | |
if class_name in rename_map.keys(): | |
renamed_class = rename_map[class_name] | |
plt.title(f'{class_i}-th original class\n' | |
f'"{class_name}" (→ "{renamed_class}")\n' | |
f'current index: {class_idx}') | |
else: | |
plt.title(f'{class_i}-th original class\n' | |
f'"{class_name}"\n' | |
f'current index: {class_idx}') | |
if reach_max_num_class_limit: | |
plt.subplot(num_rows, num_cols, draw_i) | |
plt.axis('off') | |
plt.imshow(torch.ones_like(grid).permute(1, 2, 0).numpy()) | |
plt.title(f'(Show up to {max_num_classes} classes...)') | |
plt.tight_layout() | |
plt.savefig(fig_save_path, dpi=300) | |
plt.clf() | |