Spaces:
Paused
Paused
File size: 7,276 Bytes
5885496 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
from typing import Callable, List, Optional, Tuple, Union
import json
import glob
import math
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from transformers import LlamaForCausalLM, CLIPVisionModel
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from .llava.model.llava import LlavaLlamaForCausalLM
from .segment_anything import build_sam_vit_l, build_sam_vit_h
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
if 'mm_projector' in lora_module_names:
lora_module_names.remove('mm_projector')
return sorted(list(lora_module_names))
class LISA(nn.Module):
def __init__(self,
local_rank,
seg_token_idx,
tokenizer,
llm_version,
lora_r,
precision,
lora_target_modules=['q_proj', 'v_proj'],
lora_alpha=16,
lora_dropout=0.05,
vision_tower='openai/clip-vit-large-patch14',
mm_vision_select_layer=-2,
freeze_lm=True,
train_mask_decoder=True,
out_dim=256,
):
super().__init__()
self.tokenizer = tokenizer
self.image_token = tokenizer.cls_token_id
self.precision = precision
# LLaVA
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
if precision == "bf16":
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.bfloat16, cache_dir=None, low_cpu_mem_usage=True)
else:
self.lm = LlavaLlamaForCausalLM.from_pretrained(llm_version, torch_dtype=torch.float32, cache_dir=None, low_cpu_mem_usage=True)
self.lm.enable_input_require_grads()
self.lm.gradient_checkpointing_enable()
self.lm.config.use_cache = False
model_vision_dict = self.lm.get_model().initialize_vision_modules(vision_tower=vision_tower, mm_vision_select_layer=mm_vision_select_layer, precision=precision)
vision_config = model_vision_dict['vision_config']
vision_tower = self.lm.get_model().vision_tower[0]
self.lm.model.config.eos_token_id = tokenizer.eos_token_id
self.lm.model.config.bos_token_id = tokenizer.bos_token_id
self.lm.model.config.pad_token_id = tokenizer.pad_token_id
if vision_tower.device.type == 'meta':
if precision == 'bf16':
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).cuda(local_rank)
else:
vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float32, low_cpu_mem_usage=True).cuda(local_rank)
self.lm.get_model().vision_tower[0] = vision_tower
else:
if precision == "bf16":
vision_tower.to(device='cuda', dtype=torch.bfloat16)
else:
vision_tower.to(device='cuda', dtype=torch.float32)
self.lm.config.tune_mm_mlp_adapter = False
self.lm.config.freeze_mm_mlp_adapter = False
self.lm.config.mm_use_im_start_end = True
vision_config.use_im_start_end = True
self.lm.config.sep_image_conv_front = False
self.lm.initialize_vision_tokenizer(mm_use_im_start_end=True, tokenizer=tokenizer, num_new_tokens=num_new_tokens, device=local_rank, tune_mm_mlp_adapter=False)
if freeze_lm:
for n, param in self.lm.named_parameters():
param.requires_grad = False
self.llm_version = llm_version
self.seg_token_idx = seg_token_idx
self.lm.resize_token_embeddings(len(tokenizer))
for n, p in self.lm.named_parameters():
if any([x in n for x in ['lm_head', 'embed_tokens']]) and p.shape[0] == len(tokenizer):
p.requires_grad = True
# SAM
self.visual_model = build_sam_vit_h(None)
for param in self.visual_model.parameters():
param.requires_grad = False
if train_mask_decoder:
self.visual_model.mask_decoder.train()
for param in self.visual_model.mask_decoder.parameters():
param.requires_grad = True
# Projection layer
in_dim = self.lm.config.hidden_size
text_fc = [nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0)]
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
def get_visual_embs(self, pixel_values: torch.FloatTensor):
image_embeddings = self.visual_model.image_encoder(pixel_values)
return image_embeddings
def evaluate(self, images_clip, images, input_ids, resize_list, original_size_list, max_new_tokens=32, tokenizer=None):
outputs = self.lm.generate(images=images_clip, input_ids=input_ids, max_new_tokens=max_new_tokens, num_beams=1, output_hidden_states=True, return_dict_in_generate=True)
output_hidden_states = outputs.hidden_states[-1]
output_ids = outputs.sequences
seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx)
last_embedding = None
last_output_logit = None
hidden_states = []
assert len(self.text_hidden_fcs) == 1
hidden_states.append(self.text_hidden_fcs[0](output_hidden_states))
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
pred_embeddings = last_hidden_state[seg_token_mask]
seg_token_counts = seg_token_mask.int().sum(-1) #[bs, ]
seg_token_offset = seg_token_counts.cumsum(-1)
seg_token_offset = torch.cat([torch.zeros(1).long().cuda(), seg_token_offset], dim=0)
pred_embeddings_ = []
for i in range(len(seg_token_offset)-1):
start_i, end_i = seg_token_offset[i], seg_token_offset[i+1]
pred_embeddings_.append(pred_embeddings[start_i: end_i])
pred_embeddings = pred_embeddings_
image_embeddings = self.get_visual_embs(images)
multimask_output = False
pred_masks = []
for i in range(len(pred_embeddings)):
sparse_embeddings, dense_embeddings = self.visual_model.prompt_encoder(
points=None,
boxes=None,
masks=None,
text_embeds=pred_embeddings[i].unsqueeze(1),
)
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
low_res_masks, iou_predictions = self.visual_model.mask_decoder(
image_embeddings=image_embeddings[i].unsqueeze(0),
image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
pred_mask = self.visual_model.postprocess_masks(
low_res_masks,
input_size=resize_list[i],
original_size=original_size_list[i],
)
pred_masks.append(pred_mask[:, 0])
return output_ids, pred_masks
|