Spaces:
Running
Running
| import timm | |
| from timm.models._factory import load_checkpoint | |
| import torch | |
| import os | |
| from typing import List, Union, Optional, Tuple | |
| from torch import nn | |
| from torch.jit import Final | |
| from einops import rearrange, repeat | |
| from einops.layers.torch import Rearrange | |
| from utils.dl.common.model import get_model_device, set_module, get_module, get_model_latency, get_model_size, LayerActivation3 | |
| import torch.nn.functional as F | |
| from utils.common.log import logger | |
| from transformers import AutoTokenizer | |
| import torch.nn.functional as F | |
| from maskrcnn_benchmark.modeling.detector.generalized_vl_rcnn import GeneralizedVLRCNN | |
| from maskrcnn_benchmark.config import cfg | |
| from maskrcnn_benchmark.structures.bounding_box import BoxList | |
| from torchvision import transforms as T | |
| import matplotlib.pyplot as plt | |
| import nltk | |
| import re | |
| from copy import deepcopy | |
| from abc import ABC, abstractmethod | |
| from methods.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util | |
| from methods.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util, LoRA | |
| from new_impl.cv.elasticdnn.api.model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel | |
| from methods.elasticdnn.model.base import Abs, KTakesAll, ElasticDNNUtil, Layer_WrappedWithFBS | |
| from transformers.models.bert.modeling_bert import BertSelfAttention | |
| from transformers import BertConfig | |
| import math | |
| from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |
| def collect_mm_fn(batch): | |
| if len(batch[0]) == 2: | |
| dict = {'images' : [], 'targets' : []} | |
| else: | |
| dict = {'images' : [], 'targets' : [], "info_imgs" : [], "ids" : []} | |
| for item in batch: | |
| if len(item) == 2: | |
| img, new_target = item | |
| if len(new_target) == 0: | |
| continue | |
| dict['images'].append(img) | |
| dict['targets'].append(new_target) | |
| else: | |
| img, new_target, info_imgs, ids = item | |
| if len(new_target) == 0: | |
| continue | |
| dict['images'].append(img) | |
| dict['targets'].append(new_target) | |
| dict['info_imgs'].append(info_imgs) | |
| dict['ids'].append(ids) | |
| return dict, torch.Tensor([0]) | |
| def run_ner(caption): | |
| noun_phrases = find_noun_phrases(caption) | |
| noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] | |
| noun_phrases = [phrase for phrase in noun_phrases if phrase != ''] | |
| relevant_phrases = noun_phrases | |
| labels = noun_phrases | |
| tokens_positive = [] | |
| for entity, label in zip(relevant_phrases, labels): | |
| try: | |
| # search all occurrences and mark them as different entities | |
| for m in re.finditer(entity, caption.lower()): | |
| tokens_positive.append([[m.start(), m.end()]]) | |
| except: | |
| print("noun entities:", noun_phrases) | |
| print("entity:", entity) | |
| print("caption:", caption.lower()) | |
| return tokens_positive | |
| def build_transform(cfg, min_image_size): | |
| """ | |
| Creates a basic transformation that was used to train the models | |
| """ | |
| # we are loading images with OpenCV, so we don't need to convert them | |
| # to BGR, they are already! So all we need to do is to normalize | |
| # by 255 if we want to convert to BGR255 format, or flip the channels | |
| # if we want it to be in RGB in [0-1] range. | |
| if cfg.INPUT.TO_BGR255: | |
| to_bgr_transform = T.Lambda(lambda x: x * 255) | |
| else: | |
| to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]]) | |
| normalize_transform = T.Normalize( | |
| mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD | |
| ) | |
| transform = T.Compose( | |
| [ | |
| T.ToPILImage(), | |
| T.Resize(min_image_size) if min_image_size is not None else lambda x: x, | |
| T.ToTensor(), | |
| to_bgr_transform, | |
| normalize_transform, | |
| ] | |
| ) | |
| return transform | |
| def remove_punctuation(text: str) -> str: | |
| punct = ['|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', | |
| '\'', '\"', '’', '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.' | |
| ] | |
| for p in punct: | |
| text = text.replace(p, '') | |
| return text.strip() | |
| def create_positive_map_label_to_token_from_positive_map(positive_map, plus=0): | |
| positive_map_label_to_token = {} | |
| for i in range(len(positive_map)): | |
| positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist() | |
| return positive_map_label_to_token | |
| def create_positive_map(tokenized, tokens_positive): | |
| """construct a map such that positive_map[i,j] = True iff box i is associated to token j""" | |
| positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) | |
| for j, tok_list in enumerate(tokens_positive): | |
| for (beg, end) in tok_list: | |
| try: | |
| beg_pos = tokenized.char_to_token(beg) | |
| end_pos = tokenized.char_to_token(end - 1) | |
| except Exception as e: | |
| print("beg:", beg, "end:", end) | |
| print("token_positive:", tokens_positive) | |
| # print("beg_pos:", beg_pos, "end_pos:", end_pos) | |
| raise e | |
| if beg_pos is None: | |
| try: | |
| beg_pos = tokenized.char_to_token(beg + 1) | |
| if beg_pos is None: | |
| beg_pos = tokenized.char_to_token(beg + 2) | |
| except: | |
| beg_pos = None | |
| if end_pos is None: | |
| try: | |
| end_pos = tokenized.char_to_token(end - 2) | |
| if end_pos is None: | |
| end_pos = tokenized.char_to_token(end - 3) | |
| except: | |
| end_pos = None | |
| if beg_pos is None or end_pos is None: | |
| continue | |
| assert beg_pos is not None and end_pos is not None | |
| positive_map[j, beg_pos: end_pos + 1].fill_(1) | |
| return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) | |
| def find_noun_phrases(caption: str) -> List[str]: | |
| caption = caption.lower() | |
| tokens = nltk.word_tokenize(caption) | |
| pos_tags = nltk.pos_tag(tokens) | |
| grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}" | |
| cp = nltk.RegexpParser(grammar) | |
| result = cp.parse(pos_tags) | |
| noun_phrases = list() | |
| for subtree in result.subtrees(): | |
| if subtree.label() == 'NP': | |
| noun_phrases.append(' '.join(t[0] for t in subtree.leaves())) | |
| return noun_phrases | |
| class Glip(nn.Module): | |
| def __init__(self, config, pretrain_path, min_image_size=None,confidence_threshold=0.7): | |
| super(Glip, self).__init__() | |
| state_dict = torch.load(pretrain_path)['model'] | |
| self.min_image_size = min_image_size | |
| self.cfg = config | |
| self.confidence_threshold = confidence_threshold | |
| self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH) | |
| self.device = torch.device(cfg.MODEL.DEVICE) | |
| for k in list(state_dict.keys()): | |
| if k.startswith('module'): | |
| new_k = k.replace('module.', '') | |
| state_dict[new_k] = state_dict.pop(k) | |
| self.model = GeneralizedVLRCNN(config) | |
| self.model.load_state_dict(state_dict, strict=False) | |
| # self.transform = build_transform(config, min_image_size) | |
| def forward(self, images, targets, for_training=None): | |
| # img_list = [] | |
| # for image in images: | |
| # img_list.append(self.transform(image).to(self.device)) | |
| # if isinstance(texts, list): | |
| # # we directly provided a list of category names | |
| # caption_string = "" | |
| # tokens_positive = [] | |
| # seperation_tokens = " . " | |
| # for word in texts: | |
| # tokens_positive.append([len(caption_string), len(caption_string) + len(word)]) | |
| # caption_string += word | |
| # caption_string += seperation_tokens | |
| # tokenized = self.tokenizer([caption_string], return_tensors="pt") | |
| # tokens_positive = [tokens_positive] | |
| # texts = [caption_string] | |
| # print(tokens_positive) | |
| # else: | |
| device = torch.device(cfg.MODEL.DEVICE) | |
| images = [image.to(device) for image in images] | |
| targets = [target.to(device) for target in targets] | |
| texts = [t.get_field("caption") for t in targets if "caption" in t.fields()] | |
| positive_map = [] | |
| # if custom_entity is None: | |
| # tokens_positive = self.run_ner(texts) | |
| # print(tokens_positive) | |
| # process positive map | |
| if self.training == False: | |
| try: | |
| tokens_positive = run_ner(texts[0]) | |
| except: | |
| print('a') | |
| tokenized = self.tokenizer(texts, return_tensors="pt") | |
| positive_map = create_positive_map(tokenized, tokens_positive) | |
| if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD": | |
| plus = 1 | |
| else: | |
| plus = 0 | |
| positive_map = create_positive_map_label_to_token_from_positive_map(positive_map, plus=plus) | |
| else: | |
| for i, text in enumerate(texts): | |
| tokenized = self.tokenizer(text, return_tensors="pt") | |
| tokens_positive = targets[i].get_field('tokens_positive') | |
| positive_map.append(create_positive_map(tokenized, tokens_positive)) | |
| positive_map = torch.cat(positive_map, dim=0).to(device) | |
| if self.training: | |
| proposal_losses = self.model(images, targets, texts, positive_map=positive_map) | |
| return proposal_losses | |
| else: | |
| proposals, token_logits, dot_product_logits = self.model(images, targets, texts, positive_map=positive_map) | |
| proposal = self._post_process(proposals[0]) | |
| return proposal, token_logits, dot_product_logits | |
| def _post_process_fixed_thresh(self, predictions): | |
| scores = predictions.get_field("scores") | |
| labels = predictions.get_field("labels").tolist() | |
| thresh = scores.clone() | |
| for i, lb in enumerate(labels): | |
| if isinstance(self.confidence_threshold, float): | |
| thresh[i] = self.confidence_threshold | |
| elif len(self.confidence_threshold) == 1: | |
| thresh[i] = self.confidence_threshold[0] | |
| else: | |
| thresh[i] = self.confidence_threshold[lb - 1] | |
| keep = torch.nonzero(scores > thresh).squeeze(1) | |
| predictions = predictions[keep] | |
| scores = predictions.get_field("scores") | |
| _, idx = scores.sort(0, descending=True) | |
| return predictions[idx] | |
| def _post_process(self, predictions, threshold=0.5): | |
| scores = predictions.get_field("scores") | |
| labels = predictions.get_field("labels").tolist() | |
| thresh = scores.clone() | |
| for i, lb in enumerate(labels): | |
| if isinstance(self.confidence_threshold, float): | |
| thresh[i] = threshold | |
| elif len(self.confidence_threshold) == 1: | |
| thresh[i] = threshold | |
| else: | |
| thresh[i] = self.confidence_threshold[lb - 1] | |
| keep = torch.nonzero(scores > thresh).squeeze(1) | |
| predictions = predictions[keep] | |
| scores = predictions.get_field("scores") | |
| _, idx = scores.sort(0, descending=True) | |
| return predictions[idx] | |
| # @torch.no_grad() | |
| # def clip_vit_b_16(): | |
| # # https://huggingface.co/openai/clip-vit-base-patch16 | |
| # model = CLIPModelCanReceiveTextEmbeds.from_pretrained("openai/clip-vit-base-patch16") | |
| # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
| # print(model) | |
| # from PIL import Image | |
| # import requests | |
| # image = Image.open('/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/003.backpack/003_0001.jpg') | |
| # inputs = processor(text=["a photo of a dog", "a photo of a backpack", "a photo of a cat"], images=image, return_tensors="pt", padding=True) | |
| # print(inputs) | |
| # from utils.dl.common.model import LayerActivation2, get_module | |
| # input_embed_hook = LayerActivation2(get_module(model, 'text_model.embeddings')) | |
| # outputs = model(**inputs) | |
| # logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
| # probs = logits_per_image.softmax(dim=1) | |
| # print(probs) | |
| # input_embed = input_embed_hook.output | |
| # input_embed_hook.remove() | |
| # torch.save(input_embed, os.path.join(os.path.dirname(__file__), './test_input_embed.pth')) | |
| # print('embed', input_embed.size()) | |
| # del inputs['input_ids'] | |
| # inputs['input_embeds'] = input_embed | |
| # outputs = model(**inputs) | |
| # logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
| # probs = logits_per_image.softmax(dim=1) | |
| # print(probs) | |
| def glip_model(config_path, pretrain_path): | |
| # https://huggingface.co/openai/clip-vit-base-patch16 | |
| cfg.merge_from_file(config_path) | |
| return cfg, Glip(cfg, pretrain_path) | |
| class ToQKV_WrappedWithLoRA(nn.Module): | |
| def __init__(self, fc: nn.Linear, ab_r: int): | |
| super(ToQKV_WrappedWithLoRA, self).__init__() | |
| self.fc = fc | |
| self.ab = self.create_ab_as_linear(fc.weight.data, ab_r) | |
| def create_ab_as_linear(self, fc_weight: torch.Tensor, ab_r: int): | |
| res = nn.Sequential( | |
| LoRA(fc_weight.size(1), fc_weight.size(0) // ab_r, bias=False), | |
| LoRA(fc_weight.size(0) // ab_r, fc_weight.size(0), bias=False) | |
| ).to(fc_weight.device) | |
| nn.init.kaiming_uniform_(res[0].weight, a=5 ** 0.5) | |
| nn.init.zeros_(res[1].weight) | |
| return res | |
| def forward(self, x): | |
| x1 = self.fc(x) | |
| x2 = self.ab(x) | |
| return x1 + x2 | |
| def get_model_latency_2(model: torch.nn.Module, sample: dict, sample_num: int, | |
| device: str, warmup_sample_num: int, return_detail=False): | |
| """Get the latency (inference time) of a PyTorch model. | |
| Reference: https://deci.ai/resources/blog/measure-inference-time-deep-neural-networks/ | |
| Args: | |
| model (torch.nn.Module): A PyTorch model. | |
| model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. | |
| sample_num (int): How many inputs which size is :attr:`model_input_size` will be tested and compute the average latency as result. | |
| device (str): Typically be 'cpu' or 'cuda'. | |
| warmup_sample_num (int): Let model perform some dummy inference to warm up the test environment to avoid measurement loss. | |
| return_detail (bool, optional): Beside the average latency, return all result measured. Defaults to False. | |
| Returns: | |
| Union[float, Tuple[float, List[float]]]: The average latency (and all lantecy data) of :attr:`model`. | |
| """ | |
| # if isinstance(model_input_size, tuple): | |
| # dummy_input = torch.rand(model_input_size).to(device) | |
| # else: | |
| # dummy_input = model_input_size | |
| model = model.to(device) | |
| model.eval() | |
| # warm up | |
| with torch.no_grad(): | |
| for _ in range(warmup_sample_num): | |
| model(**sample) | |
| infer_time_list = [] | |
| if device == 'cuda' or 'cuda' in str(device): | |
| with torch.no_grad(): | |
| for _ in range(sample_num): | |
| s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
| s.record() | |
| model(**sample) | |
| e.record() | |
| torch.cuda.synchronize() | |
| cur_model_infer_time = s.elapsed_time(e) / 1000. | |
| infer_time_list += [cur_model_infer_time] | |
| else: | |
| with torch.no_grad(): | |
| for _ in range(sample_num): | |
| start = time.time() | |
| model(**sample) | |
| cur_model_infer_time = time.time() - start | |
| infer_time_list += [cur_model_infer_time] | |
| avg_infer_time = sum(infer_time_list) / sample_num | |
| if return_detail: | |
| return avg_infer_time, infer_time_list | |
| return avg_infer_time | |
| class WindowAttention(nn.Module): | |
| """ Window based multi-head self attention (W-MSA) module with relative position bias. | |
| It supports both of shifted and non-shifted window. | |
| Args: | |
| dim (int): Number of input channels. | |
| window_size (tuple[int]): The height and width of the window. | |
| num_heads (int): Number of attention heads. | |
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |
| attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |
| proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |
| """ | |
| def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): | |
| super().__init__() | |
| self.dim = dim | |
| self.window_size = window_size # Wh, Ww | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim ** -0.5 | |
| # define a parameter table of relative position bias | |
| self.relative_position_bias_table = nn.Parameter( | |
| torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |
| # get pair-wise relative position index for each token inside the window | |
| coords_h = torch.arange(self.window_size[0]) | |
| coords_w = torch.arange(self.window_size[1]) | |
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |
| coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |
| relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww | |
| relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |
| relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 | |
| relative_coords[:, :, 1] += self.window_size[1] - 1 | |
| relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |
| relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |
| self.register_buffer("relative_position_index", relative_position_index) | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| trunc_normal_(self.relative_position_bias_table, std=.02) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x, mask=None): | |
| """ Forward function. | |
| Args: | |
| x: input features with shape of (num_windows*B, N, C) | |
| mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |
| """ | |
| B_, N, C = x.shape | |
| qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | |
| q = q * self.scale | |
| attn = (q @ k.transpose(-2, -1)) | |
| relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | |
| self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH | |
| relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |
| attn = attn + relative_position_bias.unsqueeze(0) | |
| if mask is not None: | |
| nW = mask.shape[0] | |
| attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) | |
| attn = attn.view(-1, self.num_heads, N, N) | |
| attn = self.softmax(attn) | |
| else: | |
| attn = self.softmax(attn) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class BiMultiHeadAttention(nn.Module): | |
| def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None): | |
| super(BiMultiHeadAttention, self).__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.v_dim = v_dim | |
| self.l_dim = l_dim | |
| assert ( | |
| self.head_dim * self.num_heads == self.embed_dim | |
| ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." | |
| self.scale = self.head_dim ** (-0.5) | |
| self.dropout = dropout | |
| self.v_proj = nn.Linear(self.v_dim, self.embed_dim) | |
| self.l_proj = nn.Linear(self.l_dim, self.embed_dim) | |
| self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) | |
| self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) | |
| self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) | |
| self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) | |
| self.stable_softmax_2d = cfg.MODEL.DYHEAD.FUSE_CONFIG.STABLE_SOFTMAX_2D | |
| self.clamp_min_for_underflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW | |
| self.clamp_max_for_overflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW | |
| self._reset_parameters() | |
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | |
| def _reset_parameters(self): | |
| nn.init.xavier_uniform_(self.v_proj.weight) | |
| self.v_proj.bias.data.fill_(0) | |
| nn.init.xavier_uniform_(self.l_proj.weight) | |
| self.l_proj.bias.data.fill_(0) | |
| nn.init.xavier_uniform_(self.values_v_proj.weight) | |
| self.values_v_proj.bias.data.fill_(0) | |
| nn.init.xavier_uniform_(self.values_l_proj.weight) | |
| self.values_l_proj.bias.data.fill_(0) | |
| nn.init.xavier_uniform_(self.out_v_proj.weight) | |
| self.out_v_proj.bias.data.fill_(0) | |
| nn.init.xavier_uniform_(self.out_l_proj.weight) | |
| self.out_l_proj.bias.data.fill_(0) | |
| def forward(self, v, l, attention_mask_l=None): | |
| bsz, tgt_len, embed_dim = v.size() | |
| query_states = self.v_proj(v) * self.scale | |
| key_states = self._shape(self.l_proj(l), -1, bsz) | |
| value_v_states = self._shape(self.values_v_proj(v), -1, bsz) | |
| value_l_states = self._shape(self.values_l_proj(l), -1, bsz) | |
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) | |
| query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) | |
| key_states = key_states.view(*proj_shape) | |
| value_v_states = value_v_states.view(*proj_shape) | |
| value_l_states = value_l_states.view(*proj_shape) | |
| src_len = key_states.size(1) | |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" | |
| ) | |
| # attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1) | |
| if self.stable_softmax_2d: | |
| attn_weights = attn_weights - attn_weights.max() | |
| if self.clamp_min_for_underflow: | |
| attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range | |
| if self.clamp_max_for_overflow: | |
| attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range | |
| attn_weights_T = attn_weights.transpose(1, 2) | |
| attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[ | |
| 0]) | |
| if self.clamp_min_for_underflow: | |
| attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range | |
| if self.clamp_max_for_overflow: | |
| attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range | |
| attn_weights_l = attn_weights_l.softmax(dim=-1) | |
| if attention_mask_l is not None: | |
| assert (attention_mask_l.dim() == 2) | |
| attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) | |
| attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) | |
| attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15) | |
| if attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
| raise ValueError( | |
| f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}" | |
| ) | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| attn_weights_v = nn.functional.softmax(attn_weights, dim=-1) | |
| attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training) | |
| attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training) | |
| attn_output_v = torch.bmm(attn_probs_v, value_l_states) | |
| attn_output_l = torch.bmm(attn_probs_l, value_v_states) | |
| if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim): | |
| raise ValueError( | |
| f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}" | |
| ) | |
| if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim): | |
| raise ValueError( | |
| f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}" | |
| ) | |
| attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
| attn_output_v = attn_output_v.transpose(1, 2) | |
| attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) | |
| attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim) | |
| attn_output_l = attn_output_l.transpose(1, 2) | |
| attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) | |
| attn_output_v = self.out_v_proj(attn_output_v) | |
| attn_output_l = self.out_l_proj(attn_output_l) | |
| return attn_output_v, attn_output_l | |
| class BertSelfAttentionPrunable(BertSelfAttention): | |
| def __init__(self): | |
| config = BertConfig.from_pretrained('new_impl/cv/glip/object_detection/bert-base-uncased') | |
| super(BertSelfAttentionPrunable, self).__init__(config) | |
| def transpose_for_scores(self, x): | |
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) | |
| x = x.view(new_x_shape) | |
| return x.permute(0, 2, 1, 3) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| head_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
| output_attentions: Optional[bool] = False, | |
| ) -> Tuple[torch.Tensor]: | |
| mixed_query_layer = self.query(hidden_states) | |
| # If this is instantiated as a cross-attention module, the keys | |
| # and values come from an encoder; the attention mask needs to be | |
| # such that the encoder's padding tokens are not attended to. | |
| is_cross_attention = encoder_hidden_states is not None | |
| if is_cross_attention and past_key_value is not None: | |
| # reuse k,v, cross_attentions | |
| key_layer = past_key_value[0] | |
| value_layer = past_key_value[1] | |
| attention_mask = encoder_attention_mask | |
| elif is_cross_attention: | |
| key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) | |
| value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) | |
| attention_mask = encoder_attention_mask | |
| elif past_key_value is not None: | |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) | |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) | |
| key_layer = torch.cat([past_key_value[0], key_layer], dim=2) | |
| value_layer = torch.cat([past_key_value[1], value_layer], dim=2) | |
| else: | |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) | |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) | |
| query_layer = self.transpose_for_scores(mixed_query_layer) | |
| use_cache = past_key_value is not None | |
| if self.is_decoder: | |
| # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | |
| # Further calls to cross_attention layer can then reuse all cross-attention | |
| # key/value_states (first "if" case) | |
| # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of | |
| # all previous decoder key/value_states. Further calls to uni-directional self-attention | |
| # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) | |
| # if encoder bi-directional self-attention `past_key_value` is always `None` | |
| past_key_value = (key_layer, value_layer) | |
| # Take the dot product between "query" and "key" to get the raw attention scores. | |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
| if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": | |
| query_length, key_length = query_layer.shape[2], key_layer.shape[2] | |
| if use_cache: | |
| position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( | |
| -1, 1 | |
| ) | |
| else: | |
| position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) | |
| position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) | |
| distance = position_ids_l - position_ids_r | |
| positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) | |
| positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility | |
| if self.position_embedding_type == "relative_key": | |
| relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) | |
| attention_scores = attention_scores + relative_position_scores | |
| elif self.position_embedding_type == "relative_key_query": | |
| relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) | |
| relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) | |
| attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key | |
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |
| if attention_mask is not None: | |
| # Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |
| attention_scores = attention_scores + attention_mask | |
| # Normalize the attention scores to probabilities. | |
| attention_probs = nn.functional.softmax(attention_scores, dim=-1) | |
| # This is actually dropping out entire tokens to attend to, which might | |
| # seem a bit unusual, but is taken from the original Transformer paper. | |
| attention_probs = self.dropout(attention_probs) | |
| # Mask heads if we want to | |
| if head_mask is not None: | |
| attention_probs = attention_probs * head_mask | |
| context_layer = torch.matmul(attention_probs, value_layer) | |
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
| new_context_layer_shape = context_layer.size()[:-2] + (self.query.out_features,) # NOTE: modified | |
| context_layer = context_layer.view(new_context_layer_shape) | |
| outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) | |
| if self.is_decoder: | |
| outputs = outputs + (past_key_value,) | |
| return outputs | |
| def init_from_exist_self_attn(attn: BertSelfAttention): | |
| # print(attn) | |
| res = BertSelfAttentionPrunable() | |
| for attr in dir(attn): | |
| # if str(attr) in ['transpose_for_scores'] or str(attr).startswith('_'): | |
| # continue | |
| # if isinstance(getattr(attn, attr), nn.Module): | |
| # print(attr) | |
| if isinstance(getattr(attn, attr), nn.Module): | |
| try: | |
| # print(attr, 'ok') | |
| setattr(res, attr, getattr(attn, attr)) | |
| except Exception as e: | |
| print(attr, str(e)) | |
| return res | |
| class FM_to_MD_GLIP_Util(FM_to_MD_Util): | |
| def init_md_from_fm_by_reducing_width_with_perf_test_2(self, fm: nn.Module, reducing_width_ratio: int, | |
| samples: torch.Tensor) -> nn.Module: | |
| fm_size = get_model_size(fm, True) | |
| fm_latency = get_model_latency_2(fm, samples, 20, | |
| get_model_device(fm), 20, False) | |
| master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio) | |
| master_dnn_size = get_model_size(master_dnn, True) | |
| logger.debug(f'inited master DNN: {master_dnn}') | |
| # from utils.dl.common.model import get_module | |
| # print('after generating') | |
| # get_module(fm, 'head').debug() | |
| # get_module(master_dnn, 'head').debug() | |
| # print('test master latency') | |
| master_dnn_latency = get_model_latency_2(fm, samples, 20, | |
| get_model_device(fm), 20, False) | |
| logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)') | |
| logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> ' | |
| f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n' | |
| f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, ' | |
| f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)') | |
| return master_dnn | |
| def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int, sparsity=0.0) -> nn.Module: | |
| #sparsity: it is mainly used to make a distilled model used in the baseline algorithm, and the parameter can ensure that the model has the same size as the model used in the online algorithm. | |
| fm_vit = deepcopy(fm) | |
| def _f(n): | |
| return int(n // reducing_width_ratio) | |
| # def _rand_indexes(n): | |
| # return torch.randperm(n)[0: int(n // reducing_width_ratio)] | |
| def l1_max_indexes(p: torch.Tensor, dim=0): | |
| assert dim in [0, 1] | |
| assert p.dim() in [1, 2, 4] | |
| if dim == 1: | |
| p = p.T | |
| p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) | |
| n = p.size(0) | |
| t1 = p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)] | |
| t2 = t1.sort()[0] | |
| return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0] | |
| def l1_max_indexes_with_sparsity(p: torch.Tensor, dim=0): | |
| assert dim in [0, 1] | |
| assert p.dim() in [1, 2, 4] | |
| if dim == 1: | |
| p = p.T | |
| p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) | |
| n = p.size(0) | |
| return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio * (1 - sparsity))].sort()[0] | |
| for layer_i, layer in enumerate(fm_vit.model.backbone.body.layers): | |
| for block in layer.blocks: | |
| ori_attn = block.attn | |
| new_attn = WindowAttention(ori_attn.dim, ori_attn.window_size, ori_attn.num_heads, True, ori_attn.scale, 0., 0.) | |
| new_attn.relative_position_index = ori_attn.relative_position_index | |
| new_attn.relative_position_bias_table = ori_attn.relative_position_bias_table | |
| new_attn.qkv = ori_attn.qkv | |
| new_attn.attn_drop = ori_attn.attn_drop | |
| new_attn.proj = ori_attn.proj | |
| new_attn.proj_drop = ori_attn.proj_drop | |
| set_module(block, 'attn', new_attn) | |
| # first_attn = True | |
| for layer_i, layer in enumerate(fm_vit.model.backbone.body.layers): | |
| for block_i, block in enumerate(layer.blocks): | |
| qkv = block.attn.qkv | |
| new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
| qkv.bias is not None, qkv.weight.device) | |
| indexes = l1_max_indexes(qkv.weight.data, 0) | |
| new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
| if qkv.bias is not None: | |
| new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
| # fm_vit.model.backbone.body.layers[0].blocks.0.attn.qkv | |
| set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.attn.qkv', new_qkv) | |
| proj = block.attn.proj | |
| new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
| proj.bias is not None, proj.weight.device) | |
| new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
| if proj.bias is not None: | |
| new_proj.bias.data.copy_(proj.bias.data) | |
| set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.attn.proj', new_proj) | |
| fc1 = block.mlp.fc1 | |
| new_fc1 = nn.Linear(fc1.in_features, int(_f(fc1.out_features) * (1 - sparsity)), | |
| fc1.bias is not None, fc1.weight.device) | |
| indexes = l1_max_indexes_with_sparsity(fc1.weight.data, 0) | |
| new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
| if fc1.bias is not None: | |
| new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
| set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.mlp.fc1', new_fc1) | |
| fc2 = block.mlp.fc2 | |
| new_fc2 = nn.Linear(int(_f(fc2.in_features) * (1 - sparsity)), fc2.out_features, | |
| fc2.bias is not None, fc2.weight.device) | |
| new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes_with_sparsity(fc2.weight.data, 1)]) | |
| if fc2.bias is not None: | |
| new_fc2.bias.data.copy_(fc2.bias.data) | |
| set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.mlp.fc2', new_fc2) | |
| for block in fm_vit.model.language_backbone.body.model.encoder.layer: | |
| set_module(block, 'attention.self', BertSelfAttentionPrunable.init_from_exist_self_attn(block.attention.self)) | |
| for block_i, block in enumerate(fm_vit.model.language_backbone.body.model.encoder.layer): | |
| for k in ['query', 'key', 'value']: | |
| qkv = get_module(block, f'attention.self.{k}') | |
| new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
| qkv.bias is not None, qkv.weight.device) | |
| indexes = l1_max_indexes(qkv.weight.data, 0) | |
| new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
| if qkv.bias is not None: | |
| new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
| set_module(block, f'attention.self.{k}', new_qkv) | |
| proj = get_module(block, f'attention.output.dense') | |
| new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
| proj.bias is not None, proj.weight.device) | |
| new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
| if proj.bias is not None: | |
| new_proj.bias.data.copy_(proj.bias.data) | |
| set_module(block, f'attention.output.dense', new_proj) | |
| fc1 = get_module(block, f'intermediate.dense') | |
| new_fc1 = nn.Linear(fc1.in_features, int(_f(fc1.out_features) * (1 - sparsity)), | |
| fc1.bias is not None, fc1.weight.device) | |
| indexes = l1_max_indexes_with_sparsity(fc1.weight.data, 0) | |
| new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
| if fc1.bias is not None: | |
| new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
| set_module(block, f'intermediate.dense', new_fc1) | |
| fc2 = get_module(block, f'output.dense') | |
| new_fc2 = nn.Linear(int(_f(fc2.in_features) * (1 - sparsity)), fc2.out_features, | |
| fc2.bias is not None, fc2.weight.device) | |
| new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes_with_sparsity(fc2.weight.data, 1)]) | |
| if fc2.bias is not None: | |
| new_fc2.bias.data.copy_(fc2.bias.data) | |
| set_module(block, f'output.dense', new_fc2) | |
| for block_i, block in enumerate(fm_vit.model.rpn.head.dyhead_tower): | |
| if block_i % 3 == 0: | |
| tmp = block.b_attn.attn | |
| tmp.head_dim = int(tmp.head_dim // reducing_width_ratio) | |
| tmp.embed_dim = int(tmp.embed_dim // reducing_width_ratio) | |
| set_module(block, 'b_attn.attn', tmp) | |
| for k in ['v_proj', 'l_proj', 'values_v_proj', 'values_l_proj']: | |
| qkv = get_module(block, f'b_attn.attn.{k}') | |
| new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
| qkv.bias is not None, qkv.weight.device) | |
| indexes = l1_max_indexes(qkv.weight.data, 0) | |
| new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
| if qkv.bias is not None: | |
| new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
| set_module(block, f'b_attn.attn.{k}', new_qkv) | |
| for k in ['out_v_proj', 'out_l_proj']: | |
| qkv = get_module(block, f'b_attn.attn.{k}') | |
| new_qkv = nn.Linear(_f(qkv.in_features), qkv.out_features, | |
| qkv.bias is not None, qkv.weight.device) | |
| new_qkv.weight.data.copy_(qkv.weight.data[:, l1_max_indexes(qkv.weight.data, 1)]) | |
| if qkv.bias is not None: | |
| new_qkv.bias.data.copy_(qkv.bias.data) | |
| set_module(block, f'b_attn.attn.{k}', new_qkv) | |
| elif block_i % 3 == 1: | |
| set_module(block, 'attention.self', BertSelfAttentionPrunable.init_from_exist_self_attn(block.attention.self)) | |
| for k in ['query', 'key', 'value']: | |
| qkv = get_module(block, f'attention.self.{k}') | |
| new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
| qkv.bias is not None, qkv.weight.device) | |
| indexes = l1_max_indexes(qkv.weight.data, 0) | |
| new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
| if qkv.bias is not None: | |
| new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
| set_module(block, f'attention.self.{k}', new_qkv) | |
| proj = get_module(block, f'attention.output.dense') | |
| new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
| proj.bias is not None, proj.weight.device) | |
| new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
| if proj.bias is not None: | |
| new_proj.bias.data.copy_(proj.bias.data) | |
| set_module(block, f'attention.output.dense', new_proj) | |
| fc1 = get_module(block, f'intermediate.dense') | |
| new_fc1 = nn.Linear(fc1.in_features, int(_f(fc1.out_features) * (1 - sparsity)), | |
| fc1.bias is not None, fc1.weight.device) | |
| indexes = l1_max_indexes_with_sparsity(fc1.weight.data, 0) | |
| new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
| if fc1.bias is not None: | |
| new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
| set_module(block, f'intermediate.dense', new_fc1) | |
| fc2 = get_module(block, f'output.dense') | |
| new_fc2 = nn.Linear(int(_f(fc2.in_features) * (1 - sparsity)), fc2.out_features, | |
| fc2.bias is not None, fc2.weight.device) | |
| new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes_with_sparsity(fc2.weight.data, 1)]) | |
| if fc2.bias is not None: | |
| new_fc2.bias.data.copy_(fc2.bias.data) | |
| set_module(block, f'output.dense', new_fc2) | |
| # reduce dim_embedding | |
| # if name.endswith('patch_embed.proj'): | |
| # continue | |
| # new_layer = nn.Conv2d(module.in_channels, _f(module.out_channels), module.kernel_size, module.stride, | |
| # module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, | |
| # module.weight.device) | |
| # rand_indexes = l1_max_indexes(module.weight.data) | |
| # new_layer.weight.data.copy_(module.weight.data[rand_indexes]) | |
| # if new_layer.bias is not None: | |
| # new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
| # fm_vit.cls_token.data = fm_vit.cls_token.data[:, :, rand_indexes] | |
| # fm_vit.pos_embed.data = fm_vit.pos_embed.data[:, :, rand_indexes] | |
| # elif isinstance(module, nn.Linear): | |
| # if 'head' in name: | |
| # continue | |
| # new_layer = nn.Linear(_f(module.in_features), module.out_features, | |
| # module.bias is not None, module.weight.device) | |
| # new_layer.weight.data.copy_(module.weight.data[:, l1_max_indexes(module.weight.data, 1)]) | |
| # if new_layer.bias is not None: | |
| # new_layer.bias.data.copy_(module.bias.data) | |
| # else: | |
| # first_attn = False | |
| # if first_attn: | |
| # first_attn = False | |
| # new_layer = nn.Linear(module.in_features, _f(module.out_features), | |
| # module.bias is not None, module.weight.device) | |
| # rand_indexes = l1_max_indexes(module.weight.data) | |
| # new_layer.weight.data.copy_(module.weight.data[rand_indexes]) | |
| # if new_layer.bias is not None: | |
| # new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
| # else: | |
| # new_layer = nn.Linear(_f(module.in_features), _f(module.out_features), | |
| # module.bias is not None, module.weight.device) | |
| # rand_indexes = l1_max_indexes(module.weight.data) | |
| # new_layer.weight.data.copy_(module.weight.data[rand_indexes][:, l1_max_indexes(module.weight.data, 1)]) | |
| # if new_layer.bias is not None: | |
| # new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
| # elif isinstance(module, nn.LayerNorm) and ('block' in name or name == 'norm' or name == 'norm.0'): | |
| # new_layer = nn.LayerNorm(_f(module.normalized_shape[0]), eps=module.eps, device=module.weight.device) | |
| # rand_indexes = l1_max_indexes(module.weight.data) | |
| # new_layer.weight.data.copy_(module.weight.data[rand_indexes]) | |
| # new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
| # else: | |
| # continue | |
| # original_layer_str = str(module) | |
| # set_module(fm_vit, name, new_layer) | |
| # logger.debug(f'set_module, {name}, {new_layer}') | |
| # logger.debug(f'slim {name} from {original_layer_str} to {new_layer}') | |
| return fm_vit | |
| class FMLoRA_GLIP_Util(FMLoRA_Util): | |
| def train_only_lora_and_conv(self, fm: nn.Module): | |
| res = [] | |
| for n, m in fm.named_modules(): | |
| if isinstance(m, LoRA) or isinstance(m, nn.Conv2d): | |
| for p in m.parameters(): | |
| p.requires_grad = True | |
| res += [p] | |
| else: | |
| for p in m.parameters(): | |
| p.requires_grad = False | |
| return res | |
| def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples): | |
| fm.eval() | |
| # samples = {'images' : samples[0], 'targets' : samples[1]} | |
| for k, v in samples.items(): | |
| if isinstance(v, torch.Tensor) or isinstance(v, BoxList): | |
| samples[k] = v.to(get_model_device(fm)) | |
| print(k) | |
| _, o1_token_logits, o1_dot_product_logits = fm(**samples) | |
| mo_list = {k:v for k, v in fm.named_modules()} | |
| for name, module in fm.named_modules(): | |
| if '.proj' in name or 'out' in name: | |
| continue | |
| if name.endswith(('k_proj', 'q_proj', 'v_proj', 'qkv', 'attn.proj', 'l_proj', 'query', 'key', 'value')): | |
| set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r)) | |
| _, o2_token_logits, o2_dot_product_logits = fm(**samples) | |
| output_diff = 0. | |
| for o1, o2 in list(zip(o1_dot_product_logits, o2_dot_product_logits)): | |
| output_diff += ((o1 - o2) ** 2).sum() | |
| if o1_token_logits is not None: | |
| output_diff += ((o1_token_logits - o2_token_logits) ** 2).sum() | |
| assert output_diff < 1e-5 | |
| return fm | |
| def absorb_lora_and_recover_net_structure(self, fm: nn.Module, samples: dict): | |
| fm.eval() | |
| # print('absorb lora before') | |
| for k, v in samples.items(): | |
| if isinstance(v, torch.Tensor): | |
| samples[k] = v.to(get_model_device(fm)) | |
| print(k) | |
| _, o1_token_logits, o1_dot_product_logits = fm(**samples) | |
| for name, module in fm.named_modules(): | |
| if not isinstance(module, ToQKV_WrappedWithLoRA): | |
| continue | |
| fc = module.fc | |
| ab = module.ab | |
| fc.weight.add_(ab[1].weight @ ab[0].weight) | |
| set_module(fm, name, fc) | |
| # print('absorb lora after') | |
| _, o2_token_logits, o2_dot_product_logits = fm(**samples) | |
| output_diff = 0. | |
| for o1, o2 in list(zip(o1_dot_product_logits, o2_dot_product_logits)): | |
| output_diff += ((o1 - o2) ** 2).sum() | |
| if o1_token_logits is not None: | |
| output_diff += ((o1_token_logits - o2_token_logits) ** 2).sum() | |
| assert output_diff < 1e-3, output_diff | |
| return fm | |
| class ElasticDNN_OfflineMMDetFMModel(ElasticDNN_OfflineFMModel): | |
| def __init__(self, name: str, models_dict_path: str, device: str, num_classes=10, collate_fn=None): | |
| super().__init__(name, models_dict_path, device) | |
| self.num_classes = num_classes | |
| self.collate_fn = collate_fn | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| # print('DeeplabV3: start test acc') | |
| _d = test_loader.dataset | |
| from data import build_dataloader | |
| if _d.__class__.__name__ == 'MergedDataset': | |
| # print('\neval on merged datasets') | |
| datasets = _d.datasets | |
| if self.collate_fn is None: | |
| test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=None) for d in datasets] | |
| else: | |
| test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=self.collate_fn) for d in datasets] | |
| accs = [self.get_accuracy(loader) for loader in test_loaders] | |
| # print(accs) | |
| return sum(accs) / len(accs) | |
| # print('dataset len', len(test_loader.dataset)) | |
| model = self.models_dict['main'] | |
| device = self.device | |
| model.eval() | |
| # print('# classes', model.num_classes) | |
| model = model.to(device) | |
| from evaluator import COCOEvaluator, MMCOCODecoder | |
| from utils.common.others import HiddenPrints | |
| with torch.no_grad(): | |
| with HiddenPrints(): | |
| evaluator = COCOEvaluator( | |
| dataloader=test_loader, | |
| img_size=(416, 416), | |
| confthre=0.01, | |
| nmsthre=0.65, | |
| num_classes=len(test_loader.dataset.classes), | |
| testdev=True | |
| ) | |
| res = evaluator.evaluate(model, False, False, decoder=MMCOCODecoder) | |
| map50 = res[1] | |
| # print('eval info', res[-1]) | |
| return map50 | |
| def infer(self, x, *args, **kwargs): | |
| if len(args) > 0: | |
| print(args, len(args)) | |
| return self.models_dict['main'](x, *args) # forward(x, label) | |
| return self.models_dict['main'](**x) | |
| class ElasticDNN_OfflineMMDetMDModel(ElasticDNN_OfflineMDModel): | |
| def __init__(self, name: str, models_dict_path: str, device: str, num_classes=10, collate_fn=None): | |
| super().__init__(name, models_dict_path, device) | |
| self.num_classes = num_classes | |
| self.collate_fn = collate_fn | |
| def get_accuracy(self, test_loader, *args, **kwargs): | |
| # print('DeeplabV3: start test acc') | |
| _d = test_loader.dataset | |
| from data import build_dataloader | |
| if _d.__class__.__name__ == 'MergedDataset': | |
| # print('\neval on merged datasets') | |
| datasets = _d.datasets | |
| if self.collate_fn is None: | |
| test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=None) for d in datasets] | |
| else: | |
| test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=self.collate_fn) for d in datasets] | |
| accs = [self.get_accuracy(loader) for loader in test_loaders] | |
| # print(accs) | |
| return sum(accs) / len(accs) | |
| # print('dataset len', len(test_loader.dataset)) | |
| model = self.models_dict['main'] | |
| device = self.device | |
| model.eval() | |
| # print('# classes', model.num_classes) | |
| model = model.to(device) | |
| from evaluator import COCOEvaluator, MMCOCODecoder | |
| from utils.common.others import HiddenPrints | |
| with torch.no_grad(): | |
| with HiddenPrints(): | |
| evaluator = COCOEvaluator( | |
| dataloader=test_loader, | |
| img_size=(416, 416), | |
| confthre=0.01, | |
| nmsthre=0.65, | |
| num_classes=len(test_loader.dataset.classes), | |
| testdev=True | |
| ) | |
| res = evaluator.evaluate(model, False, False, decoder=MMCOCODecoder) | |
| map50 = res[1] | |
| # print('eval info', res[-1]) | |
| return map50 | |
| def infer(self, x, *args, **kwargs): | |
| if len(args) > 0: | |
| return self.models_dict['main'](x, *args) # forward(x, label) | |
| return self.models_dict['main'](**x) | |
| class SqueezeLast(nn.Module): | |
| def __init__(self): | |
| super(SqueezeLast, self).__init__() | |
| def forward(self, x): | |
| return x.squeeze(-1) | |
| class ProjConv_WrappedWithFBS(Layer_WrappedWithFBS): | |
| def __init__(self, raw_conv2d: nn.Conv2d, r): | |
| super(ProjConv_WrappedWithFBS, self).__init__() | |
| self.fbs = nn.Sequential( | |
| Abs(), | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Flatten(), | |
| nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r), | |
| nn.ReLU(), | |
| nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels), | |
| nn.ReLU() | |
| ) | |
| self.raw_conv2d = raw_conv2d | |
| # self.raw_bn = raw_bn # remember clear the original BNs in the network | |
| nn.init.constant_(self.fbs[5].bias, 1.) | |
| nn.init.kaiming_normal_(self.fbs[5].weight) | |
| def forward(self, x): | |
| raw_x = self.raw_conv2d(x) | |
| if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
| channel_attention = self.cached_channel_attention | |
| else: | |
| self.cached_raw_channel_attention = self.fbs(x) | |
| self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) | |
| channel_attention = self.cached_channel_attention | |
| return raw_x * channel_attention.unsqueeze(2).unsqueeze(3) | |
| class Linear_WrappedWithFBS(Layer_WrappedWithFBS): | |
| def __init__(self, linear: nn.Linear, r): | |
| super(Linear_WrappedWithFBS, self).__init__() | |
| self.linear = linear | |
| # for conv: (B, C_in, H, W) -> (B, C_in) -> (B, C_out) | |
| # for mlp in ViT: (B, #patches, D: dim of patches embedding) -> (B, D) -> (B, C_out) | |
| self.fbs = nn.Sequential( | |
| Rearrange('b n d -> b d n'), | |
| Abs(), | |
| nn.AdaptiveAvgPool1d(1), | |
| SqueezeLast(), | |
| nn.Linear(linear.in_features, max(linear.out_features // r, 36)), | |
| nn.ReLU(), | |
| nn.Linear(max(linear.out_features // r, 36), linear.out_features), | |
| nn.ReLU() | |
| ) | |
| nn.init.constant_(self.fbs[6].bias, 1.) | |
| nn.init.kaiming_normal_(self.fbs[6].weight) | |
| def forward(self, x): | |
| if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
| channel_attention = self.cached_channel_attention | |
| else: | |
| self.cached_raw_channel_attention = self.fbs(x) | |
| self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) | |
| channel_attention = self.cached_channel_attention | |
| raw_res = self.linear(x) | |
| return channel_attention.unsqueeze(1) * raw_res | |
| class ToQKV_WrappedWithFBS(Layer_WrappedWithFBS): | |
| """ | |
| This regards to_q/to_k/to_v as a whole (in fact it consists of multiple heads) and prunes it. | |
| It seems different channels of different heads are pruned according to the input. | |
| This is different from "removing some head" or "removing the same channels in each head". | |
| """ | |
| def __init__(self, to_qkv: nn.Linear, r): | |
| super(ToQKV_WrappedWithFBS, self).__init__() | |
| # self.to_qkv = to_qkv | |
| self.to_qk = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 * 2, bias=to_qkv.bias is not None) | |
| self.to_v = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3, bias=to_qkv.bias is not None) | |
| self.to_qk.weight.data.copy_(to_qkv.weight.data[0: to_qkv.out_features // 3 * 2]) | |
| if to_qkv.bias is not None: | |
| self.to_qk.bias.data.copy_(to_qkv.bias.data[0: to_qkv.out_features // 3 * 2]) | |
| self.to_v.weight.data.copy_(to_qkv.weight.data[to_qkv.out_features // 3 * 2: ]) | |
| if to_qkv.bias is not None: | |
| self.to_v.bias.data.copy_(to_qkv.bias.data[to_qkv.out_features // 3 * 2: ]) | |
| self.fbs = nn.Sequential( | |
| Rearrange('b n d -> b d n'), | |
| Abs(), | |
| nn.AdaptiveAvgPool1d(1), | |
| SqueezeLast(), | |
| nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 // r), | |
| nn.ReLU(), | |
| # nn.Linear(to_qkv.out_features // 3 // r, to_qkv.out_features // 3), | |
| nn.Linear(to_qkv.out_features // 3 // r, self.to_v.out_features), | |
| nn.ReLU() | |
| ) | |
| nn.init.constant_(self.fbs[6].bias, 1.) | |
| nn.init.kaiming_normal_(self.fbs[6].weight) | |
| def forward(self, x): | |
| if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
| channel_attention = self.cached_channel_attention | |
| else: | |
| self.cached_raw_channel_attention = self.fbs(x) | |
| # print() | |
| # for attn in self.cached_raw_channel_attention.chunk(3, dim=1)[0: 1]: | |
| # print(self.cached_raw_channel_attention.size(), attn.size()) | |
| # print(self.k_takes_all.k) | |
| # print(attn[0].nonzero(as_tuple=True)[0].size(), attn[0]) | |
| self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) | |
| # for attn in self.cached_channel_attention.chunk(3, dim=1)[0: 1]: | |
| # print(self.cached_channel_attention.size(), attn.size()) | |
| # print(self.k_takes_all.k) | |
| # print(attn[0].nonzero(as_tuple=True)[0].size(), attn[0]) | |
| # print() | |
| channel_attention = self.cached_channel_attention | |
| qk = self.to_qk(x) | |
| v = channel_attention.unsqueeze(1) * self.to_v(x) | |
| return torch.cat([qk, v], dim=-1) | |
| # qkv = raw_res.chunk(3, dim = -1) | |
| # raw_v = qkv[2] | |
| # print('raw_k, raw_v', qkv[0].sum((0, 1))[0: 10], qkv[0].sum((0, 1)).nonzero(as_tuple=True)[0].size(), | |
| # qkv[1].sum((0, 1))[0: 10], qkv[1].sum((0, 1)).nonzero(as_tuple=True)[0].size(),) | |
| # print('raw_v', raw_v.size(), raw_v.sum((0, 1))[0: 10], raw_v.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
| # qkv_attn = channel_attention.chunk(3, dim=-1) | |
| # print('attn', [attn[0][0: 10] for attn in qkv_attn]) | |
| # print(channel_attention.unsqueeze(1).size(), raw_res.size()) | |
| # print('fbs', channel_attention.size(), raw_res.size()) | |
| # return channel_attention.unsqueeze(1) * raw_res | |
| class StaticFBS(nn.Module): | |
| def __init__(self, static_channel_attention): | |
| super(StaticFBS, self).__init__() | |
| assert static_channel_attention.dim() == 2 and static_channel_attention.size(0) == 1 | |
| self.static_channel_attention = nn.Parameter(static_channel_attention, requires_grad=False) # (1, dim) | |
| def forward(self, x): | |
| # print('staticfbs', x, self.static_channel_attention.unsqueeze(1)) | |
| return x * self.static_channel_attention.unsqueeze(1) | |
| class ElasticGLIPUtil(ElasticDNNUtil): | |
| def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): | |
| assert len(ignore_layers) == 0, 'not supported yet' | |
| raw_vit = deepcopy(raw_dnn) | |
| for name, module in raw_vit.named_modules(): | |
| # if name.endswith('patch_embed'): | |
| # set_module(module, 'proj', ProjConv_WrappedWithFBS(module.proj, r)) | |
| # if name.endswith('attn') and not name.endswith('b_attn.attn') and not name.endswith('b_attn'): | |
| # set_module(module, 'qkv', ToQKV_WrappedWithFBS(module.qkv, r)) | |
| if name.endswith('intermediate'): | |
| set_module(module, 'dense', Linear_WrappedWithFBS(module.dense, r)) | |
| elif name.endswith('mlp'): | |
| set_module(module, 'fc1', Linear_WrappedWithFBS(module.fc1, r)) | |
| return raw_vit | |
| def set_master_dnn_sparsity(self, master_dnn: nn.Module, sparsity: float): | |
| # for name, module in master_dnn.named_modules(): | |
| # if not name.endswith('attn'): | |
| # continue | |
| # q_features = module.qkv.to_qk.out_features // 2 | |
| # if (q_features - int(q_features * sparsity)) % module.num_heads != 0: | |
| # # tune sparsity to ensure #unpruned channel % num_heads == 0 | |
| # # so that the pruning seems to reduce the dim_head of each head | |
| # tuned_sparsity = 1. - int((q_features - int(q_features * sparsity)) / module.num_heads) * module.num_heads / q_features | |
| # logger.debug(f'tune sparsity from {sparsity:.2f} to {tuned_sparsity}') | |
| # sparsity = tuned_sparsity | |
| # break | |
| return super().set_master_dnn_sparsity(master_dnn, sparsity) | |
| def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): | |
| # print(samples) | |
| sample={} | |
| sample['images'] = [samples['images'][0]] | |
| sample['targets'] = [samples['targets'][0]] | |
| # return samples[0].unsqueeze(0) | |
| # res = {k: v[0: 1] for k, v in samples.items()} | |
| return sample | |
| def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False):#产生小模型的步骤 | |
| sample = self.select_most_rep_sample(master_dnn, samples) | |
| # assert sample.dim() == 4 and sample.size(0) == 1 | |
| # print('before') | |
| master_dnn.eval() | |
| self.clear_cached_channel_attention_in_master_dnn(master_dnn) | |
| with torch.no_grad(): | |
| _, o1_token_logits, o1_dot_product_logits = master_dnn(**sample) | |
| # print('after') | |
| boosted_vit = deepcopy(master_dnn) | |
| def get_unpruned_indexes_from_channel_attn(channel_attn: torch.Tensor, k): | |
| assert channel_attn.size(0) == 1, 'use A representative sample to generate channel attentions' | |
| # print('attn_in_unpruned', channel_attn[0][0: 10]) | |
| res = channel_attn[0].nonzero(as_tuple=True)[0] # should be one-dim | |
| # res = channel_attn[0].argsort(descending=True)[0: -int(channel_attn.size(1) * k)].sort()[0] | |
| # g = channel_attn | |
| # k = g.size(1) - int(g.size(1) * k) | |
| # res = g.topk(k, 1)[1][0].sort()[0] | |
| return res | |
| unpruned_indexes_of_layers = {} | |
| # for attn, ff in boosted_vit.transformer.layers: | |
| # for block_i, block in enumerate(boosted_vit.blocks): | |
| for layer_i, layer in enumerate(boosted_vit.model.backbone.body.layers): | |
| for block_i, block in enumerate(layer.blocks): | |
| # attn = block.attn | |
| # ff = block.mlp | |
| ff_0 = get_module(block, f'mlp.fc1') | |
| # ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, k) | |
| ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] | |
| ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) | |
| new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) | |
| new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) | |
| if ff_0.linear.bias is not None: | |
| new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) | |
| set_module(block, 'mlp.fc1', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) | |
| ff_1 = get_module(block, f'mlp.fc2') | |
| new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) | |
| new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) | |
| if ff_1.bias is not None: | |
| new_ff_1.bias.data.copy_(ff_1.bias.data) | |
| set_module(block, 'mlp.fc2', new_ff_1) | |
| unpruned_indexes_of_layers[f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.mlp.fc1.0.weight'] = ff_0_unpruned_indexes | |
| # for block_i,block in enumerate(boosted_vit.vision_model.encoder.layers): | |
| # attn = block.self_attn | |
| # ff = block.mlp | |
| # ff_0 = ff.fc1 | |
| # # ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, k) | |
| # ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] | |
| # ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) | |
| # new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) | |
| # new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) | |
| # if ff_0.linear.bias is not None: | |
| # new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) | |
| # set_module(ff, 'fc1', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) | |
| # ff_1 = ff.fc2 | |
| # new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) | |
| # new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) | |
| # if ff_1.bias is not None: | |
| # new_ff_1.bias.data.copy_(ff_1.bias.data) | |
| # set_module(ff, 'fc2', new_ff_1) | |
| # unpruned_indexes_of_layers[f'vision_model.encoder.layers.{block_i}.mlp.fc1.0.weight'] = ff_0_unpruned_indexes | |
| # for block_i, block in enumerate(boosted_vit.text_decoder.bert.encoder.layer): | |
| # # attn = block.attn | |
| # # ff = block.mlp | |
| # ff_0 = get_module(block, f'intermediate.dense') | |
| # # ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, k) | |
| # ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] | |
| # ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) | |
| # new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) | |
| # new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) | |
| # if ff_0.linear.bias is not None: | |
| # new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) | |
| # set_module(block, 'intermediate.dense', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) | |
| # ff_1 = get_module(block, f'output.dense') | |
| # new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) | |
| # new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) | |
| # if ff_1.bias is not None: | |
| # new_ff_1.bias.data.copy_(ff_1.bias.data) | |
| # set_module(block, 'output.dense', new_ff_1) | |
| # unpruned_indexes_of_layers[f'text_decoder.bert.encoder.layer.{block_i}.intermediate.dense.0.weight'] = ff_0_unpruned_indexes | |
| surrogate_dnn = boosted_vit | |
| surrogate_dnn.eval() | |
| surrogate_dnn = surrogate_dnn.to(get_model_device(master_dnn)) | |
| # logger.debug(surrogate_dnn) | |
| with torch.no_grad(): | |
| _, o2_token_logits, o2_dot_product_logits = surrogate_dnn(**sample) | |
| output_diff = 0. | |
| for o1, o2 in list(zip(o1_dot_product_logits, o2_dot_product_logits)): | |
| output_diff += ((o1 - o2) ** 2).sum() | |
| if o1_token_logits is not None: | |
| output_diff += ((o1_token_logits - o2_token_logits) ** 2).sum() | |
| # assert output_diff < 1e-4, output_diff | |
| logger.info(f'output diff of master and surrogate DNN: {output_diff}') | |
| # logger.debug(f'example output of master/surrogate: {master_dnn_output.sum(0)[0: 10]}, {surrogate_dnn_output.sum(0)[0: 10]}') | |
| # logger.info(f'\nonly prune mlp!!!!\n') | |
| # logger.info(f'\nonly prune mlp!!!!\n') | |
| if return_detail: | |
| return boosted_vit, unpruned_indexes_of_layers | |
| return boosted_vit | |
| def extract_surrogate_dnn_via_samples_with_perf_test(self, master_dnn: nn.Module, samples, return_detail=False): | |
| master_dnn_size = get_model_size(master_dnn, True) | |
| sample = {} | |
| sample['images'] = [samples['images'][0]] | |
| sample['targets'] = [samples['targets'][0]] | |
| master_dnn_latency = self._get_model_latency(master_dnn, sample, 50, | |
| get_model_device(master_dnn), 50, False) | |
| res = self.extract_surrogate_dnn_via_samples(master_dnn, samples, return_detail) | |
| if not return_detail: | |
| surrogate_dnn = res | |
| else: | |
| surrogate_dnn, unpruned_indexes_of_layers = res | |
| surrogate_dnn_size = get_model_size(surrogate_dnn, True) | |
| surrogate_dnn_latency = self._get_model_latency(master_dnn, samples, 50, | |
| get_model_device(master_dnn), 50, False) | |
| logger.info(f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample) -> ' | |
| f'surrogate DNN ({surrogate_dnn_size:.3f}MB, {surrogate_dnn_latency:.4f}s/sample)\n' | |
| f'(model size: ↓ {(master_dnn_size / surrogate_dnn_size):.2f}x, ' | |
| f'latency: ↓ {(master_dnn_latency / surrogate_dnn_latency):.2f}x)') | |
| return res | |
| def _get_model_latency(self, model: torch.nn.Module, sample, sample_num: int, | |
| device: str, warmup_sample_num: int, return_detail=False): | |
| import time | |
| model = model.to(device) | |
| model.eval() | |
| sample['images'] = [sample['images'][0]] | |
| sample['targets'] = [sample['targets'][0]] | |
| # warm up | |
| with torch.no_grad(): | |
| for _ in range(warmup_sample_num): | |
| model(**sample) | |
| infer_time_list = [] | |
| if device == 'cuda' or 'cuda' in str(device): | |
| with torch.no_grad(): | |
| for _ in range(sample_num): | |
| s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
| s.record() | |
| model(**sample) | |
| e.record() | |
| torch.cuda.synchronize() | |
| cur_model_infer_time = s.elapsed_time(e) / 1000. | |
| infer_time_list += [cur_model_infer_time] | |
| else: | |
| with torch.no_grad(): | |
| for _ in range(sample_num): | |
| start = time.time() | |
| model(**sample) | |
| cur_model_infer_time = time.time() - start | |
| infer_time_list += [cur_model_infer_time] | |
| avg_infer_time = sum(infer_time_list) / sample_num | |
| if return_detail: | |
| return avg_infer_time, infer_time_list | |
| return avg_infer_time | |
| # from typing import Any, Dict | |
| # from schema import Schema, Or | |
| # import schema | |
| # from data import Scenario, MergedDataset | |
| # from methods.base.alg import BaseAlg | |
| # from data import build_dataloader | |
| # from ..model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel | |
| # from ...model.base import ElasticDNNUtil | |
| # import torch.optim | |
| # import tqdm | |
| # import torch.nn.functional as F | |
| # from torch import nn | |
| # from utils.dl.common.env import create_tbwriter | |
| # import os | |
| # import random | |
| # import numpy as np | |
| # from copy import deepcopy | |
| # from utils.dl.common.model import LayerActivation2, get_module | |
| # from utils.common.log import logger | |
| # class ElasticDNN_Det_MDPretrainingWoFBSAlg(BaseAlg): | |
| # """ | |
| # TODO: fine-tuned FM -> init MD -> trained MD -> construct indexes (only between similar weights) and fine-tune | |
| # """ | |
| # def get_required_models_schema(self) -> Schema: | |
| # return Schema({ | |
| # 'fm': ElasticDNN_OfflineFMModel, | |
| # 'md': ElasticDNN_OfflineMDModel | |
| # }) | |
| # def get_required_hyp_schema(self) -> Schema: | |
| # return Schema({ | |
| # 'launch_tbboard': bool, | |
| # 'samples_size': any, | |
| # 'generate_md_width_ratio': int, | |
| # 'train_batch_size': int, | |
| # 'val_batch_size': int, | |
| # 'num_workers': int, | |
| # 'optimizer': str, | |
| # 'optimizer_args': dict, | |
| # 'scheduler': str, | |
| # 'scheduler_args': dict, | |
| # 'num_iters': int, | |
| # 'val_freq': int, | |
| # 'distill_loss_weight': float | |
| # }) | |
| # def run(self, scenario: Scenario, hyps: Dict) -> Dict[str, Any]: | |
| # super().run(scenario, hyps) | |
| # assert isinstance(self.models['md'], ElasticDNN_OfflineMDModel) # for auto completion | |
| # assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion | |
| # # 1. add FBS | |
| # device = self.models['md'].device | |
| # if self.models['md'].models_dict['main'] == -1: | |
| # logger.info(f'init master DNN by reducing width of an adapted foundation model (already tuned by LoRA)...') | |
| # before_fm_model = deepcopy(self.models['fm'].models_dict['main']) | |
| # lora_util = self.models['fm'].get_lora_util() | |
| # sample = hyps['samples_size'] | |
| # if isinstance(sample, (tuple, list)) and isinstance(sample[0], int): | |
| # sample = torch.rand(hyps['samples_size']).to(device) | |
| # lora_absorbed_fm_model = lora_util.absorb_lora_and_recover_net_structure(self.models['fm'].models_dict['main'], | |
| # sample) | |
| # self.models['fm'].models_dict['main'] = lora_absorbed_fm_model | |
| # master_dnn = self.models['fm'].generate_md_by_reducing_width(hyps['generate_md_width_ratio'], | |
| # sample) | |
| # self.models['fm'].models_dict['main'] = before_fm_model | |
| # self.models['md'].models_dict['main'] = master_dnn | |
| # self.models['md'].to(device) | |
| # # 2. train (knowledge distillation, index relationship) | |
| # offline_datasets = scenario.get_offline_datasets() | |
| # train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()]) | |
| # val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()]) | |
| # train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'], | |
| # True, None)) | |
| # val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'], | |
| # False, False) | |
| # # logger.info(f'FM acc: {self.models["fm"].get_accuracy(val_loader):.4f}') | |
| # # 2.1 train whole master DNN (knowledge distillation) | |
| # for p in master_dnn.parameters(): | |
| # p.requires_grad = True | |
| # self.models['md'].to_train_mode() | |
| # optimizer = torch.optim.__dict__[hyps['optimizer']]([ | |
| # {'params': self.models['md'].models_dict['main'].parameters(), **hyps['optimizer_args']} | |
| # ]) | |
| # scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args']) | |
| # tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard']) | |
| # pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True) | |
| # best_avg_val_acc = 0. | |
| # md_output_hook = None | |
| # for iter_index in pbar: | |
| # self.models['md'].to_train_mode() | |
| # self.models['fm'].to_eval_mode() | |
| # # rand_sparsity = random.random() * (hyps['max_sparsity'] - hyps['min_sparsity']) + hyps['min_sparsity'] | |
| # # elastic_dnn_util.set_master_dnn_sparsity(self.models['md'].models_dict['main'], rand_sparsity) | |
| # if md_output_hook is None: | |
| # md_output_hook = self.models['md'].get_feature_hook() | |
| # fm_output_hook = self.models['fm'].get_feature_hook() | |
| # x, y = next(train_loader) | |
| # if isinstance(x, dict): | |
| # for k, v in x.items(): | |
| # if isinstance(v, torch.Tensor): | |
| # x[k] = v.to(device) | |
| # y = y.to(device) | |
| # else: | |
| # x, y = x.to(device), y.to(device) | |
| # with torch.no_grad(): | |
| # fm_output = self.models['fm'].infer(x) | |
| # task_loss = self.models['md'].forward_to_get_task_loss(x, y) | |
| # md_output = md_output_hook.output | |
| # fm_output = fm_output_hook.output | |
| # distill_loss = hyps['distill_loss_weight'] * self.models['md'].get_distill_loss(md_output, fm_output) | |
| # total_loss = task_loss + distill_loss | |
| # optimizer.zero_grad() | |
| # total_loss.backward() | |
| # optimizer.step() | |
| # scheduler.step() | |
| # if (iter_index + 1) % hyps['val_freq'] == 0: | |
| # # elastic_dnn_util.clear_cached_channel_attention_in_master_dnn(self.models['md'].models_dict['main']) | |
| # md_output_hook.remove() | |
| # md_output_hook = None | |
| # fm_output_hook.remove() | |
| # fm_output_hook = None | |
| # cur_md = self.models['md'].models_dict['main'] | |
| # md_for_test = deepcopy(self.models['md'].models_dict['main']) | |
| # val_acc = 0. | |
| # self.models['md'].models_dict['main'] = md_for_test | |
| # self.models['md'].to_eval_mode() | |
| # val_acc = self.models['md'].get_accuracy(val_loader) | |
| # self.models['md'].models_dict['main'] = cur_md | |
| # self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_last.pt')) | |
| # self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt')) | |
| # if val_acc > best_avg_val_acc: | |
| # best_avg_val_acc = val_acc | |
| # self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_best.pt')) | |
| # self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt')) | |
| # tb_writer.add_scalars(f'losses', dict(task=task_loss, distill=distill_loss, total=total_loss), iter_index) | |
| # pbar.set_description(f'loss: {total_loss:.6f}') | |
| # if (iter_index + 1) >= hyps['val_freq']: | |
| # tb_writer.add_scalar(f'accs/val_acc', val_acc, iter_index) | |
| # pbar.set_description(f'loss: {total_loss:.6f}, val_acc: {val_acc:.4f}') | |
| # if __name__ == '__main__': | |
| # model = glip_model('new_impl/cv/glip/object_detection/pretrained_model/glip_Swin_T_O365_GoldG.yaml','new_impl/cv/glip/object_detection/pretrained_model/glip_tiny_model_o365_goldg_cc_sbu.pth').cuda() | |
| # model.eval() | |
| # # print(model) | |
| # # exit() | |
| # # config = CLIPConfig.from_pretrained('openai/clip-vit-base-patch16') | |
| # # print(config) | |
| # # # test 1: single image inference | |
| # from PIL import Image, ImageDraw | |
| # import requests | |
| # import numpy as np | |
| # ori_image = Image.open('new_impl/cv/glip/object_detection/9472793441_b7822c00de_z.jpg').convert("RGB") | |
| # image = [np.asarray(ori_image)[:, :, [2, 1, 0]]] | |
| # text = 'sofa . remote . dog . person . car . sky . plane .' | |
| # target = torch.Tensor() | |
| # o = model(image, text) | |
| # o = model._post_process(o[0]) | |
| # print(o) | |
| # bboxes = o.bbox.cpu() | |
| # a = ImageDraw.ImageDraw(ori_image) | |
| # for box in bboxes: | |
| # box = box.int() | |
| # a.rectangle(((box[0], box[1]), (box[2], box[3])), fill=None, outline='red', width=2) | |
| # ori_image.save('test.jpg') | |
| # # print(o.logits_per_image.softmax(dim=1)) | |
| # # o = model(image, torch.load('dnns/clip/test_input_embed.pth'), False) | |
| # # # print(o) | |
| # # print(o.logits_per_image.softmax(dim=1)) | |
| # # exit() | |
| # # test 2: normal training using clip loss (batch) | |
| # from data import get_dataset, build_dataloader | |
| # from torchvision.transforms import Compose, ToTensor, Resize | |
| # dataset = get_dataset('Caltech256', '/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/', 'train', transform=Compose([ | |
| # Resize((32, 32)), ToTensor() | |
| # ])) | |
| # dataloader = build_dataloader(dataset, 8, 0, True, None) | |
| # from PIL import Image | |
| # import requests | |
| # images, labels = next(iter(dataloader)) | |
| # # torch.save(images, 'dnns/clip/test_image.pth') | |
| # classes = dataset.classes | |
| # text = [f"a photo of a {classes[i]}" for i in labels] # should be ground truth | |
| # print(text) | |
| # print(images.size()) | |
| # o = model(images, text, True) | |
| # print(o) | |
| # print(o.logits_per_image.softmax(dim=1)) | |
| # # o = model(image, torch.load('dnns/clip/test_input_embed.pth'), False) | |
| # # # print(o) | |
| # # print(o.logits_per_image.softmax(dim=1)) |