|
from typing import List, Optional, Tuple, Union |
|
|
|
import os |
|
import torch |
|
import numpy as np |
|
import torch.nn as nn |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import torch.nn.functional as F |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from model.IXC.modeling_internlm_xcomposer2 import InternLMXComposer2ForCausalLM |
|
from model.IXC.modeling_internlm2 import InternLM2Model |
|
from model.sam2.build_sam import build_sam2_hf |
|
from model.sam2.utils.transforms import SAM2Transforms |
|
from transformers import TextStreamer |
|
try: |
|
from transformers.generation.streamers import BaseStreamer |
|
except: |
|
BaseStreamer = None |
|
|
|
|
|
def dice_loss( |
|
inputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
num_masks: float, |
|
scale=1000, |
|
eps=1e-6, |
|
): |
|
""" |
|
Compute the DICE loss, similar to generalized IOU for masks |
|
Args: |
|
inputs: A float tensor of arbitrary shape. |
|
The predictions for each example. |
|
targets: A float tensor with the same shape as inputs. Stores the binary |
|
classification label for each element in inputs |
|
(0 for the negative class and 1 for the positive class). |
|
""" |
|
inputs = inputs.sigmoid() |
|
inputs = inputs.flatten(1, 2) |
|
targets = targets.flatten(1, 2) |
|
numerator = 2 * (inputs / scale * targets).sum(-1) |
|
denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) |
|
loss = 1 - (numerator + eps) / (denominator + eps) |
|
loss = loss.sum() / (num_masks + 1e-8) |
|
return loss |
|
|
|
|
|
def sigmoid_ce_loss( |
|
inputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
num_masks: float, |
|
): |
|
""" |
|
Args: |
|
inputs: A float tensor of arbitrary shape. |
|
The predictions for each example. |
|
targets: A float tensor with the same shape as inputs. Stores the binary |
|
classification label for each element in inputs |
|
(0 for the negative class and 1 for the positive class). |
|
Returns: |
|
Loss tensor |
|
""" |
|
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
|
loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) |
|
return loss |
|
|
|
|
|
class GeoPixelMetaModel: |
|
def __init__( |
|
self, |
|
config, |
|
**kwargs, |
|
): |
|
super(GeoPixelMetaModel, self).__init__(config) |
|
self.config = config |
|
self.config.train_mask_decoder = getattr(self.config, "train_mask_decoder", kwargs.get("train_mask_decoder", False)) |
|
self.config.out_dim = getattr(self.config, "out_dim", kwargs.get("out_dim", 256)) |
|
self.vision_pretrained = kwargs.get("vision_pretrained", None) |
|
self.initialize_geopixel_modules(self.config) |
|
|
|
def initialize_geopixel_modules(self, config): |
|
|
|
self.visual_model = build_sam2_hf(self.vision_pretrained) |
|
|
|
self._transform = SAM2Transforms( |
|
resolution=self.visual_model.image_size, |
|
mask_threshold=0.0, |
|
max_hole_area=0.0, |
|
max_sprinkle_area=0.0, |
|
) |
|
|
|
self._bb_feat_sizes = [ |
|
(256, 256), |
|
(128, 128), |
|
(64, 64), |
|
] |
|
|
|
for param in self.visual_model.parameters(): |
|
param.requires_grad = False |
|
|
|
if config.train_mask_decoder: |
|
self.visual_model.sam_mask_decoder.train() |
|
for param in self.visual_model.sam_mask_decoder.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
in_dim = config.hidden_size |
|
out_dim = config.out_dim |
|
text_projection_layers = [ |
|
nn.Linear(in_dim, in_dim), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(in_dim, out_dim), |
|
nn.Dropout(0.0), |
|
] |
|
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_projection_layers)]) |
|
self.text_hidden_fcs.train() |
|
for param in self.text_hidden_fcs.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
class GeoPixelModel(GeoPixelMetaModel, InternLM2Model): |
|
def __init__( |
|
self, |
|
config, |
|
**kwargs, |
|
): |
|
super(GeoPixelModel, self).__init__(config, **kwargs) |
|
self.config.use_cache = False |
|
|
|
|
|
class GeoPixelForCausalLM(InternLMXComposer2ForCausalLM): |
|
def __init__(self,config,**kwargs,): |
|
|
|
self.ce_loss_weight = kwargs.pop("ce_loss_weight", None) |
|
self.dice_loss_weight = kwargs.pop("dice_loss_weight", None) |
|
self.bce_loss_weight = kwargs.pop("bce_loss_weight", None) |
|
self.seg_token_idx = kwargs.pop("seg_token_idx") |
|
|
|
super().__init__(config) |
|
self.model = GeoPixelModel(config, **kwargs) |
|
self.vocab_size = config.vocab_size |
|
self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
self.post_init() |
|
|
|
def encode_g_img(self, image): |
|
""" |
|
Calculates the image embeddings for the provided image |
|
Arguments: |
|
image (np.ndarray or str) |
|
""" |
|
if image is None: |
|
return None |
|
if isinstance(image, str): |
|
_, ext = os.path.splitext(image) |
|
if ext.lower() in {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp','.tif'}: |
|
image = Image.open(image) |
|
w, h = image.size |
|
_orig_hw = [(h, w)] |
|
else: |
|
print ('Unknow input format', image) |
|
return None |
|
else: |
|
assert isinstance(image, torch.Tensor) |
|
_orig_hw = [image.shape[:2]] |
|
image = self.model._transform(image) |
|
image = image[None, ...].to(self.device) |
|
assert ( len(image.shape) == 4 and image.shape[1] == 3), f"image must be of size 1x3xHxW, got {image.shape}" |
|
features = self.get_visual_embs(image) |
|
return features,_orig_hw |
|
|
|
def get_visual_embs(self, img_batch: torch.FloatTensor): |
|
with torch.no_grad(): |
|
torch.cuda.empty_cache() |
|
img_batch = img_batch.to(self.device) |
|
batch_size = img_batch.shape[0] |
|
assert ( |
|
len(img_batch.shape) == 4 and img_batch.shape[1] == 3 |
|
), f"grounding_img_batch must be of size Bx3xHxW, got {img_batch.shape}" |
|
backbone_out = self.model.visual_model.forward_image(img_batch) |
|
_, vision_feats, _, _ = self.model.visual_model._prepare_backbone_features(backbone_out) |
|
if self.model.visual_model.directly_add_no_mem_embed: |
|
vision_feats[-1] = vision_feats[-1] + self.model.visual_model.no_mem_embed |
|
feats = [ |
|
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) |
|
for feat, feat_size in zip(vision_feats[::-1], self.model._bb_feat_sizes[::-1]) |
|
][::-1] |
|
features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} |
|
return features |
|
|
|
def forward(self, **kwargs): |
|
return super().forward(**kwargs) if "past_key_values" in kwargs else self.model_forward(**kwargs) |
|
|
|
def model_forward( |
|
self, |
|
inference: bool = False, |
|
**kwargs, |
|
): |
|
samples = kwargs.get('samples', None) |
|
if samples and samples['data_type'][0] == 'grounding': |
|
kwargs['output_hidden_states'] = True |
|
kwargs['use_cache'] = False |
|
|
|
torch.cuda.empty_cache() |
|
outputs = super().forward(**kwargs) |
|
|
|
if inference: |
|
assert len(samples['text_input']) == 1 and len(samples['image'][0]) == 1 |
|
output_hidden_states = [outputs.hidden_states] |
|
outputs = None |
|
else: |
|
output_hidden_states = outputs.hidden_states |
|
|
|
hidden_states = [] |
|
assert len(self.model.text_hidden_fcs) == 1 |
|
hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1])) |
|
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) |
|
|
|
seg_token_mask = outputs.seg_token_mask |
|
pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)] |
|
image_g_batch = torch.cat(samples['image_g'][0],dim = 0) |
|
image_g_features = self.get_visual_embs(image_g_batch) |
|
ori_hw = samples['ori_hw'][0] |
|
all_pred_masks = [] |
|
for i in range(len(pred_embeddings)): |
|
if (pred_embeddings[i].numel()== 0): |
|
pred_masks.append([]) |
|
continue |
|
(sparse_embeddings, dense_embeddings,) = self.model.visual_model.sam_prompt_encoder( |
|
points=None, |
|
boxes=None, |
|
masks=None, |
|
text_embeds=pred_embeddings[i].unsqueeze(1), |
|
) |
|
batch_mode = (pred_embeddings[i].shape[0]>1) |
|
high_res_features = [ |
|
feat_level[i].unsqueeze(0) |
|
for feat_level in image_g_features["high_res_feats"] |
|
] |
|
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) |
|
image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16) |
|
low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder( |
|
image_embeddings=image_g_embeds, |
|
image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
repeat_image=batch_mode, |
|
multimask_output=False, |
|
high_res_features=high_res_features, |
|
) |
|
pred_masks = self.model._transform.postprocess_masks( |
|
low_res_masks, |
|
ori_hw[i], |
|
) |
|
all_pred_masks.append(pred_masks[:, 0]) |
|
|
|
|
|
model_output = outputs |
|
gt_masks = samples['masks'][0] |
|
pred_masks = all_pred_masks |
|
|
|
if inference: |
|
return { |
|
"pred_masks": pred_masks, |
|
"gt_masks": gt_masks, |
|
} |
|
|
|
ce_loss = model_output.loss |
|
ce_loss = ce_loss * self.ce_loss_weight |
|
mask_bce_loss = 0 |
|
mask_dice_loss = 0 |
|
num_masks = 0 |
|
|
|
for batch_idx in range(len(pred_masks)): |
|
cur_gt_masks = torch.stack( |
|
[ |
|
torch.from_numpy(gt_mask).to(dtype=pred_masks[batch_idx].dtype, device=pred_masks[batch_idx].device) |
|
for gt_mask in gt_masks[batch_idx] |
|
], |
|
dim=0 |
|
) |
|
cur_pred_masks = pred_masks[batch_idx] |
|
assert ( |
|
cur_gt_masks.shape[0] == cur_pred_masks.shape[0] |
|
), "gt_masks.shape: {}, pred_masks.shape: {}".format( |
|
cur_gt_masks.shape, cur_pred_masks.shape |
|
) |
|
mask_bce_loss += ( |
|
sigmoid_ce_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0]) |
|
* cur_gt_masks.shape[0] |
|
) |
|
mask_dice_loss += ( |
|
dice_loss(cur_pred_masks, cur_gt_masks, num_masks=cur_gt_masks.shape[0]) |
|
* cur_gt_masks.shape[0] |
|
) |
|
num_masks += cur_gt_masks.shape[0] |
|
|
|
mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) |
|
mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) |
|
mask_loss = mask_bce_loss + mask_dice_loss |
|
|
|
loss = ce_loss + mask_loss |
|
outputs = CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=model_output.logits, |
|
past_key_values=model_output.past_key_values, |
|
hidden_states=output_hidden_states, |
|
attentions=model_output.attentions, |
|
) |
|
outputs.ce_loss = ce_loss |
|
outputs.mask_bce_loss = mask_bce_loss |
|
outputs.mask_dice_loss = mask_dice_loss |
|
outputs.mask_loss = mask_loss |
|
else: |
|
outputs = super().forward(**kwargs) |
|
return outputs |
|
|
|
def evaluate( |
|
self, |
|
tokenizer, |
|
query: str, |
|
images: List[Tuple[str, str]] = [], |
|
hd_num: int = 9, |
|
history: List[Tuple[str, str]] = [], |
|
max_new_tokens: int = 1024, |
|
stream: bool = False, |
|
**kwargs, |
|
): |
|
with torch.no_grad(): |
|
inputs, im_mask, _ = self.interleav_wrap_chat(query, images, history=history, hd_num=hd_num) |
|
inputs = { |
|
k: v.to(self.device) |
|
for k, v in inputs.items() if torch.is_tensor(v) |
|
} |
|
eos_token_id = [ |
|
tokenizer.eos_token_id, |
|
|
|
] |
|
all_pred_masks = [] |
|
|
|
if stream: |
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
else: |
|
streamer = None |
|
|
|
outputs = self.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
im_mask=im_mask, |
|
input_ids = None, |
|
streamer= streamer, |
|
num_beams=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_p= 1.0, |
|
top_k = 0, |
|
eos_token_id=eos_token_id, |
|
repetition_penalty=1.0, |
|
infer_mode = 'base', |
|
output_hidden_states=True, |
|
return_dict_in_generate=True, |
|
**kwargs, |
|
) |
|
output_ids = outputs['sequences'] |
|
response = tokenizer.decode(output_ids[0].cpu().tolist(), skip_special_tokens=True) |
|
response = response.replace("[UNUSED_TOKEN_145]","") |
|
history = history + [(query, response)] |
|
if len(images)==1 and isinstance(images[0], str): |
|
output_hidden_states = outputs.hidden_states[-1] |
|
seg_token_mask = output_ids[:, 1:-1] == self.seg_token_idx |
|
inputs_embeds_len = inputs['inputs_embeds'].size(1) |
|
seg_token_mask = torch.cat( |
|
[ |
|
torch.zeros((seg_token_mask.shape[0], inputs_embeds_len)).bool().cuda(), |
|
seg_token_mask, |
|
], |
|
dim=1, |
|
) |
|
hidden_states = [] |
|
assert len(self.model.text_hidden_fcs) == 1 |
|
hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states)) |
|
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) |
|
pred_embeddings = [states[masks] for states, masks in zip(last_hidden_state, seg_token_mask)] |
|
image_g_features, ori_hw = self.encode_g_img(images[0]) |
|
|
|
for i in range(len(pred_embeddings)): |
|
if (pred_embeddings[i].numel()== 0): |
|
all_pred_masks.append([]) |
|
continue |
|
(sparse_embeddings,dense_embeddings,) = self.model.visual_model.sam_prompt_encoder( |
|
points=None, |
|
boxes=None, |
|
masks=None, |
|
text_embeds=pred_embeddings[i].unsqueeze(1), |
|
) |
|
batch_mode = (pred_embeddings[i].shape[0]>1) |
|
high_res_features = [ |
|
feat_level[i].unsqueeze(0) |
|
for feat_level in image_g_features["high_res_feats"] |
|
] |
|
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) |
|
image_g_embeds = image_g_features['image_embed'][i].unsqueeze(0).to(torch.bfloat16) |
|
|
|
low_res_masks, _, _ , _ = self.model.visual_model.sam_mask_decoder( |
|
image_embeddings=image_g_embeds, |
|
image_pe=self.model.visual_model.sam_prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embeddings, |
|
dense_prompt_embeddings=dense_embeddings, |
|
repeat_image=batch_mode, |
|
multimask_output=False, |
|
high_res_features=high_res_features, |
|
) |
|
pred_masks = self.model._transform.postprocess_masks( |
|
low_res_masks, |
|
ori_hw[i], |
|
) |
|
all_pred_masks.append(pred_masks[:, 0]) |
|
|
|
return response, all_pred_masks |