Last commit not found
import logging | |
import math | |
import os | |
import re | |
from typing import List, Optional, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import nn | |
from torchvision.ops import roi_align | |
from transformers import ( | |
AutoConfig, | |
AutoModel, | |
AutoModelForCausalLM, | |
Qwen2Config, | |
Qwen2ForCausalLM, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
) | |
from transformers.generation.utils import GenerateOutput | |
from transformers.utils import logging, strtobool | |
from .clip import CLIPVisionTower | |
from .convnext import ConvNextVisionEncoder | |
logger = logging.get_logger(__name__) | |
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() | |
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() | |
IGNORE_INDEX = -100 | |
DEFAULT_PAD_TOKEN_INDEX = 0 | |
IMAGE_TOKEN_INDEX = -200 | |
DEFAULT_IMAGE_TOKEN = "<image>" | |
# For Objects | |
DEFAULT_OBJECT_TOKEN = "<obj<i>>" | |
DEFAULT_OBJECT_FEATURE_TOKEN = "<objfeat>" | |
DEFAULT_OBJECT_INDEX = -300 | |
# For Grounding | |
DEFAULT_GROUNDING_START = "<ground>" | |
DEFAULT_GROUNDING_END = "</ground>" | |
DEFAULT_GROUNDING_OBJECTS_START = "<objects>" | |
DEFAULT_GROUNDING_OBJECTS_END = "</objects>" | |
def is_fsdp_enabled(): | |
return ( | |
torch.distributed.is_available() | |
and torch.distributed.is_initialized() | |
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 | |
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 | |
) | |
class IdentityMap(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, *args, **kwargs): | |
return x | |
def config(self): | |
return {"mm_projector_type": "identity"} | |
class SimpleResBlock(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
self.pre_norm = nn.LayerNorm(channels) | |
self.proj = nn.Sequential( | |
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels) | |
) | |
def forward(self, x): | |
x = self.pre_norm(x) | |
return x + self.proj(x) | |
def build_vision_projector(config, start_hidden_size, delay_load=False, **kwargs): | |
projector_type = "mlp2x_gelu" | |
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) | |
if mlp_gelu_match: | |
mlp_depth = int(mlp_gelu_match.group(1)) | |
modules = [nn.Linear(start_hidden_size, config.hidden_size)] | |
for _ in range(1, mlp_depth): | |
modules.append(nn.GELU()) | |
modules.append(nn.Linear(config.hidden_size, config.hidden_size)) | |
return nn.Sequential(*modules) | |
if projector_type == "identity": | |
return IdentityMap() | |
raise ValueError(f"Unknown projector type: {projector_type}") | |
def get_token_slices(input_ids: torch.Tensor): | |
""" | |
Get slices of tokens based on special markers in the input tensor. | |
Args: | |
input_ids (torch.Tensor): A tensor of token IDs where IMAGE_TOKEN_INDEX represents an image token, | |
DEFAULT_OBJECT_INDEX represents an object token, and all other values represent text tokens. | |
Returns: | |
List[Dict[str, Any]]: A list of dictionaries where each dictionary contains the type of the | |
token slice ('text', 'image', 'object') and the span as a list of start and end indices. | |
""" | |
# define type markers and corresponding types | |
type_map = {IMAGE_TOKEN_INDEX: "image", DEFAULT_OBJECT_INDEX: "object"} | |
# find the positions of special markers | |
image_indices = torch.where(input_ids == IMAGE_TOKEN_INDEX)[0] | |
object_indices = torch.where(input_ids == DEFAULT_OBJECT_INDEX)[0] | |
if len(object_indices) > 0: | |
has_object = True | |
else: | |
has_object = False | |
# merge all the positions of special markers | |
special_indices = torch.cat((image_indices, object_indices)) | |
special_indices, _ = torch.sort(special_indices) | |
special_tokens = input_ids[special_indices] | |
slices = [] | |
start_idx = 0 | |
for i, idx in enumerate(special_indices): | |
if start_idx < idx: | |
slices.append({"type": "text", "span": [start_idx, idx.item()]}) | |
token_type = type_map[special_tokens[i].item()] | |
slices.append({"type": token_type, "span": [idx.item(), idx.item() + 1]}) | |
start_idx = idx.item() + 1 | |
if start_idx < len(input_ids): | |
slices.append({"type": "text", "span": [start_idx, len(input_ids)]}) | |
return slices, has_object | |
class StopWordStoppingCriteria(StoppingCriteria): | |
"""StopWord stopping criteria.""" | |
def __init__(self, tokenizer, stop_word): | |
self.tokenizer = tokenizer | |
self.stop_word = stop_word | |
self.length = len(self.stop_word) | |
def __call__(self, input_ids, *args, **kwargs) -> bool: | |
cur_text = self.tokenizer.decode(input_ids[0]) | |
cur_text = cur_text.replace("\r", "").replace("\n", "") | |
return cur_text[-self.length :] == self.stop_word | |
def get_stop_criteria( | |
tokenizer, | |
stop_words=[], | |
): | |
stop_criteria = StoppingCriteriaList() | |
for word in stop_words: | |
stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) | |
return stop_criteria | |
def gen_sineembed_for_position(pos_tensor, dim_of_pos_feats): | |
"""Generate sine position embedding from a position tensor. | |
Args: | |
pos_tensor (torch.Tensor): shape: [batch_size, N, 4]. the last dimension is [cx, cy, w, h] in | |
normalized coordinates in range [0, 1]. | |
out_dim (int): the output dimension of the position embedding. | |
Returns: | |
pos (torch.Tensor): shape: [batch_size, N, out_dim]. | |
""" | |
scale = 2 * math.pi | |
dim_t = torch.arange( | |
dim_of_pos_feats, dtype=torch.float32, device=pos_tensor.device | |
) | |
dim_t = 10000 ** (2 * (dim_t // 2) / dim_of_pos_feats) | |
x_embed = pos_tensor[:, :, 0] * scale | |
y_embed = pos_tensor[:, :, 1] * scale | |
pos_x = x_embed[:, :, None] / dim_t | |
pos_y = y_embed[:, :, None] / dim_t | |
pos_x = torch.stack( | |
(pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3 | |
).flatten(2) | |
pos_y = torch.stack( | |
(pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3 | |
).flatten(2) | |
if pos_tensor.size(-1) == 2: | |
pos = torch.cat((pos_y, pos_x), dim=2) | |
elif pos_tensor.size(-1) == 4: | |
w_embed = pos_tensor[:, :, 2] * scale | |
pos_w = w_embed[:, :, None] / dim_t | |
pos_w = torch.stack( | |
(pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3 | |
).flatten(2) | |
h_embed = pos_tensor[:, :, 3] * scale | |
pos_h = h_embed[:, :, None] / dim_t | |
pos_h = torch.stack( | |
(pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3 | |
).flatten(2) | |
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) | |
else: | |
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) | |
return pos | |
class MultiLevelROIVisualPrompt(nn.Module): | |
"""Initialize the MultiLevelROIVisualPrompt. | |
Args: | |
output_size (Optional[int]): The size of the output. Default is None. | |
channel_per_level (List[int]): List of channels per level. Default is [192, 384, 768, 1536]. | |
spatial_scale (Optional[float]): The spatial scale factor. Default is None. | |
with_additional_projection (bool): Whether to use additional projection. Default is False. | |
visual_prompt_hidden_size (int): The hidden size of the visual prompt. Default is 1024. | |
add_pos_embedding (bool): Whether to add position embedding. Default is False. | |
pos_embedding_dim (int): The dimension of the position embedding. Default is 1024. | |
""" | |
def __init__( | |
self, | |
output_size: int = None, | |
channel_per_level: List[int] = [192, 384, 768, 1536], | |
spatail_scale: float = None, | |
add_pos_embedding: bool = False, | |
pos_embedding_dim: int = 1024, | |
): | |
super(MultiLevelROIVisualPrompt, self).__init__() | |
self.output_size = output_size | |
self.channel_per_level = channel_per_level | |
self.spatail_scale = spatail_scale | |
self.add_pos_embedding = add_pos_embedding | |
self.pos_embedding_dim = pos_embedding_dim | |
def __call__( | |
self, | |
multi_level_features: List[torch.Tensor], | |
boxes: Union[torch.Tensor, List[torch.Tensor]], | |
) -> torch.Tensor: | |
"""Performs Region of Interest (RoI) Align operator on multi-level features. The RoI | |
feature on each scale will go through a different linear layer for projection. Different | |
RoI features will be summed up and then average pooled. | |
Args: | |
multi_level_features (Listp[Tensor[N, C, H, W]]): Feature maps from different levels | |
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) | |
format where the regions will be taken from. | |
Returns: | |
Tensor[1, K, C]: The output tensor that has the shape KxC, where K is the number of RoIs | |
""" | |
boxes[0] = boxes[0].float() | |
concat_multi_level_feature = [] | |
max_height = max([feature.shape[2] for feature in multi_level_features]) | |
max_width = max([feature.shape[3] for feature in multi_level_features]) | |
# interpolate to the same size | |
for level, feature in enumerate(multi_level_features): | |
if level != 0: | |
concat_multi_level_feature.append( | |
F.interpolate( | |
feature.float(), | |
size=(max_height, max_width), | |
mode="bilinear", | |
align_corners=False, | |
) | |
) | |
else: | |
concat_multi_level_feature.append(feature.float()) | |
concat_multi_level_feature = torch.cat(concat_multi_level_feature, dim=1) | |
out_box_feat = roi_align( | |
concat_multi_level_feature, | |
boxes, | |
output_size=self.output_size, | |
spatial_scale=self.spatail_scale, | |
) | |
# Average Pooling -> n,c -> 1,n,c | |
out_box_feat = out_box_feat.mean(dim=(2, 3)).reshape( | |
1, out_box_feat.shape[0], out_box_feat.shape[1] | |
) | |
if self.add_pos_embedding: | |
# note that this boxes is in xyxy, unormalized format, so we need to normalize it first | |
boxes = boxes[0] # (N, 4) | |
boxes = boxes.to(out_box_feat.dtype) | |
original_img_width = max_width / self.spatail_scale | |
original_img_height = max_height / self.spatail_scale | |
boxes[:, [0, 2]] = boxes[:, [0, 2]] / original_img_width | |
boxes[:, [1, 3]] = boxes[:, [1, 3]] / original_img_height | |
# convert from xyxy to cx, cy, w, h | |
boxes[:, 2] = boxes[:, 2] - boxes[:, 0] | |
boxes[:, 3] = boxes[:, 3] - boxes[:, 1] | |
boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2 | |
boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2 | |
pos_embed = gen_sineembed_for_position( | |
boxes.unsqueeze(0), self.pos_embedding_dim // 4 | |
) | |
out_box_feat = out_box_feat + pos_embed | |
return out_box_feat | |
class RexSeekQwenConfig(Qwen2Config): | |
model_type = "rexseek_qwen" | |
class RexSeekQwenForCausalLM(Qwen2ForCausalLM): | |
config_class = RexSeekQwenConfig | |
def __init__(self, config): | |
super().__init__(config) | |
# low resolusion vision encoder | |
vision_tower = getattr( | |
config, | |
"mm_vision_tower", | |
getattr(config, "vision_tower", None), | |
) | |
self.vision_tower = CLIPVisionTower( | |
vision_tower, | |
args=config, | |
) | |
# high resolusion vision encoder | |
self.vision_tower_aux = ConvNextVisionEncoder() | |
# vision projector | |
self.mm_projector = build_vision_projector( | |
config, start_hidden_size=2560 | |
) # projector for vision_tower | |
# projector for object token | |
self.mm_object_projector = build_vision_projector( | |
config, start_hidden_size=2880 | |
) | |
# visual prompt encoder | |
self.vocab_size = config.vocab_size | |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
# Initialize weights and apply final processing | |
self.box_encoder = MultiLevelROIVisualPrompt( | |
output_size=7, | |
channel_per_level=[192, 384, 768, 1536], # ConvNeXt Large | |
spatail_scale=192 / 768, | |
add_pos_embedding=True, | |
pos_embedding_dim=2880, | |
) | |
self.post_init() | |
print("model initialized") | |
def get_vision_tower(self): | |
vision_tower = getattr(self, "vision_tower", None) | |
if type(vision_tower) is list: | |
vision_tower = vision_tower[0] | |
return vision_tower | |
def get_vision_tower_aux(self): | |
vision_tower_aux = getattr(self, "vision_tower_aux", None) | |
if type(vision_tower_aux) is list: | |
vision_tower_aux = vision_tower_aux[0] | |
return vision_tower_aux | |
def get_model(self): | |
return self.model | |
def encode_images(self, images, images_aux): | |
low_res_feat = self.get_vision_tower()(images) | |
aux_output = self.get_vision_tower_aux()(images_aux) | |
visual_outputs_aux = aux_output["image_features"] | |
high_res_feat = aux_output["last_feat"] # (B, 1536, 24, 24) | |
# concat the low res features with the high res features | |
b, c, h, w = high_res_feat.shape # (2, 1536, 24, 24) | |
_, _, d = low_res_feat.shape # (2, 576, 1024) | |
high_res_feat = high_res_feat.view(b, c, h * w).transpose(1, 2) | |
image_features = torch.cat((low_res_feat, high_res_feat), dim=-1) | |
image_features = self.mm_projector(image_features) | |
return image_features, visual_outputs_aux | |
def encode_objects( | |
self, bboxes, visual_outputs_aux, dtype, num_gt_boxes_per_image=None | |
): | |
"""Encode object features from bounding boxes. | |
Args: | |
bboxes (torch.Tensor): bounding boxes in the shape of (N, 4) | |
image_features_before_proj (torch.Tensor): image features in the shape of (N, hidden_size) | |
Returns: | |
torch.Tensor: object features in the shape of (N, hidden_size) | |
""" | |
bbox_visual_outputs = [] | |
for batch_idx, boxes in enumerate(bboxes): | |
num_box = ( | |
num_gt_boxes_per_image[batch_idx] | |
if num_gt_boxes_per_image is not None | |
else len(boxes) | |
) | |
boxes = boxes[:num_box] | |
if len(boxes) == 0: | |
bbox_visual_outputs.append(None) | |
continue | |
multi_level_aux_features = [ | |
visual_output_aux[batch_idx].unsqueeze(0) | |
for visual_output_aux in visual_outputs_aux | |
] | |
out_vp_feat = self.box_encoder( | |
multi_level_aux_features, | |
[boxes], | |
).squeeze(0) | |
out_vp_feat = out_vp_feat.to(dtype) | |
out_vp_feat = self.mm_object_projector(out_vp_feat) | |
bbox_visual_outputs.append(out_vp_feat) | |
# b,n,c | |
return bbox_visual_outputs | |
def prepare_inputs_labels_for_multimodal( | |
self, | |
input_ids, | |
position_ids, | |
attention_mask, | |
past_key_values, | |
labels, | |
pixel_values=None, | |
pixel_values_aux=None, | |
gt_boxes=None, | |
num_gt_boxes_per_image=None, | |
): | |
if pixel_values is None: | |
return ( | |
input_ids, | |
position_ids, | |
attention_mask, | |
past_key_values, | |
None, | |
labels, | |
) | |
pixel_values, visual_outputs_aux = self.encode_images( | |
pixel_values, pixel_values_aux | |
) # (B, 576, 2048) | |
if gt_boxes is not None: | |
bbox_feats = self.encode_objects( | |
gt_boxes, visual_outputs_aux, pixel_values.dtype, num_gt_boxes_per_image | |
) | |
_labels = labels | |
_position_ids = position_ids | |
_attention_mask = attention_mask | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) | |
else: | |
attention_mask = attention_mask.bool() # padding mask in shaoe (B, L) | |
if position_ids is None: | |
position_ids = torch.arange( | |
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device | |
) | |
if labels is None: | |
labels = torch.full_like(input_ids, IGNORE_INDEX) | |
input_ids = [ | |
cur_input_ids[cur_attention_mask] | |
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) | |
] | |
labels = [ | |
cur_labels[cur_attention_mask] | |
for cur_labels, cur_attention_mask in zip(labels, attention_mask) | |
] | |
new_input_embeds = [] | |
new_labels = [] | |
cur_image_idx = 0 | |
cur_object_idx = 0 | |
for batch_idx, cur_input_ids in enumerate(input_ids): | |
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() | |
if num_images == 0: | |
cur_image_features = pixel_values[cur_image_idx] | |
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) | |
cur_input_embeds = torch.cat( | |
[cur_input_embeds_1, cur_image_features[0:0]], dim=0 | |
) | |
new_input_embeds.append(cur_input_embeds) | |
new_labels.append(labels[batch_idx]) | |
cur_image_idx += 1 | |
cur_object_idx += 1 | |
continue | |
cur_labels = labels[batch_idx] | |
token_slices, has_object = get_token_slices(cur_input_ids) | |
result_input_embeddings = [] | |
result_output_labels = [] | |
cur_gt_bnox_indice = 0 | |
cur_object_features = None | |
for slice in token_slices: | |
slice_type = slice["type"] | |
slice_span = slice["span"] | |
if slice_type == "text": | |
cur_input_ids_noim = cur_input_ids[slice_span[0] : slice_span[1]] | |
cur_labels_noim = cur_labels[slice_span[0] : slice_span[1]] | |
cur_input_embeds = self.get_model().embed_tokens(cur_input_ids_noim) | |
result_input_embeddings.append(cur_input_embeds) | |
result_output_labels.append(cur_labels_noim) | |
elif slice_type == "image": | |
cur_input_embeds = pixel_values[cur_image_idx] | |
result_input_embeddings.append(cur_input_embeds) | |
result_output_labels.append( | |
torch.full( | |
(cur_input_embeds.shape[0],), | |
IGNORE_INDEX, | |
device=cur_labels.device, | |
dtype=cur_labels.dtype, | |
) | |
) | |
cur_image_idx += 1 | |
elif slice_type == "object": | |
try: | |
result_input_embeddings.append( | |
bbox_feats[cur_object_idx][cur_gt_bnox_indice].unsqueeze(0) | |
) | |
except: | |
raise ValueError( | |
f"current boxe_feats.shape: {bbox_feats[cur_object_idx].shape}, " | |
) | |
cur_gt_bnox_indice += 1 | |
result_output_labels.append( | |
torch.full( | |
(1,), | |
IGNORE_INDEX, | |
device=cur_labels.device, | |
dtype=cur_labels.dtype, | |
) | |
) | |
cur_object_idx += 1 | |
result_input_embeddings = torch.cat(result_input_embeddings) | |
result_output_labels = torch.cat(result_output_labels) | |
assert len(result_output_labels) == len(result_input_embeddings) | |
new_input_embeds.append(result_input_embeddings) | |
new_labels.append(result_output_labels) | |
# Truncate sequences to max length as image embeddings can make the sequence longer | |
tokenizer_model_max_length = getattr( | |
self.config, "tokenizer_model_max_length", None | |
) | |
if tokenizer_model_max_length is not None: | |
new_input_embeds = [ | |
x[:tokenizer_model_max_length] for x in new_input_embeds | |
] | |
new_labels = [x[:tokenizer_model_max_length] for x in new_labels] | |
# Combine them | |
max_len = max(x.shape[0] for x in new_input_embeds) | |
batch_size = len(new_input_embeds) | |
new_input_embeds_padded = [] | |
new_labels_padded = torch.full( | |
(batch_size, max_len), | |
IGNORE_INDEX, | |
dtype=new_labels[0].dtype, | |
device=new_labels[0].device, | |
) | |
attention_mask = torch.zeros( | |
(batch_size, max_len), | |
dtype=attention_mask.dtype, | |
device=attention_mask.device, | |
) | |
position_ids = torch.zeros( | |
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device | |
) | |
for i, (cur_new_embed, cur_new_labels) in enumerate( | |
zip(new_input_embeds, new_labels) | |
): | |
cur_len = cur_new_embed.shape[0] | |
new_input_embeds_padded.append( | |
torch.cat( | |
( | |
cur_new_embed, | |
torch.zeros( | |
(max_len - cur_len, cur_new_embed.shape[1]), | |
dtype=cur_new_embed.dtype, | |
device=cur_new_embed.device, | |
), | |
), | |
dim=0, | |
) | |
) | |
if cur_len > 0: | |
new_labels_padded[i, :cur_len] = cur_new_labels | |
attention_mask[i, :cur_len] = True | |
position_ids[i, :cur_len] = torch.arange( | |
0, cur_len, dtype=position_ids.dtype, device=position_ids.device | |
) | |
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) | |
if _labels is None: | |
new_labels = None | |
else: | |
new_labels = new_labels_padded | |
if _attention_mask is None: | |
attention_mask = None | |
else: | |
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) | |
if _position_ids is None: | |
position_ids = None | |
return ( | |
None, | |
position_ids, | |
attention_mask, | |
past_key_values, | |
new_input_embeds, | |
new_labels, | |
) | |
def generate( | |
self, | |
inputs: Optional[torch.Tensor], | |
pixel_values: Optional[torch.Tensor], | |
pixel_values_aux: Optional[torch.Tensor], | |
position_ids: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
inputs_embeds: Optional[torch.Tensor] = None, | |
**kwargs, | |
) -> Union[GenerateOutput, torch.LongTensor]: | |
if inputs_embeds is None: | |
position_ids = kwargs.pop("position_ids", None) | |
attention_mask = kwargs.pop("attention_mask", None) | |
gt_boxes = kwargs.pop("gt_boxes", None) | |
num_gt_boxes_per_image = kwargs.pop("num_gt_boxes_per_image", None) | |
if pixel_values is not None: | |
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = ( | |
self.prepare_inputs_labels_for_multimodal( | |
inputs, | |
position_ids, | |
attention_mask, | |
past_key_values=None, | |
labels=None, | |
pixel_values=pixel_values, | |
pixel_values_aux=pixel_values_aux, | |
gt_boxes=gt_boxes, | |
num_gt_boxes_per_image=num_gt_boxes_per_image, | |
) | |
) | |
else: | |
inputs_embeds = self.get_model().embed_tokens(inputs) | |
return super().generate( | |
position_ids=position_ids, | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
**kwargs, | |
) | |
AutoConfig.register("rexseek_qwen", RexSeekQwenConfig) | |
AutoModelForCausalLM.register(RexSeekQwenConfig, RexSeekQwenForCausalLM) | |