Spaces:
Running
Running
import timm | |
from timm.models._factory import load_checkpoint | |
import torch | |
import os | |
from typing import List, Union, Optional, Tuple | |
from torch import nn | |
from torch.jit import Final | |
from einops import rearrange, repeat | |
from einops.layers.torch import Rearrange | |
from utils.dl.common.model import get_model_device, set_module, get_module, get_model_latency, get_model_size, LayerActivation3 | |
import torch.nn.functional as F | |
from utils.common.log import logger | |
from transformers import AutoTokenizer | |
import torch.nn.functional as F | |
from maskrcnn_benchmark.modeling.detector.generalized_vl_rcnn import GeneralizedVLRCNN | |
from maskrcnn_benchmark.config import cfg | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
from torchvision import transforms as T | |
import matplotlib.pyplot as plt | |
import nltk | |
import re | |
from copy import deepcopy | |
from abc import ABC, abstractmethod | |
from methods.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util | |
from methods.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util, LoRA | |
from new_impl.cv.elasticdnn.api.model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel | |
from methods.elasticdnn.model.base import Abs, KTakesAll, ElasticDNNUtil, Layer_WrappedWithFBS | |
from transformers.models.bert.modeling_bert import BertSelfAttention | |
from transformers import BertConfig | |
import math | |
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | |
def collect_mm_fn(batch): | |
if len(batch[0]) == 2: | |
dict = {'images' : [], 'targets' : []} | |
else: | |
dict = {'images' : [], 'targets' : [], "info_imgs" : [], "ids" : []} | |
for item in batch: | |
if len(item) == 2: | |
img, new_target = item | |
if len(new_target) == 0: | |
continue | |
dict['images'].append(img) | |
dict['targets'].append(new_target) | |
else: | |
img, new_target, info_imgs, ids = item | |
if len(new_target) == 0: | |
continue | |
dict['images'].append(img) | |
dict['targets'].append(new_target) | |
dict['info_imgs'].append(info_imgs) | |
dict['ids'].append(ids) | |
return dict, torch.Tensor([0]) | |
def run_ner(caption): | |
noun_phrases = find_noun_phrases(caption) | |
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] | |
noun_phrases = [phrase for phrase in noun_phrases if phrase != ''] | |
relevant_phrases = noun_phrases | |
labels = noun_phrases | |
tokens_positive = [] | |
for entity, label in zip(relevant_phrases, labels): | |
try: | |
# search all occurrences and mark them as different entities | |
for m in re.finditer(entity, caption.lower()): | |
tokens_positive.append([[m.start(), m.end()]]) | |
except: | |
print("noun entities:", noun_phrases) | |
print("entity:", entity) | |
print("caption:", caption.lower()) | |
return tokens_positive | |
def build_transform(cfg, min_image_size): | |
""" | |
Creates a basic transformation that was used to train the models | |
""" | |
# we are loading images with OpenCV, so we don't need to convert them | |
# to BGR, they are already! So all we need to do is to normalize | |
# by 255 if we want to convert to BGR255 format, or flip the channels | |
# if we want it to be in RGB in [0-1] range. | |
if cfg.INPUT.TO_BGR255: | |
to_bgr_transform = T.Lambda(lambda x: x * 255) | |
else: | |
to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]]) | |
normalize_transform = T.Normalize( | |
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD | |
) | |
transform = T.Compose( | |
[ | |
T.ToPILImage(), | |
T.Resize(min_image_size) if min_image_size is not None else lambda x: x, | |
T.ToTensor(), | |
to_bgr_transform, | |
normalize_transform, | |
] | |
) | |
return transform | |
def remove_punctuation(text: str) -> str: | |
punct = ['|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', | |
'\'', '\"', '’', '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.' | |
] | |
for p in punct: | |
text = text.replace(p, '') | |
return text.strip() | |
def create_positive_map_label_to_token_from_positive_map(positive_map, plus=0): | |
positive_map_label_to_token = {} | |
for i in range(len(positive_map)): | |
positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist() | |
return positive_map_label_to_token | |
def create_positive_map(tokenized, tokens_positive): | |
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j""" | |
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) | |
for j, tok_list in enumerate(tokens_positive): | |
for (beg, end) in tok_list: | |
try: | |
beg_pos = tokenized.char_to_token(beg) | |
end_pos = tokenized.char_to_token(end - 1) | |
except Exception as e: | |
print("beg:", beg, "end:", end) | |
print("token_positive:", tokens_positive) | |
# print("beg_pos:", beg_pos, "end_pos:", end_pos) | |
raise e | |
if beg_pos is None: | |
try: | |
beg_pos = tokenized.char_to_token(beg + 1) | |
if beg_pos is None: | |
beg_pos = tokenized.char_to_token(beg + 2) | |
except: | |
beg_pos = None | |
if end_pos is None: | |
try: | |
end_pos = tokenized.char_to_token(end - 2) | |
if end_pos is None: | |
end_pos = tokenized.char_to_token(end - 3) | |
except: | |
end_pos = None | |
if beg_pos is None or end_pos is None: | |
continue | |
assert beg_pos is not None and end_pos is not None | |
positive_map[j, beg_pos: end_pos + 1].fill_(1) | |
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) | |
def find_noun_phrases(caption: str) -> List[str]: | |
caption = caption.lower() | |
tokens = nltk.word_tokenize(caption) | |
pos_tags = nltk.pos_tag(tokens) | |
grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}" | |
cp = nltk.RegexpParser(grammar) | |
result = cp.parse(pos_tags) | |
noun_phrases = list() | |
for subtree in result.subtrees(): | |
if subtree.label() == 'NP': | |
noun_phrases.append(' '.join(t[0] for t in subtree.leaves())) | |
return noun_phrases | |
class Glip(nn.Module): | |
def __init__(self, config, pretrain_path, min_image_size=None,confidence_threshold=0.7): | |
super(Glip, self).__init__() | |
state_dict = torch.load(pretrain_path)['model'] | |
self.min_image_size = min_image_size | |
self.cfg = config | |
self.confidence_threshold = confidence_threshold | |
self.tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_PATH) | |
self.device = torch.device(cfg.MODEL.DEVICE) | |
for k in list(state_dict.keys()): | |
if k.startswith('module'): | |
new_k = k.replace('module.', '') | |
state_dict[new_k] = state_dict.pop(k) | |
self.model = GeneralizedVLRCNN(config) | |
self.model.load_state_dict(state_dict, strict=False) | |
# self.transform = build_transform(config, min_image_size) | |
def forward(self, images, targets, for_training=None): | |
# img_list = [] | |
# for image in images: | |
# img_list.append(self.transform(image).to(self.device)) | |
# if isinstance(texts, list): | |
# # we directly provided a list of category names | |
# caption_string = "" | |
# tokens_positive = [] | |
# seperation_tokens = " . " | |
# for word in texts: | |
# tokens_positive.append([len(caption_string), len(caption_string) + len(word)]) | |
# caption_string += word | |
# caption_string += seperation_tokens | |
# tokenized = self.tokenizer([caption_string], return_tensors="pt") | |
# tokens_positive = [tokens_positive] | |
# texts = [caption_string] | |
# print(tokens_positive) | |
# else: | |
device = torch.device(cfg.MODEL.DEVICE) | |
images = [image.to(device) for image in images] | |
targets = [target.to(device) for target in targets] | |
texts = [t.get_field("caption") for t in targets if "caption" in t.fields()] | |
positive_map = [] | |
# if custom_entity is None: | |
# tokens_positive = self.run_ner(texts) | |
# print(tokens_positive) | |
# process positive map | |
if self.training == False: | |
try: | |
tokens_positive = run_ner(texts[0]) | |
except: | |
print('a') | |
tokenized = self.tokenizer(texts, return_tensors="pt") | |
positive_map = create_positive_map(tokenized, tokens_positive) | |
if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD": | |
plus = 1 | |
else: | |
plus = 0 | |
positive_map = create_positive_map_label_to_token_from_positive_map(positive_map, plus=plus) | |
else: | |
for i, text in enumerate(texts): | |
tokenized = self.tokenizer(text, return_tensors="pt") | |
tokens_positive = targets[i].get_field('tokens_positive') | |
positive_map.append(create_positive_map(tokenized, tokens_positive)) | |
positive_map = torch.cat(positive_map, dim=0).to(device) | |
if self.training: | |
proposal_losses = self.model(images, targets, texts, positive_map=positive_map) | |
return proposal_losses | |
else: | |
proposals, token_logits, dot_product_logits = self.model(images, targets, texts, positive_map=positive_map) | |
proposal = self._post_process(proposals[0]) | |
return proposal, token_logits, dot_product_logits | |
def _post_process_fixed_thresh(self, predictions): | |
scores = predictions.get_field("scores") | |
labels = predictions.get_field("labels").tolist() | |
thresh = scores.clone() | |
for i, lb in enumerate(labels): | |
if isinstance(self.confidence_threshold, float): | |
thresh[i] = self.confidence_threshold | |
elif len(self.confidence_threshold) == 1: | |
thresh[i] = self.confidence_threshold[0] | |
else: | |
thresh[i] = self.confidence_threshold[lb - 1] | |
keep = torch.nonzero(scores > thresh).squeeze(1) | |
predictions = predictions[keep] | |
scores = predictions.get_field("scores") | |
_, idx = scores.sort(0, descending=True) | |
return predictions[idx] | |
def _post_process(self, predictions, threshold=0.5): | |
scores = predictions.get_field("scores") | |
labels = predictions.get_field("labels").tolist() | |
thresh = scores.clone() | |
for i, lb in enumerate(labels): | |
if isinstance(self.confidence_threshold, float): | |
thresh[i] = threshold | |
elif len(self.confidence_threshold) == 1: | |
thresh[i] = threshold | |
else: | |
thresh[i] = self.confidence_threshold[lb - 1] | |
keep = torch.nonzero(scores > thresh).squeeze(1) | |
predictions = predictions[keep] | |
scores = predictions.get_field("scores") | |
_, idx = scores.sort(0, descending=True) | |
return predictions[idx] | |
# @torch.no_grad() | |
# def clip_vit_b_16(): | |
# # https://huggingface.co/openai/clip-vit-base-patch16 | |
# model = CLIPModelCanReceiveTextEmbeds.from_pretrained("openai/clip-vit-base-patch16") | |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
# print(model) | |
# from PIL import Image | |
# import requests | |
# image = Image.open('/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/003.backpack/003_0001.jpg') | |
# inputs = processor(text=["a photo of a dog", "a photo of a backpack", "a photo of a cat"], images=image, return_tensors="pt", padding=True) | |
# print(inputs) | |
# from utils.dl.common.model import LayerActivation2, get_module | |
# input_embed_hook = LayerActivation2(get_module(model, 'text_model.embeddings')) | |
# outputs = model(**inputs) | |
# logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
# probs = logits_per_image.softmax(dim=1) | |
# print(probs) | |
# input_embed = input_embed_hook.output | |
# input_embed_hook.remove() | |
# torch.save(input_embed, os.path.join(os.path.dirname(__file__), './test_input_embed.pth')) | |
# print('embed', input_embed.size()) | |
# del inputs['input_ids'] | |
# inputs['input_embeds'] = input_embed | |
# outputs = model(**inputs) | |
# logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
# probs = logits_per_image.softmax(dim=1) | |
# print(probs) | |
def glip_model(config_path, pretrain_path): | |
# https://huggingface.co/openai/clip-vit-base-patch16 | |
cfg.merge_from_file(config_path) | |
return cfg, Glip(cfg, pretrain_path) | |
class ToQKV_WrappedWithLoRA(nn.Module): | |
def __init__(self, fc: nn.Linear, ab_r: int): | |
super(ToQKV_WrappedWithLoRA, self).__init__() | |
self.fc = fc | |
self.ab = self.create_ab_as_linear(fc.weight.data, ab_r) | |
def create_ab_as_linear(self, fc_weight: torch.Tensor, ab_r: int): | |
res = nn.Sequential( | |
LoRA(fc_weight.size(1), fc_weight.size(0) // ab_r, bias=False), | |
LoRA(fc_weight.size(0) // ab_r, fc_weight.size(0), bias=False) | |
).to(fc_weight.device) | |
nn.init.kaiming_uniform_(res[0].weight, a=5 ** 0.5) | |
nn.init.zeros_(res[1].weight) | |
return res | |
def forward(self, x): | |
x1 = self.fc(x) | |
x2 = self.ab(x) | |
return x1 + x2 | |
def get_model_latency_2(model: torch.nn.Module, sample: dict, sample_num: int, | |
device: str, warmup_sample_num: int, return_detail=False): | |
"""Get the latency (inference time) of a PyTorch model. | |
Reference: https://deci.ai/resources/blog/measure-inference-time-deep-neural-networks/ | |
Args: | |
model (torch.nn.Module): A PyTorch model. | |
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. | |
sample_num (int): How many inputs which size is :attr:`model_input_size` will be tested and compute the average latency as result. | |
device (str): Typically be 'cpu' or 'cuda'. | |
warmup_sample_num (int): Let model perform some dummy inference to warm up the test environment to avoid measurement loss. | |
return_detail (bool, optional): Beside the average latency, return all result measured. Defaults to False. | |
Returns: | |
Union[float, Tuple[float, List[float]]]: The average latency (and all lantecy data) of :attr:`model`. | |
""" | |
# if isinstance(model_input_size, tuple): | |
# dummy_input = torch.rand(model_input_size).to(device) | |
# else: | |
# dummy_input = model_input_size | |
model = model.to(device) | |
model.eval() | |
# warm up | |
with torch.no_grad(): | |
for _ in range(warmup_sample_num): | |
model(**sample) | |
infer_time_list = [] | |
if device == 'cuda' or 'cuda' in str(device): | |
with torch.no_grad(): | |
for _ in range(sample_num): | |
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
s.record() | |
model(**sample) | |
e.record() | |
torch.cuda.synchronize() | |
cur_model_infer_time = s.elapsed_time(e) / 1000. | |
infer_time_list += [cur_model_infer_time] | |
else: | |
with torch.no_grad(): | |
for _ in range(sample_num): | |
start = time.time() | |
model(**sample) | |
cur_model_infer_time = time.time() - start | |
infer_time_list += [cur_model_infer_time] | |
avg_infer_time = sum(infer_time_list) / sample_num | |
if return_detail: | |
return avg_infer_time, infer_time_list | |
return avg_infer_time | |
class WindowAttention(nn.Module): | |
""" Window based multi-head self attention (W-MSA) module with relative position bias. | |
It supports both of shifted and non-shifted window. | |
Args: | |
dim (int): Number of input channels. | |
window_size (tuple[int]): The height and width of the window. | |
num_heads (int): Number of attention heads. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |
""" | |
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): | |
super().__init__() | |
self.dim = dim | |
self.window_size = window_size # Wh, Ww | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
# define a parameter table of relative position bias | |
self.relative_position_bias_table = nn.Parameter( | |
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |
# get pair-wise relative position index for each token inside the window | |
coords_h = torch.arange(self.window_size[0]) | |
coords_w = torch.arange(self.window_size[1]) | |
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww | |
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 | |
relative_coords[:, :, 1] += self.window_size[1] - 1 | |
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 | |
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |
self.register_buffer("relative_position_index", relative_position_index) | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
trunc_normal_(self.relative_position_bias_table, std=.02) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x, mask=None): | |
""" Forward function. | |
Args: | |
x: input features with shape of (num_windows*B, N, C) | |
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None | |
""" | |
B_, N, C = x.shape | |
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | |
q = q * self.scale | |
attn = (q @ k.transpose(-2, -1)) | |
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | |
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH | |
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |
attn = attn + relative_position_bias.unsqueeze(0) | |
if mask is not None: | |
nW = mask.shape[0] | |
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) | |
attn = attn.view(-1, self.num_heads, N, N) | |
attn = self.softmax(attn) | |
else: | |
attn = self.softmax(attn) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class BiMultiHeadAttention(nn.Module): | |
def __init__(self, v_dim, l_dim, embed_dim, num_heads, dropout=0.1, cfg=None): | |
super(BiMultiHeadAttention, self).__init__() | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.head_dim = embed_dim // num_heads | |
self.v_dim = v_dim | |
self.l_dim = l_dim | |
assert ( | |
self.head_dim * self.num_heads == self.embed_dim | |
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." | |
self.scale = self.head_dim ** (-0.5) | |
self.dropout = dropout | |
self.v_proj = nn.Linear(self.v_dim, self.embed_dim) | |
self.l_proj = nn.Linear(self.l_dim, self.embed_dim) | |
self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) | |
self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) | |
self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) | |
self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) | |
self.stable_softmax_2d = cfg.MODEL.DYHEAD.FUSE_CONFIG.STABLE_SOFTMAX_2D | |
self.clamp_min_for_underflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW | |
self.clamp_max_for_overflow = cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW | |
self._reset_parameters() | |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | |
def _reset_parameters(self): | |
nn.init.xavier_uniform_(self.v_proj.weight) | |
self.v_proj.bias.data.fill_(0) | |
nn.init.xavier_uniform_(self.l_proj.weight) | |
self.l_proj.bias.data.fill_(0) | |
nn.init.xavier_uniform_(self.values_v_proj.weight) | |
self.values_v_proj.bias.data.fill_(0) | |
nn.init.xavier_uniform_(self.values_l_proj.weight) | |
self.values_l_proj.bias.data.fill_(0) | |
nn.init.xavier_uniform_(self.out_v_proj.weight) | |
self.out_v_proj.bias.data.fill_(0) | |
nn.init.xavier_uniform_(self.out_l_proj.weight) | |
self.out_l_proj.bias.data.fill_(0) | |
def forward(self, v, l, attention_mask_l=None): | |
bsz, tgt_len, embed_dim = v.size() | |
query_states = self.v_proj(v) * self.scale | |
key_states = self._shape(self.l_proj(l), -1, bsz) | |
value_v_states = self._shape(self.values_v_proj(v), -1, bsz) | |
value_l_states = self._shape(self.values_l_proj(l), -1, bsz) | |
proj_shape = (bsz * self.num_heads, -1, self.head_dim) | |
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) | |
key_states = key_states.view(*proj_shape) | |
value_v_states = value_v_states.view(*proj_shape) | |
value_l_states = value_l_states.view(*proj_shape) | |
src_len = key_states.size(1) | |
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): | |
raise ValueError( | |
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" | |
) | |
# attn_weights_l = nn.functional.softmax(attn_weights.transpose(1, 2), dim=-1) | |
if self.stable_softmax_2d: | |
attn_weights = attn_weights - attn_weights.max() | |
if self.clamp_min_for_underflow: | |
attn_weights = torch.clamp(attn_weights, min=-50000) # Do not increase -50000, data type half has quite limited range | |
if self.clamp_max_for_overflow: | |
attn_weights = torch.clamp(attn_weights, max=50000) # Do not increase 50000, data type half has quite limited range | |
attn_weights_T = attn_weights.transpose(1, 2) | |
attn_weights_l = (attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[ | |
0]) | |
if self.clamp_min_for_underflow: | |
attn_weights_l = torch.clamp(attn_weights_l, min=-50000) # Do not increase -50000, data type half has quite limited range | |
if self.clamp_max_for_overflow: | |
attn_weights_l = torch.clamp(attn_weights_l, max=50000) # Do not increase 50000, data type half has quite limited range | |
attn_weights_l = attn_weights_l.softmax(dim=-1) | |
if attention_mask_l is not None: | |
assert (attention_mask_l.dim() == 2) | |
attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) | |
attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) | |
attention_mask = attention_mask.masked_fill(attention_mask == 0, -9e15) | |
if attention_mask.size() != (bsz, 1, tgt_len, src_len): | |
raise ValueError( | |
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}" | |
) | |
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | |
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
attn_weights_v = nn.functional.softmax(attn_weights, dim=-1) | |
attn_probs_v = F.dropout(attn_weights_v, p=self.dropout, training=self.training) | |
attn_probs_l = F.dropout(attn_weights_l, p=self.dropout, training=self.training) | |
attn_output_v = torch.bmm(attn_probs_v, value_l_states) | |
attn_output_l = torch.bmm(attn_probs_l, value_v_states) | |
if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim): | |
raise ValueError( | |
f"`attn_output_v` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output_v.size()}" | |
) | |
if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim): | |
raise ValueError( | |
f"`attn_output_l` should be of size {(bsz, self.num_heads, src_len, self.head_dim)}, but is {attn_output_l.size()}" | |
) | |
attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
attn_output_v = attn_output_v.transpose(1, 2) | |
attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) | |
attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim) | |
attn_output_l = attn_output_l.transpose(1, 2) | |
attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) | |
attn_output_v = self.out_v_proj(attn_output_v) | |
attn_output_l = self.out_l_proj(attn_output_l) | |
return attn_output_v, attn_output_l | |
class BertSelfAttentionPrunable(BertSelfAttention): | |
def __init__(self): | |
config = BertConfig.from_pretrained('new_impl/cv/glip/object_detection/bert-base-uncased') | |
super(BertSelfAttentionPrunable, self).__init__(config) | |
def transpose_for_scores(self, x): | |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) | |
x = x.view(new_x_shape) | |
return x.permute(0, 2, 1, 3) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
output_attentions: Optional[bool] = False, | |
) -> Tuple[torch.Tensor]: | |
mixed_query_layer = self.query(hidden_states) | |
# If this is instantiated as a cross-attention module, the keys | |
# and values come from an encoder; the attention mask needs to be | |
# such that the encoder's padding tokens are not attended to. | |
is_cross_attention = encoder_hidden_states is not None | |
if is_cross_attention and past_key_value is not None: | |
# reuse k,v, cross_attentions | |
key_layer = past_key_value[0] | |
value_layer = past_key_value[1] | |
attention_mask = encoder_attention_mask | |
elif is_cross_attention: | |
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) | |
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) | |
attention_mask = encoder_attention_mask | |
elif past_key_value is not None: | |
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) | |
value_layer = torch.cat([past_key_value[1], value_layer], dim=2) | |
else: | |
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |
query_layer = self.transpose_for_scores(mixed_query_layer) | |
use_cache = past_key_value is not None | |
if self.is_decoder: | |
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | |
# Further calls to cross_attention layer can then reuse all cross-attention | |
# key/value_states (first "if" case) | |
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of | |
# all previous decoder key/value_states. Further calls to uni-directional self-attention | |
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) | |
# if encoder bi-directional self-attention `past_key_value` is always `None` | |
past_key_value = (key_layer, value_layer) | |
# Take the dot product between "query" and "key" to get the raw attention scores. | |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": | |
query_length, key_length = query_layer.shape[2], key_layer.shape[2] | |
if use_cache: | |
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( | |
-1, 1 | |
) | |
else: | |
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) | |
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) | |
distance = position_ids_l - position_ids_r | |
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) | |
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility | |
if self.position_embedding_type == "relative_key": | |
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) | |
attention_scores = attention_scores + relative_position_scores | |
elif self.position_embedding_type == "relative_key_query": | |
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) | |
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) | |
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key | |
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |
if attention_mask is not None: | |
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |
attention_scores = attention_scores + attention_mask | |
# Normalize the attention scores to probabilities. | |
attention_probs = nn.functional.softmax(attention_scores, dim=-1) | |
# This is actually dropping out entire tokens to attend to, which might | |
# seem a bit unusual, but is taken from the original Transformer paper. | |
attention_probs = self.dropout(attention_probs) | |
# Mask heads if we want to | |
if head_mask is not None: | |
attention_probs = attention_probs * head_mask | |
context_layer = torch.matmul(attention_probs, value_layer) | |
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
new_context_layer_shape = context_layer.size()[:-2] + (self.query.out_features,) # NOTE: modified | |
context_layer = context_layer.view(new_context_layer_shape) | |
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) | |
if self.is_decoder: | |
outputs = outputs + (past_key_value,) | |
return outputs | |
def init_from_exist_self_attn(attn: BertSelfAttention): | |
# print(attn) | |
res = BertSelfAttentionPrunable() | |
for attr in dir(attn): | |
# if str(attr) in ['transpose_for_scores'] or str(attr).startswith('_'): | |
# continue | |
# if isinstance(getattr(attn, attr), nn.Module): | |
# print(attr) | |
if isinstance(getattr(attn, attr), nn.Module): | |
try: | |
# print(attr, 'ok') | |
setattr(res, attr, getattr(attn, attr)) | |
except Exception as e: | |
print(attr, str(e)) | |
return res | |
class FM_to_MD_GLIP_Util(FM_to_MD_Util): | |
def init_md_from_fm_by_reducing_width_with_perf_test_2(self, fm: nn.Module, reducing_width_ratio: int, | |
samples: torch.Tensor) -> nn.Module: | |
fm_size = get_model_size(fm, True) | |
fm_latency = get_model_latency_2(fm, samples, 20, | |
get_model_device(fm), 20, False) | |
master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio) | |
master_dnn_size = get_model_size(master_dnn, True) | |
logger.debug(f'inited master DNN: {master_dnn}') | |
# from utils.dl.common.model import get_module | |
# print('after generating') | |
# get_module(fm, 'head').debug() | |
# get_module(master_dnn, 'head').debug() | |
# print('test master latency') | |
master_dnn_latency = get_model_latency_2(fm, samples, 20, | |
get_model_device(fm), 20, False) | |
logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)') | |
logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> ' | |
f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n' | |
f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, ' | |
f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)') | |
return master_dnn | |
def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int, sparsity=0.0) -> nn.Module: | |
#sparsity: it is mainly used to make a distilled model used in the baseline algorithm, and the parameter can ensure that the model has the same size as the model used in the online algorithm. | |
fm_vit = deepcopy(fm) | |
def _f(n): | |
return int(n // reducing_width_ratio) | |
# def _rand_indexes(n): | |
# return torch.randperm(n)[0: int(n // reducing_width_ratio)] | |
def l1_max_indexes(p: torch.Tensor, dim=0): | |
assert dim in [0, 1] | |
assert p.dim() in [1, 2, 4] | |
if dim == 1: | |
p = p.T | |
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) | |
n = p.size(0) | |
t1 = p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)] | |
t2 = t1.sort()[0] | |
return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0] | |
def l1_max_indexes_with_sparsity(p: torch.Tensor, dim=0): | |
assert dim in [0, 1] | |
assert p.dim() in [1, 2, 4] | |
if dim == 1: | |
p = p.T | |
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) | |
n = p.size(0) | |
return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio * (1 - sparsity))].sort()[0] | |
for layer_i, layer in enumerate(fm_vit.model.backbone.body.layers): | |
for block in layer.blocks: | |
ori_attn = block.attn | |
new_attn = WindowAttention(ori_attn.dim, ori_attn.window_size, ori_attn.num_heads, True, ori_attn.scale, 0., 0.) | |
new_attn.relative_position_index = ori_attn.relative_position_index | |
new_attn.relative_position_bias_table = ori_attn.relative_position_bias_table | |
new_attn.qkv = ori_attn.qkv | |
new_attn.attn_drop = ori_attn.attn_drop | |
new_attn.proj = ori_attn.proj | |
new_attn.proj_drop = ori_attn.proj_drop | |
set_module(block, 'attn', new_attn) | |
# first_attn = True | |
for layer_i, layer in enumerate(fm_vit.model.backbone.body.layers): | |
for block_i, block in enumerate(layer.blocks): | |
qkv = block.attn.qkv | |
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
qkv.bias is not None, qkv.weight.device) | |
indexes = l1_max_indexes(qkv.weight.data, 0) | |
new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
if qkv.bias is not None: | |
new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
# fm_vit.model.backbone.body.layers[0].blocks.0.attn.qkv | |
set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.attn.qkv', new_qkv) | |
proj = block.attn.proj | |
new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
proj.bias is not None, proj.weight.device) | |
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
if proj.bias is not None: | |
new_proj.bias.data.copy_(proj.bias.data) | |
set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.attn.proj', new_proj) | |
fc1 = block.mlp.fc1 | |
new_fc1 = nn.Linear(fc1.in_features, int(_f(fc1.out_features) * (1 - sparsity)), | |
fc1.bias is not None, fc1.weight.device) | |
indexes = l1_max_indexes_with_sparsity(fc1.weight.data, 0) | |
new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
if fc1.bias is not None: | |
new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.mlp.fc1', new_fc1) | |
fc2 = block.mlp.fc2 | |
new_fc2 = nn.Linear(int(_f(fc2.in_features) * (1 - sparsity)), fc2.out_features, | |
fc2.bias is not None, fc2.weight.device) | |
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes_with_sparsity(fc2.weight.data, 1)]) | |
if fc2.bias is not None: | |
new_fc2.bias.data.copy_(fc2.bias.data) | |
set_module(fm_vit, f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.mlp.fc2', new_fc2) | |
for block in fm_vit.model.language_backbone.body.model.encoder.layer: | |
set_module(block, 'attention.self', BertSelfAttentionPrunable.init_from_exist_self_attn(block.attention.self)) | |
for block_i, block in enumerate(fm_vit.model.language_backbone.body.model.encoder.layer): | |
for k in ['query', 'key', 'value']: | |
qkv = get_module(block, f'attention.self.{k}') | |
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
qkv.bias is not None, qkv.weight.device) | |
indexes = l1_max_indexes(qkv.weight.data, 0) | |
new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
if qkv.bias is not None: | |
new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
set_module(block, f'attention.self.{k}', new_qkv) | |
proj = get_module(block, f'attention.output.dense') | |
new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
proj.bias is not None, proj.weight.device) | |
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
if proj.bias is not None: | |
new_proj.bias.data.copy_(proj.bias.data) | |
set_module(block, f'attention.output.dense', new_proj) | |
fc1 = get_module(block, f'intermediate.dense') | |
new_fc1 = nn.Linear(fc1.in_features, int(_f(fc1.out_features) * (1 - sparsity)), | |
fc1.bias is not None, fc1.weight.device) | |
indexes = l1_max_indexes_with_sparsity(fc1.weight.data, 0) | |
new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
if fc1.bias is not None: | |
new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
set_module(block, f'intermediate.dense', new_fc1) | |
fc2 = get_module(block, f'output.dense') | |
new_fc2 = nn.Linear(int(_f(fc2.in_features) * (1 - sparsity)), fc2.out_features, | |
fc2.bias is not None, fc2.weight.device) | |
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes_with_sparsity(fc2.weight.data, 1)]) | |
if fc2.bias is not None: | |
new_fc2.bias.data.copy_(fc2.bias.data) | |
set_module(block, f'output.dense', new_fc2) | |
for block_i, block in enumerate(fm_vit.model.rpn.head.dyhead_tower): | |
if block_i % 3 == 0: | |
tmp = block.b_attn.attn | |
tmp.head_dim = int(tmp.head_dim // reducing_width_ratio) | |
tmp.embed_dim = int(tmp.embed_dim // reducing_width_ratio) | |
set_module(block, 'b_attn.attn', tmp) | |
for k in ['v_proj', 'l_proj', 'values_v_proj', 'values_l_proj']: | |
qkv = get_module(block, f'b_attn.attn.{k}') | |
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
qkv.bias is not None, qkv.weight.device) | |
indexes = l1_max_indexes(qkv.weight.data, 0) | |
new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
if qkv.bias is not None: | |
new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
set_module(block, f'b_attn.attn.{k}', new_qkv) | |
for k in ['out_v_proj', 'out_l_proj']: | |
qkv = get_module(block, f'b_attn.attn.{k}') | |
new_qkv = nn.Linear(_f(qkv.in_features), qkv.out_features, | |
qkv.bias is not None, qkv.weight.device) | |
new_qkv.weight.data.copy_(qkv.weight.data[:, l1_max_indexes(qkv.weight.data, 1)]) | |
if qkv.bias is not None: | |
new_qkv.bias.data.copy_(qkv.bias.data) | |
set_module(block, f'b_attn.attn.{k}', new_qkv) | |
elif block_i % 3 == 1: | |
set_module(block, 'attention.self', BertSelfAttentionPrunable.init_from_exist_self_attn(block.attention.self)) | |
for k in ['query', 'key', 'value']: | |
qkv = get_module(block, f'attention.self.{k}') | |
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
qkv.bias is not None, qkv.weight.device) | |
indexes = l1_max_indexes(qkv.weight.data, 0) | |
new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
if qkv.bias is not None: | |
new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
set_module(block, f'attention.self.{k}', new_qkv) | |
proj = get_module(block, f'attention.output.dense') | |
new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
proj.bias is not None, proj.weight.device) | |
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
if proj.bias is not None: | |
new_proj.bias.data.copy_(proj.bias.data) | |
set_module(block, f'attention.output.dense', new_proj) | |
fc1 = get_module(block, f'intermediate.dense') | |
new_fc1 = nn.Linear(fc1.in_features, int(_f(fc1.out_features) * (1 - sparsity)), | |
fc1.bias is not None, fc1.weight.device) | |
indexes = l1_max_indexes_with_sparsity(fc1.weight.data, 0) | |
new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
if fc1.bias is not None: | |
new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
set_module(block, f'intermediate.dense', new_fc1) | |
fc2 = get_module(block, f'output.dense') | |
new_fc2 = nn.Linear(int(_f(fc2.in_features) * (1 - sparsity)), fc2.out_features, | |
fc2.bias is not None, fc2.weight.device) | |
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes_with_sparsity(fc2.weight.data, 1)]) | |
if fc2.bias is not None: | |
new_fc2.bias.data.copy_(fc2.bias.data) | |
set_module(block, f'output.dense', new_fc2) | |
# reduce dim_embedding | |
# if name.endswith('patch_embed.proj'): | |
# continue | |
# new_layer = nn.Conv2d(module.in_channels, _f(module.out_channels), module.kernel_size, module.stride, | |
# module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, | |
# module.weight.device) | |
# rand_indexes = l1_max_indexes(module.weight.data) | |
# new_layer.weight.data.copy_(module.weight.data[rand_indexes]) | |
# if new_layer.bias is not None: | |
# new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
# fm_vit.cls_token.data = fm_vit.cls_token.data[:, :, rand_indexes] | |
# fm_vit.pos_embed.data = fm_vit.pos_embed.data[:, :, rand_indexes] | |
# elif isinstance(module, nn.Linear): | |
# if 'head' in name: | |
# continue | |
# new_layer = nn.Linear(_f(module.in_features), module.out_features, | |
# module.bias is not None, module.weight.device) | |
# new_layer.weight.data.copy_(module.weight.data[:, l1_max_indexes(module.weight.data, 1)]) | |
# if new_layer.bias is not None: | |
# new_layer.bias.data.copy_(module.bias.data) | |
# else: | |
# first_attn = False | |
# if first_attn: | |
# first_attn = False | |
# new_layer = nn.Linear(module.in_features, _f(module.out_features), | |
# module.bias is not None, module.weight.device) | |
# rand_indexes = l1_max_indexes(module.weight.data) | |
# new_layer.weight.data.copy_(module.weight.data[rand_indexes]) | |
# if new_layer.bias is not None: | |
# new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
# else: | |
# new_layer = nn.Linear(_f(module.in_features), _f(module.out_features), | |
# module.bias is not None, module.weight.device) | |
# rand_indexes = l1_max_indexes(module.weight.data) | |
# new_layer.weight.data.copy_(module.weight.data[rand_indexes][:, l1_max_indexes(module.weight.data, 1)]) | |
# if new_layer.bias is not None: | |
# new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
# elif isinstance(module, nn.LayerNorm) and ('block' in name or name == 'norm' or name == 'norm.0'): | |
# new_layer = nn.LayerNorm(_f(module.normalized_shape[0]), eps=module.eps, device=module.weight.device) | |
# rand_indexes = l1_max_indexes(module.weight.data) | |
# new_layer.weight.data.copy_(module.weight.data[rand_indexes]) | |
# new_layer.bias.data.copy_(module.bias.data[rand_indexes]) | |
# else: | |
# continue | |
# original_layer_str = str(module) | |
# set_module(fm_vit, name, new_layer) | |
# logger.debug(f'set_module, {name}, {new_layer}') | |
# logger.debug(f'slim {name} from {original_layer_str} to {new_layer}') | |
return fm_vit | |
class FMLoRA_GLIP_Util(FMLoRA_Util): | |
def train_only_lora_and_conv(self, fm: nn.Module): | |
res = [] | |
for n, m in fm.named_modules(): | |
if isinstance(m, LoRA) or isinstance(m, nn.Conv2d): | |
for p in m.parameters(): | |
p.requires_grad = True | |
res += [p] | |
else: | |
for p in m.parameters(): | |
p.requires_grad = False | |
return res | |
def add_lora_ab_to_fm(self, fm: nn.Module, ab_r: int, samples): | |
fm.eval() | |
# samples = {'images' : samples[0], 'targets' : samples[1]} | |
for k, v in samples.items(): | |
if isinstance(v, torch.Tensor) or isinstance(v, BoxList): | |
samples[k] = v.to(get_model_device(fm)) | |
print(k) | |
_, o1_token_logits, o1_dot_product_logits = fm(**samples) | |
mo_list = {k:v for k, v in fm.named_modules()} | |
for name, module in fm.named_modules(): | |
if '.proj' in name or 'out' in name: | |
continue | |
if name.endswith(('k_proj', 'q_proj', 'v_proj', 'qkv', 'attn.proj', 'l_proj', 'query', 'key', 'value')): | |
set_module(fm, name, ToQKV_WrappedWithLoRA(module, ab_r)) | |
_, o2_token_logits, o2_dot_product_logits = fm(**samples) | |
output_diff = 0. | |
for o1, o2 in list(zip(o1_dot_product_logits, o2_dot_product_logits)): | |
output_diff += ((o1 - o2) ** 2).sum() | |
if o1_token_logits is not None: | |
output_diff += ((o1_token_logits - o2_token_logits) ** 2).sum() | |
assert output_diff < 1e-5 | |
return fm | |
def absorb_lora_and_recover_net_structure(self, fm: nn.Module, samples: dict): | |
fm.eval() | |
# print('absorb lora before') | |
for k, v in samples.items(): | |
if isinstance(v, torch.Tensor): | |
samples[k] = v.to(get_model_device(fm)) | |
print(k) | |
_, o1_token_logits, o1_dot_product_logits = fm(**samples) | |
for name, module in fm.named_modules(): | |
if not isinstance(module, ToQKV_WrappedWithLoRA): | |
continue | |
fc = module.fc | |
ab = module.ab | |
fc.weight.add_(ab[1].weight @ ab[0].weight) | |
set_module(fm, name, fc) | |
# print('absorb lora after') | |
_, o2_token_logits, o2_dot_product_logits = fm(**samples) | |
output_diff = 0. | |
for o1, o2 in list(zip(o1_dot_product_logits, o2_dot_product_logits)): | |
output_diff += ((o1 - o2) ** 2).sum() | |
if o1_token_logits is not None: | |
output_diff += ((o1_token_logits - o2_token_logits) ** 2).sum() | |
assert output_diff < 1e-3, output_diff | |
return fm | |
class ElasticDNN_OfflineMMDetFMModel(ElasticDNN_OfflineFMModel): | |
def __init__(self, name: str, models_dict_path: str, device: str, num_classes=10, collate_fn=None): | |
super().__init__(name, models_dict_path, device) | |
self.num_classes = num_classes | |
self.collate_fn = collate_fn | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
# print('DeeplabV3: start test acc') | |
_d = test_loader.dataset | |
from data import build_dataloader | |
if _d.__class__.__name__ == 'MergedDataset': | |
# print('\neval on merged datasets') | |
datasets = _d.datasets | |
if self.collate_fn is None: | |
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=None) for d in datasets] | |
else: | |
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=self.collate_fn) for d in datasets] | |
accs = [self.get_accuracy(loader) for loader in test_loaders] | |
# print(accs) | |
return sum(accs) / len(accs) | |
# print('dataset len', len(test_loader.dataset)) | |
model = self.models_dict['main'] | |
device = self.device | |
model.eval() | |
# print('# classes', model.num_classes) | |
model = model.to(device) | |
from evaluator import COCOEvaluator, MMCOCODecoder | |
from utils.common.others import HiddenPrints | |
with torch.no_grad(): | |
with HiddenPrints(): | |
evaluator = COCOEvaluator( | |
dataloader=test_loader, | |
img_size=(416, 416), | |
confthre=0.01, | |
nmsthre=0.65, | |
num_classes=len(test_loader.dataset.classes), | |
testdev=True | |
) | |
res = evaluator.evaluate(model, False, False, decoder=MMCOCODecoder) | |
map50 = res[1] | |
# print('eval info', res[-1]) | |
return map50 | |
def infer(self, x, *args, **kwargs): | |
if len(args) > 0: | |
print(args, len(args)) | |
return self.models_dict['main'](x, *args) # forward(x, label) | |
return self.models_dict['main'](**x) | |
class ElasticDNN_OfflineMMDetMDModel(ElasticDNN_OfflineMDModel): | |
def __init__(self, name: str, models_dict_path: str, device: str, num_classes=10, collate_fn=None): | |
super().__init__(name, models_dict_path, device) | |
self.num_classes = num_classes | |
self.collate_fn = collate_fn | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
# print('DeeplabV3: start test acc') | |
_d = test_loader.dataset | |
from data import build_dataloader | |
if _d.__class__.__name__ == 'MergedDataset': | |
# print('\neval on merged datasets') | |
datasets = _d.datasets | |
if self.collate_fn is None: | |
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=None) for d in datasets] | |
else: | |
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None, collate_fn=self.collate_fn) for d in datasets] | |
accs = [self.get_accuracy(loader) for loader in test_loaders] | |
# print(accs) | |
return sum(accs) / len(accs) | |
# print('dataset len', len(test_loader.dataset)) | |
model = self.models_dict['main'] | |
device = self.device | |
model.eval() | |
# print('# classes', model.num_classes) | |
model = model.to(device) | |
from evaluator import COCOEvaluator, MMCOCODecoder | |
from utils.common.others import HiddenPrints | |
with torch.no_grad(): | |
with HiddenPrints(): | |
evaluator = COCOEvaluator( | |
dataloader=test_loader, | |
img_size=(416, 416), | |
confthre=0.01, | |
nmsthre=0.65, | |
num_classes=len(test_loader.dataset.classes), | |
testdev=True | |
) | |
res = evaluator.evaluate(model, False, False, decoder=MMCOCODecoder) | |
map50 = res[1] | |
# print('eval info', res[-1]) | |
return map50 | |
def infer(self, x, *args, **kwargs): | |
if len(args) > 0: | |
return self.models_dict['main'](x, *args) # forward(x, label) | |
return self.models_dict['main'](**x) | |
class SqueezeLast(nn.Module): | |
def __init__(self): | |
super(SqueezeLast, self).__init__() | |
def forward(self, x): | |
return x.squeeze(-1) | |
class ProjConv_WrappedWithFBS(Layer_WrappedWithFBS): | |
def __init__(self, raw_conv2d: nn.Conv2d, r): | |
super(ProjConv_WrappedWithFBS, self).__init__() | |
self.fbs = nn.Sequential( | |
Abs(), | |
nn.AdaptiveAvgPool2d(1), | |
nn.Flatten(), | |
nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // r), | |
nn.ReLU(), | |
nn.Linear(raw_conv2d.out_channels // r, raw_conv2d.out_channels), | |
nn.ReLU() | |
) | |
self.raw_conv2d = raw_conv2d | |
# self.raw_bn = raw_bn # remember clear the original BNs in the network | |
nn.init.constant_(self.fbs[5].bias, 1.) | |
nn.init.kaiming_normal_(self.fbs[5].weight) | |
def forward(self, x): | |
raw_x = self.raw_conv2d(x) | |
if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
channel_attention = self.cached_channel_attention | |
else: | |
self.cached_raw_channel_attention = self.fbs(x) | |
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) | |
channel_attention = self.cached_channel_attention | |
return raw_x * channel_attention.unsqueeze(2).unsqueeze(3) | |
class Linear_WrappedWithFBS(Layer_WrappedWithFBS): | |
def __init__(self, linear: nn.Linear, r): | |
super(Linear_WrappedWithFBS, self).__init__() | |
self.linear = linear | |
# for conv: (B, C_in, H, W) -> (B, C_in) -> (B, C_out) | |
# for mlp in ViT: (B, #patches, D: dim of patches embedding) -> (B, D) -> (B, C_out) | |
self.fbs = nn.Sequential( | |
Rearrange('b n d -> b d n'), | |
Abs(), | |
nn.AdaptiveAvgPool1d(1), | |
SqueezeLast(), | |
nn.Linear(linear.in_features, max(linear.out_features // r, 36)), | |
nn.ReLU(), | |
nn.Linear(max(linear.out_features // r, 36), linear.out_features), | |
nn.ReLU() | |
) | |
nn.init.constant_(self.fbs[6].bias, 1.) | |
nn.init.kaiming_normal_(self.fbs[6].weight) | |
def forward(self, x): | |
if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
channel_attention = self.cached_channel_attention | |
else: | |
self.cached_raw_channel_attention = self.fbs(x) | |
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) | |
channel_attention = self.cached_channel_attention | |
raw_res = self.linear(x) | |
return channel_attention.unsqueeze(1) * raw_res | |
class ToQKV_WrappedWithFBS(Layer_WrappedWithFBS): | |
""" | |
This regards to_q/to_k/to_v as a whole (in fact it consists of multiple heads) and prunes it. | |
It seems different channels of different heads are pruned according to the input. | |
This is different from "removing some head" or "removing the same channels in each head". | |
""" | |
def __init__(self, to_qkv: nn.Linear, r): | |
super(ToQKV_WrappedWithFBS, self).__init__() | |
# self.to_qkv = to_qkv | |
self.to_qk = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 * 2, bias=to_qkv.bias is not None) | |
self.to_v = nn.Linear(to_qkv.in_features, to_qkv.out_features // 3, bias=to_qkv.bias is not None) | |
self.to_qk.weight.data.copy_(to_qkv.weight.data[0: to_qkv.out_features // 3 * 2]) | |
if to_qkv.bias is not None: | |
self.to_qk.bias.data.copy_(to_qkv.bias.data[0: to_qkv.out_features // 3 * 2]) | |
self.to_v.weight.data.copy_(to_qkv.weight.data[to_qkv.out_features // 3 * 2: ]) | |
if to_qkv.bias is not None: | |
self.to_v.bias.data.copy_(to_qkv.bias.data[to_qkv.out_features // 3 * 2: ]) | |
self.fbs = nn.Sequential( | |
Rearrange('b n d -> b d n'), | |
Abs(), | |
nn.AdaptiveAvgPool1d(1), | |
SqueezeLast(), | |
nn.Linear(to_qkv.in_features, to_qkv.out_features // 3 // r), | |
nn.ReLU(), | |
# nn.Linear(to_qkv.out_features // 3 // r, to_qkv.out_features // 3), | |
nn.Linear(to_qkv.out_features // 3 // r, self.to_v.out_features), | |
nn.ReLU() | |
) | |
nn.init.constant_(self.fbs[6].bias, 1.) | |
nn.init.kaiming_normal_(self.fbs[6].weight) | |
def forward(self, x): | |
if self.use_cached_channel_attention and self.cached_channel_attention is not None: | |
channel_attention = self.cached_channel_attention | |
else: | |
self.cached_raw_channel_attention = self.fbs(x) | |
# print() | |
# for attn in self.cached_raw_channel_attention.chunk(3, dim=1)[0: 1]: | |
# print(self.cached_raw_channel_attention.size(), attn.size()) | |
# print(self.k_takes_all.k) | |
# print(attn[0].nonzero(as_tuple=True)[0].size(), attn[0]) | |
self.cached_channel_attention = self.k_takes_all(self.cached_raw_channel_attention) | |
# for attn in self.cached_channel_attention.chunk(3, dim=1)[0: 1]: | |
# print(self.cached_channel_attention.size(), attn.size()) | |
# print(self.k_takes_all.k) | |
# print(attn[0].nonzero(as_tuple=True)[0].size(), attn[0]) | |
# print() | |
channel_attention = self.cached_channel_attention | |
qk = self.to_qk(x) | |
v = channel_attention.unsqueeze(1) * self.to_v(x) | |
return torch.cat([qk, v], dim=-1) | |
# qkv = raw_res.chunk(3, dim = -1) | |
# raw_v = qkv[2] | |
# print('raw_k, raw_v', qkv[0].sum((0, 1))[0: 10], qkv[0].sum((0, 1)).nonzero(as_tuple=True)[0].size(), | |
# qkv[1].sum((0, 1))[0: 10], qkv[1].sum((0, 1)).nonzero(as_tuple=True)[0].size(),) | |
# print('raw_v', raw_v.size(), raw_v.sum((0, 1))[0: 10], raw_v.sum((0, 1)).nonzero(as_tuple=True)[0].size()) | |
# qkv_attn = channel_attention.chunk(3, dim=-1) | |
# print('attn', [attn[0][0: 10] for attn in qkv_attn]) | |
# print(channel_attention.unsqueeze(1).size(), raw_res.size()) | |
# print('fbs', channel_attention.size(), raw_res.size()) | |
# return channel_attention.unsqueeze(1) * raw_res | |
class StaticFBS(nn.Module): | |
def __init__(self, static_channel_attention): | |
super(StaticFBS, self).__init__() | |
assert static_channel_attention.dim() == 2 and static_channel_attention.size(0) == 1 | |
self.static_channel_attention = nn.Parameter(static_channel_attention, requires_grad=False) # (1, dim) | |
def forward(self, x): | |
# print('staticfbs', x, self.static_channel_attention.unsqueeze(1)) | |
return x * self.static_channel_attention.unsqueeze(1) | |
class ElasticGLIPUtil(ElasticDNNUtil): | |
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): | |
assert len(ignore_layers) == 0, 'not supported yet' | |
raw_vit = deepcopy(raw_dnn) | |
for name, module in raw_vit.named_modules(): | |
# if name.endswith('patch_embed'): | |
# set_module(module, 'proj', ProjConv_WrappedWithFBS(module.proj, r)) | |
# if name.endswith('attn') and not name.endswith('b_attn.attn') and not name.endswith('b_attn'): | |
# set_module(module, 'qkv', ToQKV_WrappedWithFBS(module.qkv, r)) | |
if name.endswith('intermediate'): | |
set_module(module, 'dense', Linear_WrappedWithFBS(module.dense, r)) | |
elif name.endswith('mlp'): | |
set_module(module, 'fc1', Linear_WrappedWithFBS(module.fc1, r)) | |
return raw_vit | |
def set_master_dnn_sparsity(self, master_dnn: nn.Module, sparsity: float): | |
# for name, module in master_dnn.named_modules(): | |
# if not name.endswith('attn'): | |
# continue | |
# q_features = module.qkv.to_qk.out_features // 2 | |
# if (q_features - int(q_features * sparsity)) % module.num_heads != 0: | |
# # tune sparsity to ensure #unpruned channel % num_heads == 0 | |
# # so that the pruning seems to reduce the dim_head of each head | |
# tuned_sparsity = 1. - int((q_features - int(q_features * sparsity)) / module.num_heads) * module.num_heads / q_features | |
# logger.debug(f'tune sparsity from {sparsity:.2f} to {tuned_sparsity}') | |
# sparsity = tuned_sparsity | |
# break | |
return super().set_master_dnn_sparsity(master_dnn, sparsity) | |
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): | |
# print(samples) | |
sample={} | |
sample['images'] = [samples['images'][0]] | |
sample['targets'] = [samples['targets'][0]] | |
# return samples[0].unsqueeze(0) | |
# res = {k: v[0: 1] for k, v in samples.items()} | |
return sample | |
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False):#产生小模型的步骤 | |
sample = self.select_most_rep_sample(master_dnn, samples) | |
# assert sample.dim() == 4 and sample.size(0) == 1 | |
# print('before') | |
master_dnn.eval() | |
self.clear_cached_channel_attention_in_master_dnn(master_dnn) | |
with torch.no_grad(): | |
_, o1_token_logits, o1_dot_product_logits = master_dnn(**sample) | |
# print('after') | |
boosted_vit = deepcopy(master_dnn) | |
def get_unpruned_indexes_from_channel_attn(channel_attn: torch.Tensor, k): | |
assert channel_attn.size(0) == 1, 'use A representative sample to generate channel attentions' | |
# print('attn_in_unpruned', channel_attn[0][0: 10]) | |
res = channel_attn[0].nonzero(as_tuple=True)[0] # should be one-dim | |
# res = channel_attn[0].argsort(descending=True)[0: -int(channel_attn.size(1) * k)].sort()[0] | |
# g = channel_attn | |
# k = g.size(1) - int(g.size(1) * k) | |
# res = g.topk(k, 1)[1][0].sort()[0] | |
return res | |
unpruned_indexes_of_layers = {} | |
# for attn, ff in boosted_vit.transformer.layers: | |
# for block_i, block in enumerate(boosted_vit.blocks): | |
for layer_i, layer in enumerate(boosted_vit.model.backbone.body.layers): | |
for block_i, block in enumerate(layer.blocks): | |
# attn = block.attn | |
# ff = block.mlp | |
ff_0 = get_module(block, f'mlp.fc1') | |
# ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, k) | |
ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] | |
ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) | |
new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) | |
new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) | |
if ff_0.linear.bias is not None: | |
new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) | |
set_module(block, 'mlp.fc1', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) | |
ff_1 = get_module(block, f'mlp.fc2') | |
new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) | |
new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) | |
if ff_1.bias is not None: | |
new_ff_1.bias.data.copy_(ff_1.bias.data) | |
set_module(block, 'mlp.fc2', new_ff_1) | |
unpruned_indexes_of_layers[f'model.backbone.body.layers.{layer_i}.blocks.{block_i}.mlp.fc1.0.weight'] = ff_0_unpruned_indexes | |
# for block_i,block in enumerate(boosted_vit.vision_model.encoder.layers): | |
# attn = block.self_attn | |
# ff = block.mlp | |
# ff_0 = ff.fc1 | |
# # ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, k) | |
# ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] | |
# ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) | |
# new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) | |
# new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) | |
# if ff_0.linear.bias is not None: | |
# new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) | |
# set_module(ff, 'fc1', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) | |
# ff_1 = ff.fc2 | |
# new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) | |
# new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) | |
# if ff_1.bias is not None: | |
# new_ff_1.bias.data.copy_(ff_1.bias.data) | |
# set_module(ff, 'fc2', new_ff_1) | |
# unpruned_indexes_of_layers[f'vision_model.encoder.layers.{block_i}.mlp.fc1.0.weight'] = ff_0_unpruned_indexes | |
# for block_i, block in enumerate(boosted_vit.text_decoder.bert.encoder.layer): | |
# # attn = block.attn | |
# # ff = block.mlp | |
# ff_0 = get_module(block, f'intermediate.dense') | |
# # ff_0_unpruned_indexes = get_unpruned_indexes_from_channel_attn(ff_0.cached_channel_attention, k) | |
# ff_0_pruned_indexes = ff_0.k_takes_all.cached_i[0].sort()[0] | |
# ff_0_unpruned_indexes = torch.LongTensor([ii for ii in range(ff_0.cached_channel_attention.size(1)) if ii not in ff_0_pruned_indexes]) | |
# new_ff_0 = nn.Linear(ff_0.linear.in_features, ff_0_unpruned_indexes.size(0), ff_0.linear.bias is not None) | |
# new_ff_0.weight.data.copy_(ff_0.linear.weight.data[ff_0_unpruned_indexes]) | |
# if ff_0.linear.bias is not None: | |
# new_ff_0.bias.data.copy_(ff_0.linear.bias.data[ff_0_unpruned_indexes]) | |
# set_module(block, 'intermediate.dense', nn.Sequential(new_ff_0, StaticFBS(ff_0.cached_channel_attention[:, ff_0_unpruned_indexes]))) | |
# ff_1 = get_module(block, f'output.dense') | |
# new_ff_1 = nn.Linear(ff_0_unpruned_indexes.size(0), ff_1.out_features, ff_1.bias is not None) | |
# new_ff_1.weight.data.copy_(ff_1.weight.data[:, ff_0_unpruned_indexes]) | |
# if ff_1.bias is not None: | |
# new_ff_1.bias.data.copy_(ff_1.bias.data) | |
# set_module(block, 'output.dense', new_ff_1) | |
# unpruned_indexes_of_layers[f'text_decoder.bert.encoder.layer.{block_i}.intermediate.dense.0.weight'] = ff_0_unpruned_indexes | |
surrogate_dnn = boosted_vit | |
surrogate_dnn.eval() | |
surrogate_dnn = surrogate_dnn.to(get_model_device(master_dnn)) | |
# logger.debug(surrogate_dnn) | |
with torch.no_grad(): | |
_, o2_token_logits, o2_dot_product_logits = surrogate_dnn(**sample) | |
output_diff = 0. | |
for o1, o2 in list(zip(o1_dot_product_logits, o2_dot_product_logits)): | |
output_diff += ((o1 - o2) ** 2).sum() | |
if o1_token_logits is not None: | |
output_diff += ((o1_token_logits - o2_token_logits) ** 2).sum() | |
# assert output_diff < 1e-4, output_diff | |
logger.info(f'output diff of master and surrogate DNN: {output_diff}') | |
# logger.debug(f'example output of master/surrogate: {master_dnn_output.sum(0)[0: 10]}, {surrogate_dnn_output.sum(0)[0: 10]}') | |
# logger.info(f'\nonly prune mlp!!!!\n') | |
# logger.info(f'\nonly prune mlp!!!!\n') | |
if return_detail: | |
return boosted_vit, unpruned_indexes_of_layers | |
return boosted_vit | |
def extract_surrogate_dnn_via_samples_with_perf_test(self, master_dnn: nn.Module, samples, return_detail=False): | |
master_dnn_size = get_model_size(master_dnn, True) | |
sample = {} | |
sample['images'] = [samples['images'][0]] | |
sample['targets'] = [samples['targets'][0]] | |
master_dnn_latency = self._get_model_latency(master_dnn, sample, 50, | |
get_model_device(master_dnn), 50, False) | |
res = self.extract_surrogate_dnn_via_samples(master_dnn, samples, return_detail) | |
if not return_detail: | |
surrogate_dnn = res | |
else: | |
surrogate_dnn, unpruned_indexes_of_layers = res | |
surrogate_dnn_size = get_model_size(surrogate_dnn, True) | |
surrogate_dnn_latency = self._get_model_latency(master_dnn, samples, 50, | |
get_model_device(master_dnn), 50, False) | |
logger.info(f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample) -> ' | |
f'surrogate DNN ({surrogate_dnn_size:.3f}MB, {surrogate_dnn_latency:.4f}s/sample)\n' | |
f'(model size: ↓ {(master_dnn_size / surrogate_dnn_size):.2f}x, ' | |
f'latency: ↓ {(master_dnn_latency / surrogate_dnn_latency):.2f}x)') | |
return res | |
def _get_model_latency(self, model: torch.nn.Module, sample, sample_num: int, | |
device: str, warmup_sample_num: int, return_detail=False): | |
import time | |
model = model.to(device) | |
model.eval() | |
sample['images'] = [sample['images'][0]] | |
sample['targets'] = [sample['targets'][0]] | |
# warm up | |
with torch.no_grad(): | |
for _ in range(warmup_sample_num): | |
model(**sample) | |
infer_time_list = [] | |
if device == 'cuda' or 'cuda' in str(device): | |
with torch.no_grad(): | |
for _ in range(sample_num): | |
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
s.record() | |
model(**sample) | |
e.record() | |
torch.cuda.synchronize() | |
cur_model_infer_time = s.elapsed_time(e) / 1000. | |
infer_time_list += [cur_model_infer_time] | |
else: | |
with torch.no_grad(): | |
for _ in range(sample_num): | |
start = time.time() | |
model(**sample) | |
cur_model_infer_time = time.time() - start | |
infer_time_list += [cur_model_infer_time] | |
avg_infer_time = sum(infer_time_list) / sample_num | |
if return_detail: | |
return avg_infer_time, infer_time_list | |
return avg_infer_time | |
# from typing import Any, Dict | |
# from schema import Schema, Or | |
# import schema | |
# from data import Scenario, MergedDataset | |
# from methods.base.alg import BaseAlg | |
# from data import build_dataloader | |
# from ..model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel | |
# from ...model.base import ElasticDNNUtil | |
# import torch.optim | |
# import tqdm | |
# import torch.nn.functional as F | |
# from torch import nn | |
# from utils.dl.common.env import create_tbwriter | |
# import os | |
# import random | |
# import numpy as np | |
# from copy import deepcopy | |
# from utils.dl.common.model import LayerActivation2, get_module | |
# from utils.common.log import logger | |
# class ElasticDNN_Det_MDPretrainingWoFBSAlg(BaseAlg): | |
# """ | |
# TODO: fine-tuned FM -> init MD -> trained MD -> construct indexes (only between similar weights) and fine-tune | |
# """ | |
# def get_required_models_schema(self) -> Schema: | |
# return Schema({ | |
# 'fm': ElasticDNN_OfflineFMModel, | |
# 'md': ElasticDNN_OfflineMDModel | |
# }) | |
# def get_required_hyp_schema(self) -> Schema: | |
# return Schema({ | |
# 'launch_tbboard': bool, | |
# 'samples_size': any, | |
# 'generate_md_width_ratio': int, | |
# 'train_batch_size': int, | |
# 'val_batch_size': int, | |
# 'num_workers': int, | |
# 'optimizer': str, | |
# 'optimizer_args': dict, | |
# 'scheduler': str, | |
# 'scheduler_args': dict, | |
# 'num_iters': int, | |
# 'val_freq': int, | |
# 'distill_loss_weight': float | |
# }) | |
# def run(self, scenario: Scenario, hyps: Dict) -> Dict[str, Any]: | |
# super().run(scenario, hyps) | |
# assert isinstance(self.models['md'], ElasticDNN_OfflineMDModel) # for auto completion | |
# assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion | |
# # 1. add FBS | |
# device = self.models['md'].device | |
# if self.models['md'].models_dict['main'] == -1: | |
# logger.info(f'init master DNN by reducing width of an adapted foundation model (already tuned by LoRA)...') | |
# before_fm_model = deepcopy(self.models['fm'].models_dict['main']) | |
# lora_util = self.models['fm'].get_lora_util() | |
# sample = hyps['samples_size'] | |
# if isinstance(sample, (tuple, list)) and isinstance(sample[0], int): | |
# sample = torch.rand(hyps['samples_size']).to(device) | |
# lora_absorbed_fm_model = lora_util.absorb_lora_and_recover_net_structure(self.models['fm'].models_dict['main'], | |
# sample) | |
# self.models['fm'].models_dict['main'] = lora_absorbed_fm_model | |
# master_dnn = self.models['fm'].generate_md_by_reducing_width(hyps['generate_md_width_ratio'], | |
# sample) | |
# self.models['fm'].models_dict['main'] = before_fm_model | |
# self.models['md'].models_dict['main'] = master_dnn | |
# self.models['md'].to(device) | |
# # 2. train (knowledge distillation, index relationship) | |
# offline_datasets = scenario.get_offline_datasets() | |
# train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()]) | |
# val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()]) | |
# train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'], | |
# True, None)) | |
# val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'], | |
# False, False) | |
# # logger.info(f'FM acc: {self.models["fm"].get_accuracy(val_loader):.4f}') | |
# # 2.1 train whole master DNN (knowledge distillation) | |
# for p in master_dnn.parameters(): | |
# p.requires_grad = True | |
# self.models['md'].to_train_mode() | |
# optimizer = torch.optim.__dict__[hyps['optimizer']]([ | |
# {'params': self.models['md'].models_dict['main'].parameters(), **hyps['optimizer_args']} | |
# ]) | |
# scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args']) | |
# tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard']) | |
# pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True) | |
# best_avg_val_acc = 0. | |
# md_output_hook = None | |
# for iter_index in pbar: | |
# self.models['md'].to_train_mode() | |
# self.models['fm'].to_eval_mode() | |
# # rand_sparsity = random.random() * (hyps['max_sparsity'] - hyps['min_sparsity']) + hyps['min_sparsity'] | |
# # elastic_dnn_util.set_master_dnn_sparsity(self.models['md'].models_dict['main'], rand_sparsity) | |
# if md_output_hook is None: | |
# md_output_hook = self.models['md'].get_feature_hook() | |
# fm_output_hook = self.models['fm'].get_feature_hook() | |
# x, y = next(train_loader) | |
# if isinstance(x, dict): | |
# for k, v in x.items(): | |
# if isinstance(v, torch.Tensor): | |
# x[k] = v.to(device) | |
# y = y.to(device) | |
# else: | |
# x, y = x.to(device), y.to(device) | |
# with torch.no_grad(): | |
# fm_output = self.models['fm'].infer(x) | |
# task_loss = self.models['md'].forward_to_get_task_loss(x, y) | |
# md_output = md_output_hook.output | |
# fm_output = fm_output_hook.output | |
# distill_loss = hyps['distill_loss_weight'] * self.models['md'].get_distill_loss(md_output, fm_output) | |
# total_loss = task_loss + distill_loss | |
# optimizer.zero_grad() | |
# total_loss.backward() | |
# optimizer.step() | |
# scheduler.step() | |
# if (iter_index + 1) % hyps['val_freq'] == 0: | |
# # elastic_dnn_util.clear_cached_channel_attention_in_master_dnn(self.models['md'].models_dict['main']) | |
# md_output_hook.remove() | |
# md_output_hook = None | |
# fm_output_hook.remove() | |
# fm_output_hook = None | |
# cur_md = self.models['md'].models_dict['main'] | |
# md_for_test = deepcopy(self.models['md'].models_dict['main']) | |
# val_acc = 0. | |
# self.models['md'].models_dict['main'] = md_for_test | |
# self.models['md'].to_eval_mode() | |
# val_acc = self.models['md'].get_accuracy(val_loader) | |
# self.models['md'].models_dict['main'] = cur_md | |
# self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_last.pt')) | |
# self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt')) | |
# if val_acc > best_avg_val_acc: | |
# best_avg_val_acc = val_acc | |
# self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_best.pt')) | |
# self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt')) | |
# tb_writer.add_scalars(f'losses', dict(task=task_loss, distill=distill_loss, total=total_loss), iter_index) | |
# pbar.set_description(f'loss: {total_loss:.6f}') | |
# if (iter_index + 1) >= hyps['val_freq']: | |
# tb_writer.add_scalar(f'accs/val_acc', val_acc, iter_index) | |
# pbar.set_description(f'loss: {total_loss:.6f}, val_acc: {val_acc:.4f}') | |
# if __name__ == '__main__': | |
# model = glip_model('new_impl/cv/glip/object_detection/pretrained_model/glip_Swin_T_O365_GoldG.yaml','new_impl/cv/glip/object_detection/pretrained_model/glip_tiny_model_o365_goldg_cc_sbu.pth').cuda() | |
# model.eval() | |
# # print(model) | |
# # exit() | |
# # config = CLIPConfig.from_pretrained('openai/clip-vit-base-patch16') | |
# # print(config) | |
# # # test 1: single image inference | |
# from PIL import Image, ImageDraw | |
# import requests | |
# import numpy as np | |
# ori_image = Image.open('new_impl/cv/glip/object_detection/9472793441_b7822c00de_z.jpg').convert("RGB") | |
# image = [np.asarray(ori_image)[:, :, [2, 1, 0]]] | |
# text = 'sofa . remote . dog . person . car . sky . plane .' | |
# target = torch.Tensor() | |
# o = model(image, text) | |
# o = model._post_process(o[0]) | |
# print(o) | |
# bboxes = o.bbox.cpu() | |
# a = ImageDraw.ImageDraw(ori_image) | |
# for box in bboxes: | |
# box = box.int() | |
# a.rectangle(((box[0], box[1]), (box[2], box[3])), fill=None, outline='red', width=2) | |
# ori_image.save('test.jpg') | |
# # print(o.logits_per_image.softmax(dim=1)) | |
# # o = model(image, torch.load('dnns/clip/test_input_embed.pth'), False) | |
# # # print(o) | |
# # print(o.logits_per_image.softmax(dim=1)) | |
# # exit() | |
# # test 2: normal training using clip loss (batch) | |
# from data import get_dataset, build_dataloader | |
# from torchvision.transforms import Compose, ToTensor, Resize | |
# dataset = get_dataset('Caltech256', '/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/', 'train', transform=Compose([ | |
# Resize((32, 32)), ToTensor() | |
# ])) | |
# dataloader = build_dataloader(dataset, 8, 0, True, None) | |
# from PIL import Image | |
# import requests | |
# images, labels = next(iter(dataloader)) | |
# # torch.save(images, 'dnns/clip/test_image.pth') | |
# classes = dataset.classes | |
# text = [f"a photo of a {classes[i]}" for i in labels] # should be ground truth | |
# print(text) | |
# print(images.size()) | |
# o = model(images, text, True) | |
# print(o) | |
# print(o.logits_per_image.softmax(dim=1)) | |
# # o = model(image, torch.load('dnns/clip/test_input_embed.pth'), False) | |
# # # print(o) | |
# # print(o.logits_per_image.softmax(dim=1)) |