from typing import Callable, List, Optional, Tuple, Union import json import glob import math import numpy as np import os import torch import torch.nn as nn import torch.nn.functional as F import transformers from transformers import LlamaForCausalLM, CLIPVisionModel, BitsAndBytesConfig from peft import ( LoraConfig, get_peft_model, get_peft_model_state_dict, prepare_model_for_int8_training, set_peft_model_state_dict, ) from .llava.model.llava import LlavaLlamaForCausalLM from .segment_anything import build_sam_vit_l, build_sam_vit_h DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" def find_all_linear_names(model): cls = torch.nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') if 'mm_projector' in lora_module_names: lora_module_names.remove('mm_projector') return sorted(list(lora_module_names)) class LISA(nn.Module): def __init__(self, local_rank, seg_token_idx, tokenizer, llm_version, lora_r, precision, load_in_4bit=False, load_in_8bit=False, lora_target_modules=['q_proj', 'v_proj'], lora_alpha=16, lora_dropout=0.05, vision_tower='openai/clip-vit-large-patch14', mm_vision_select_layer=-2, freeze_lm=True, train_mask_decoder=True, out_dim=256, ): super().__init__() self.tokenizer = tokenizer self.image_token = tokenizer.cls_token_id self.precision = precision # LLaVA tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) if precision == "bf16": self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.bfloat16, cache_dir=None, low_cpu_mem_usage=True) elif precision == "fp16": if load_in_4bit: self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_4bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto', quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ) ) elif load_in_8bit: self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, load_in_8bit=True, cache_dir=None, low_cpu_mem_usage=True, device_map='auto') else: self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.half, cache_dir=None, low_cpu_mem_usage=True) else: self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.float32, cache_dir=None, low_cpu_mem_usage=True) self.lm.enable_input_require_grads() self.lm.gradient_checkpointing_enable() self.lm.config.use_cache = False model_vision_dict = self.lm.get_model().initialize_vision_modules(vision_tower=vision_tower, mm_vision_select_layer=mm_vision_select_layer, precision=precision) vision_config = model_vision_dict['vision_config'] vision_tower = self.lm.get_model().vision_tower[0] self.lm.model.config.eos_token_id = tokenizer.eos_token_id self.lm.model.config.bos_token_id = tokenizer.bos_token_id self.lm.model.config.pad_token_id = tokenizer.pad_token_id if vision_tower.device.type == 'meta': if precision == 'bf16': vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).cuda(local_rank) elif precision == 'fp16': vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.half, low_cpu_mem_usage=True).cuda(local_rank) else: vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda(local_rank) self.lm.get_model().vision_tower[0] = vision_tower else: if precision == "bf16": vision_tower.to(device='cuda', dtype=torch.bfloat16) elif precision == "fp16": vision_tower.to(device='cuda', dtype=torch.half) else: vision_tower.to(device='cuda', dtype=torch.float32) self.lm.config.tune_mm_mlp_adapter = False self.lm.config.freeze_mm_mlp_adapter = False self.lm.config.mm_use_im_start_end = True vision_config.use_im_start_end = True self.lm.config.sep_image_conv_front = False self.lm.initialize_vision_tokenizer(mm_use_im_start_end=True, tokenizer=tokenizer, num_new_tokens=num_new_tokens, device=local_rank, tune_mm_mlp_adapter=False) if freeze_lm: for n, param in self.lm.named_parameters(): param.requires_grad = False self.llm_version = llm_version self.seg_token_idx = seg_token_idx self.lm.resize_token_embeddings(len(tokenizer)) for n, p in self.lm.named_parameters(): if any([x in n for x in ['lm_head', 'embed_tokens']]) and p.shape[0] == len(tokenizer): p.requires_grad = True # SAM self.visual_model = build_sam_vit_h(None) for param in self.visual_model.parameters(): param.requires_grad = False if train_mask_decoder: self.visual_model.mask_decoder.train() for param in self.visual_model.mask_decoder.parameters(): param.requires_grad = True # Projection layer in_dim = self.lm.config.hidden_size text_fc = [nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0)] self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) def get_visual_embs(self, pixel_values: torch.FloatTensor): image_embeddings = self.visual_model.image_encoder(pixel_values) return image_embeddings def evaluate(self, images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=32, tokenizer=None): with torch.no_grad(): outputs = self.lm.generate(images=images_clip, input_ids=input_ids, max_new_tokens=max_new_tokens, num_beams=1, output_hidden_states=True, return_dict_in_generate=True) output_hidden_states = outputs.hidden_states[-1] output_ids = outputs.sequences seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx) last_embedding = None last_output_logit = None hidden_states = [] assert len(self.text_hidden_fcs) == 1 hidden_states.append(self.text_hidden_fcs[0](output_hidden_states)) last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) pred_embeddings = last_hidden_state[seg_token_mask] seg_token_counts = seg_token_mask.int().sum(-1) #[bs, ] seg_token_offset = seg_token_counts.cumsum(-1) seg_token_offset = torch.cat([torch.zeros(1).long().cuda(), seg_token_offset], dim=0) pred_embeddings_ = [] for i in range(len(seg_token_offset)-1): start_i, end_i = seg_token_offset[i], seg_token_offset[i+1] pred_embeddings_.append(pred_embeddings[start_i: end_i]) pred_embeddings = pred_embeddings_ image_embeddings = self.get_visual_embs(images) multimask_output = False pred_masks = [] for i in range(len(pred_embeddings)): sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder( points=None, boxes=None, masks=None, text_embeds=pred_embeddings[i].unsqueeze(1), ) sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype) low_res_masks, iou_predictions = self.visual_model.mask_decoder( image_embeddings=image_embeddings[i].unsqueeze(0), image_pe=self.visual_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) pred_mask = self.visual_model.postprocess_masks( low_res_masks, input_size=resize_list[i], original_size=original_size_list[i], ) pred_masks.append(pred_mask[:, 0]) return output_ids, pred_masks