Last commit not found
from PIL import Image | |
import re | |
from typing import List, Union | |
import numpy as np | |
import torch | |
import torchvision.transforms.functional as F | |
from transformers import AutoTokenizer | |
from transformers.processing_utils import ProcessorMixin | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
IGNORE_INDEX = -100 | |
DEFAULT_PAD_TOKEN_INDEX = 0 | |
IMAGE_TOKEN_INDEX = -200 | |
DEFAULT_IMAGE_TOKEN = "<image>" | |
# For Objects | |
DEFAULT_OBJECT_TOKEN = "<obj<i>>" | |
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>" | |
DEFAULT_OBJECT_INDEX = -300 | |
# For Grounding | |
DEFAULT_GROUNDING_START = "<ground>" | |
DEFAULT_GROUNDING_END = "</ground>" | |
DEFAULT_GROUNDING_OBJECTS_START = "<objects>" | |
DEFAULT_GROUNDING_OBJECTS_END = "</objects>" | |
def xyxy_to_xywh(boxes): | |
""" | |
Convert boxes from xywh to xyxy format. | |
Parameters: | |
boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes. | |
Each box is represented as [x, y, x, y]. | |
Returns: | |
numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, w, h]. | |
""" | |
boxes = np.array(boxes) | |
x_min, y_min, x_max, y_max = ( | |
boxes[:, 0], | |
boxes[:, 1], | |
boxes[:, 2], | |
boxes[:, 3], | |
) | |
w = x_max - x_min | |
h = y_max - y_min | |
return np.stack([x_min, y_min, w, h], axis=1) | |
def xywh_to_xyxy(boxes): | |
""" | |
Convert boxes from xywh to xyxy format. | |
Parameters: | |
boxes (numpy.ndarray): An array of shape (N, 4) where N is the number of boxes. | |
Each box is represented as [x, y, width, height]. | |
Returns: | |
numpy.ndarray: An array of shape (N, 4) where each box is represented as [x_min, y_min, x_max, y_max]. | |
""" | |
boxes = np.array(boxes) | |
x, y, width, height = ( | |
boxes[:, 0], | |
boxes[:, 1], | |
boxes[:, 2], | |
boxes[:, 3], | |
) | |
x_max = x + width | |
y_max = y + height | |
return np.stack([x, y, x_max, y_max], axis=1) | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
def pad_boxes(gt_boxes, old_size): | |
old_w, old_h = old_size | |
gt_boxes = np.array(gt_boxes).astype(np.float32) | |
# Calculate the padding added | |
if old_w > old_h: | |
pad_top = (old_w - old_h) // 2 | |
pad_bottom = old_w - old_h - pad_top | |
pad_left, pad_right = 0, 0 | |
else: | |
pad_left = (old_h - old_w) // 2 | |
pad_right = old_h - old_w - pad_left | |
pad_top, pad_bottom = 0, 0 | |
# Adjust the boxes for padding | |
gt_boxes[:, 0] += pad_left # x | |
gt_boxes[:, 1] += pad_top # y | |
return gt_boxes | |
def resize_boxes(gt_boxes, old_size, new_size): | |
old_w, old_h = old_size | |
new_h, new_w = new_size | |
gt_boxes = np.array(gt_boxes).astype(np.float32) | |
# Calculate scale factors | |
scale_x = new_w / max(old_w, old_h) | |
scale_y = new_h / max(old_w, old_h) | |
# Resize the boxes | |
gt_boxes[:, 0] *= scale_x # x | |
gt_boxes[:, 1] *= scale_y # y | |
gt_boxes[:, 2] *= scale_x # w | |
gt_boxes[:, 3] *= scale_y # h | |
return gt_boxes | |
def split_special_strings(input_string: str, special_strings: list[str] = None): | |
"""Split the input string into a list of strings, keeping the special strings. | |
Args: | |
input_string (str): The input string to split. | |
Example: | |
input_string = "<image>\n<obj0><objfeat><obj1><objfeat>\n I am happy today." | |
output = ['<image>', '\n<obj0>', '<objfeat>', '<obj1>', '<objfeat>', '\n I am happy today.'] | |
Returns: | |
list: A list of strings, with the special strings separated from the rest of the input string. | |
""" | |
# Create a regex pattern to match the special strings | |
pattern = "|".join(map(re.escape, special_strings)) | |
# Split the input string using the pattern, keeping the special strings in the result | |
split_list = re.split(f"({pattern})", input_string) | |
# Remove empty strings from the list | |
split_list = [s for s in split_list if s] | |
return split_list | |
def tokenizer_image_object_token(prompt, tokenizer): | |
bos_token_id = tokenizer.bos_token_id | |
split_tokens = [DEFAULT_IMAGE_TOKEN, DEFAULT_OBJECT_FEATURE_TOKEN] | |
chunks = split_special_strings(prompt, split_tokens) | |
input_encode = [bos_token_id] if bos_token_id else [] | |
for chunk in chunks: | |
if chunk == DEFAULT_IMAGE_TOKEN: | |
input_encode.append(IMAGE_TOKEN_INDEX) | |
elif chunk == DEFAULT_OBJECT_FEATURE_TOKEN: | |
input_encode.append(DEFAULT_OBJECT_INDEX) | |
else: | |
input_encode.extend(tokenizer.encode(chunk, add_special_tokens=False)) | |
return input_encode | |
class RexSeekProcessor(ProcessorMixin): | |
attributes = ["image_processor", "tokenizer"] | |
image_processor_class = "AutoImageProcessor" | |
tokenizer_class = "AutoTokenizer" | |
def __init__(self, image_processor=None, tokenizer: AutoTokenizer = None, **kwargs): | |
# self.image_processor = image_processor | |
# self.tokenizer = tokenizer | |
super().__init__(image_processor, tokenizer) | |
self._special_tokens = None | |
self.template = dict( | |
SYSTEM=("<|im_start|>system\n{system}<|im_end|>\n"), | |
INSTRUCTION=( | |
"<|im_start|>user\n{input}<|im_end|>\n" "<|im_start|>assistant\n" | |
), | |
SUFFIX="<|im_end|>", | |
SUFFIX_AS_EOS=True, | |
SEP="\n", | |
STOP_WORDS=["<|im_end|>", "<|endoftext|>"], | |
) | |
def process( | |
self, | |
image: Union[str, Image.Image], | |
bbox: List[List[int]], | |
question: str, | |
): | |
"""Prepare input data for inference. | |
Args: | |
image (Union[str, Image.Image]): The image to process. | |
bbox (List[List[int]]): A list of bounding boxes for the image. Each bounding box should | |
be in order of [x, y, x , y]. | |
question (str): The question to ask about the image. | |
""" | |
data_dict = {} | |
# step1 load image | |
if type(image) == str: | |
image = Image.open(image).convert("RGB") | |
ori_w, ori_h = F.get_image_size(image) | |
image = expand2square( | |
image, | |
tuple(int(x * 255) for x in self.image_processor.image_mean), | |
) | |
pad_w, pad_h = F.get_image_size(image) | |
image_aux = self.image_processor.preprocess(image, return_tensors="pt")[ | |
"pixel_values" | |
][0] | |
resize_h, resize_w = image_aux.shape[-2:] | |
data_dict["pixel_values_aux"] = image_aux.unsqueeze(0) | |
image = image_aux.clone() | |
image = torch.nn.functional.interpolate( | |
image[None], | |
size=[336, 336], | |
mode="bilinear", | |
align_corners=False, | |
)[0] | |
data_dict["pixel_values"] = image.unsqueeze(0) | |
# step2 load boxes | |
bbox = xyxy_to_xywh(bbox) | |
bbox = pad_boxes(bbox, (ori_w, ori_h)) | |
bbox = resize_boxes(bbox, (pad_w, pad_h), (resize_h, resize_w)) | |
data_dict["gt_boxes"] = torch.tensor(xywh_to_xyxy(bbox)).unsqueeze(0) | |
# step3 prepare question | |
total_num_boxes = len(bbox) | |
obj_tokens = [ | |
DEFAULT_OBJECT_TOKEN.replace("<i>", str(i)) for i in range(total_num_boxes) | |
] | |
obj_tokens = ( | |
DEFAULT_OBJECT_FEATURE_TOKEN.join(obj_tokens) + DEFAULT_OBJECT_FEATURE_TOKEN | |
) | |
question = question.replace(DEFAULT_IMAGE_TOKEN, "") | |
question = DEFAULT_IMAGE_TOKEN + "\n" + obj_tokens + "\n" + question | |
inputs = "" | |
inputs += self.template["INSTRUCTION"].format(input=question, round=1) | |
# step4 tokenize question | |
input_ids = tokenizer_image_object_token(inputs, self.tokenizer) | |
data_dict["input_ids"] = torch.tensor(input_ids).unsqueeze(0) | |
return data_dict | |
RexSeekProcessor.register_for_auto_class() | |