import os import sys from pathlib import Path import textwrap import re import ast import os import numpy as np import torch import torchvision.transforms as T from PIL import Image import matplotlib.pyplot as plt import matplotlib.pylab as pylab pylab.rcParams['figure.figsize'] = 20, 12 import cv2 import base64 import io from decode_string import decode_bbox_from_caption EOD_SYMBOL = "" BOI_SYMBOL = "" EOI_SYMBOL = "" EOC_SYMBOL = "" EOL_SYMBOL = "" BOP_SYMBOL="" EOP_SYMBOL="" BOO_SYMBOL="" EOO_SYMBOL="" DOM_SYMBOL="" SPECIAL_SYMBOLS = [EOD_SYMBOL, BOI_SYMBOL, EOI_SYMBOL, EOC_SYMBOL, EOL_SYMBOL] def add_location_symbols(quantized_size): custom_sp_symbols = [] for symbol in SPECIAL_SYMBOLS: custom_sp_symbols.append(symbol) for symbol in [BOP_SYMBOL, EOP_SYMBOL, BOO_SYMBOL, EOO_SYMBOL, DOM_SYMBOL]: custom_sp_symbols.append(symbol) for i in range(quantized_size ** 2): token_name = f"" custom_sp_symbols.append(token_name) return custom_sp_symbols def imshow(img, file_name = "tmp.jpg", caption='test'): # Create figure and axis objects fig, ax = plt.subplots() # Show image on axis ax.imshow(img[:, :, [2, 1, 0]]) ax.set_axis_off() # Set caption text # Add caption below image # ax.text(0.5, -0.1, caption, ha='center', transform=ax.transAxes) ax.text(0.5, -0.1, '\n'.join(textwrap.wrap(caption, 120)), ha='center', transform=ax.transAxes, fontsize=18) plt.savefig(file_name) plt.close() def is_overlapping(rect1, rect2): x1, y1, x2, y2 = rect1 x3, y3, x4, y4 = rect2 return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4) def draw_entity_box_on_image(image, collect_entity_location): """_summary_ Args: image (_type_): image or image path collect_entity_location (_type_): _description_ """ if isinstance(image, Image.Image): image_h = image.height image_w = image.width image = np.array(image)[:, :, [2, 1, 0]] elif isinstance(image, str): if os.path.exists(image): pil_img = Image.open(image).convert("RGB") image = np.array(pil_img)[:, :, [2, 1, 0]] image_h = pil_img.height image_w = pil_img.width else: raise ValueError(f"invaild image path, {image}") elif isinstance(image, torch.Tensor): # pdb.set_trace() image_tensor = image.cpu() reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None] reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None] image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean pil_img = T.ToPILImage()(image_tensor) image_h = pil_img.height image_w = pil_img.width image = np.array(pil_img)[:, :, [2, 1, 0]] else: raise ValueError(f"invaild image format, {type(image)} for {image}") if len(collect_entity_location) == 0: return image new_image = image.copy() previous_locations = [] previous_bboxes = [] text_offset = 10 text_offset_original = 4 text_size = max(0.07 * min(image_h, image_w) / 100, 0.5) text_line = int(max(1 * min(image_h, image_w) / 512, 1)) box_line = int(max(2 * min(image_h, image_w) / 512, 2)) text_height = text_offset # init for (phrase, x1_norm, y1_norm, x2_norm, y2_norm) in collect_entity_location: x1, y1, x2, y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h) # draw bbox # random color color = tuple(np.random.randint(0, 255, size=3).tolist()) new_image = cv2.rectangle(new_image, (x1, y1), (x2, y2), color, box_line) # add phrase name # decide the text location first for x_prev, y_prev in previous_locations: if abs(x1 - x_prev) < abs(text_offset) and abs(y1 - y_prev) < abs(text_offset): y1 += text_height if y1 < 2 * text_offset: y1 += text_offset + text_offset_original # add text background (text_width, text_height), _ = cv2.getTextSize(phrase, cv2.FONT_HERSHEY_SIMPLEX, text_size, text_line) text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - text_height - text_offset_original, x1 + text_width, y1 for prev_bbox in previous_bboxes: while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox): text_bg_y1 += text_offset text_bg_y2 += text_offset y1 += text_offset if text_bg_y2 >= image_h: text_bg_y1 = max(0, image_h - text_height - text_offset_original) text_bg_y2 = image_h y1 = max(0, image_h - text_height - text_offset_original + text_offset) break alpha = 0.5 for i in range(text_bg_y1, text_bg_y2): for j in range(text_bg_x1, text_bg_x2): if i < image_h and j < image_w: new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(color)).astype(np.uint8) cv2.putText( new_image, phrase, (x1, y1 - text_offset_original), cv2.FONT_HERSHEY_SIMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA ) previous_locations.append((x1, y1)) previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) return new_image def visualize_results_on_image(img_path, caption, quantized_size=16, save_path=f"show_box_on_image.jpg", show=True): # collect_entity_location = decode_phrase_with_bbox_from_caption(caption, quantized_size=quantized_size) collect_entity_location = decode_bbox_from_caption(caption, quantized_size=quantized_size) image = draw_entity_box_on_image(img_path, collect_entity_location) if show: imshow(image, file_name=save_path, caption=caption) else: # return a PIL Image image = image[:, :, [2, 1, 0]] pil_image = Image.fromarray(image) return pil_image if __name__ == "__main__": caption = "a wet suit is at in the picture" print(decode_bbox_from_caption(caption))