Spaces:
Running
Running
import torch | |
from torch import nn | |
from copy import deepcopy | |
from .base import FM_to_MD_Util | |
from utils.common.log import logger | |
from utils.dl.common.model import set_module, get_module, get_super_module | |
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size | |
from utils.common.log import logger | |
from transformers.models.vilt.modeling_vilt import ViltSelfAttention | |
from transformers import ViltConfig | |
from typing import Optional, Tuple | |
import math | |
class ViltSelfAttentionPrunable(ViltSelfAttention): | |
def __init__(self): | |
config = ViltConfig.from_pretrained('dandelin/vilt-b32-mlm-itm') | |
super(ViltSelfAttentionPrunable, self).__init__(config) | |
def transpose_for_scores(self, x): | |
new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) | |
x = x.view(*new_x_shape) | |
return x.permute(0, 2, 1, 3) | |
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): | |
mixed_query_layer = self.query(hidden_states) | |
key_layer = self.transpose_for_scores(self.key(hidden_states)) | |
value_layer = self.transpose_for_scores(self.value(hidden_states)) | |
query_layer = self.transpose_for_scores(mixed_query_layer) | |
# Take the dot product between "query" and "key" to get the raw attention scores. | |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |
if attention_mask is not None: | |
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |
attention_scores = attention_scores + attention_mask | |
# Normalize the attention scores to probabilities. | |
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |
# This is actually dropping out entire tokens to attend to, which might | |
# seem a bit unusual, but is taken from the original Transformer paper. | |
attention_probs = self.dropout(attention_probs) | |
# Mask heads if we want to | |
if head_mask is not None: | |
attention_probs = attention_probs * head_mask | |
context_layer = torch.matmul(attention_probs, value_layer) | |
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |
# new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |
new_context_layer_shape = context_layer.size()[:-2] + (-1,) | |
context_layer = context_layer.view(*new_context_layer_shape) | |
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) | |
return outputs | |
def init_from_exist_self_attn(attn: ViltSelfAttention): | |
# print(attn) | |
res = ViltSelfAttentionPrunable() | |
for attr in dir(attn): | |
# if str(attr) in ['transpose_for_scores'] or str(attr).startswith('_'): | |
# continue | |
# if isinstance(getattr(attn, attr), nn.Module): | |
# print(attr) | |
if isinstance(getattr(attn, attr), nn.Module): | |
try: | |
# print(attr, 'ok') | |
setattr(res, attr, getattr(attn, attr)) | |
except Exception as e: | |
print(attr, str(e)) | |
return res | |
class FM_to_MD_Vilt_Util(FM_to_MD_Util): | |
def init_md_from_fm_by_reducing_width(self, fm: nn.Module, reducing_width_ratio: int) -> nn.Module: | |
fm_vit = deepcopy(fm) | |
for block in fm_vit.vilt.encoder.layer: | |
set_module(block, 'attention.attention', ViltSelfAttentionPrunable.init_from_exist_self_attn(block.attention.attention)) | |
def _f(n): | |
return int(n // reducing_width_ratio) | |
# def _rand_indexes(n): | |
# return torch.randperm(n)[0: int(n // reducing_width_ratio)] | |
def l1_max_indexes(p: torch.Tensor, dim=0): | |
assert dim in [0, 1] | |
assert p.dim() in [1, 2, 4] | |
if dim == 1: | |
p = p.T | |
p_norm = p.abs().contiguous().view(p.size(0), -1).sum(dim=1) | |
n = p.size(0) | |
return p_norm.argsort(descending=True)[0: int(n // reducing_width_ratio)].sort()[0] | |
for block_i, block in enumerate(fm_vit.vilt.encoder.layer): | |
for k in ['query', 'key', 'value']: | |
qkv = get_module(block, f'attention.attention.{k}') | |
new_qkv = nn.Linear(qkv.in_features, _f(qkv.out_features), | |
qkv.bias is not None, qkv.weight.device) | |
indexes = l1_max_indexes(qkv.weight.data, 0) | |
new_qkv.weight.data.copy_(qkv.weight.data[indexes]) | |
if qkv.bias is not None: | |
new_qkv.bias.data.copy_(qkv.bias.data[indexes]) | |
set_module(block, f'attention.attention.{k}', new_qkv) | |
proj = get_module(block, f'attention.output.dense') | |
new_proj = nn.Linear(_f(proj.in_features), proj.out_features, | |
proj.bias is not None, proj.weight.device) | |
new_proj.weight.data.copy_(proj.weight.data[:, l1_max_indexes(proj.weight.data, 1)]) | |
if proj.bias is not None: | |
new_proj.bias.data.copy_(proj.bias.data) | |
set_module(block, f'attention.output.dense', new_proj) | |
fc1 = get_module(block, f'intermediate.dense') | |
new_fc1 = nn.Linear(fc1.in_features, _f(fc1.out_features), | |
fc1.bias is not None, fc1.weight.device) | |
indexes = l1_max_indexes(fc1.weight.data, 0) | |
new_fc1.weight.data.copy_(fc1.weight.data[indexes]) | |
if fc1.bias is not None: | |
new_fc1.bias.data.copy_(fc1.bias.data[indexes]) | |
set_module(block, f'intermediate.dense', new_fc1) | |
fc2 = get_module(block, f'output.dense') | |
new_fc2 = nn.Linear(_f(fc2.in_features), fc2.out_features, | |
fc2.bias is not None, fc2.weight.device) | |
new_fc2.weight.data.copy_(fc2.weight.data[:, l1_max_indexes(fc2.weight.data, 1)]) | |
if fc2.bias is not None: | |
new_fc2.bias.data.copy_(fc2.bias.data) | |
set_module(block, f'output.dense', new_fc2) | |
return fm_vit | |
def init_md_from_fm_by_reducing_width_with_perf_test(self, fm: nn.Module, reducing_width_ratio: int, | |
samples: torch.Tensor) -> nn.Module: | |
fm_size = get_model_size(fm, True) | |
fm_latency = self._get_model_latency(fm, samples, 20, | |
get_model_device(fm), 20, False) | |
master_dnn = self.init_md_from_fm_by_reducing_width(fm, reducing_width_ratio) | |
master_dnn_size = get_model_size(master_dnn, True) | |
logger.debug(f'inited master DNN: {master_dnn}') | |
master_dnn_latency = self._get_model_latency(master_dnn, samples, 20, | |
get_model_device(master_dnn), 20, False) | |
logger.info(f'init master DNN (w/o FBS yet) by reducing foundation model\'s width (by {reducing_width_ratio:d}x)') | |
logger.info(f'foundation model ({fm_size:.3f}MB, {fm_latency:.4f}s/sample) -> ' | |
f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample)\n' | |
f'(model size: ↓ {(fm_size / master_dnn_size):.2f}x, ' | |
f'latency: ↓ {(fm_latency / master_dnn_latency):.2f}x)') | |
return master_dnn | |
def _get_model_latency(self, model: torch.nn.Module, model_input_size, sample_num: int, | |
device: str, warmup_sample_num: int, return_detail=False): | |
import time | |
if isinstance(model_input_size, tuple): | |
dummy_input = torch.rand(model_input_size).to(device) | |
else: | |
dummy_input = model_input_size | |
model = model.to(device) | |
model.eval() | |
# warm up | |
with torch.no_grad(): | |
for _ in range(warmup_sample_num): | |
model(**dummy_input) | |
infer_time_list = [] | |
if device == 'cuda' or 'cuda' in str(device): | |
with torch.no_grad(): | |
for _ in range(sample_num): | |
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) | |
s.record() | |
model(**dummy_input) | |
e.record() | |
torch.cuda.synchronize() | |
cur_model_infer_time = s.elapsed_time(e) / 1000. | |
infer_time_list += [cur_model_infer_time] | |
else: | |
with torch.no_grad(): | |
for _ in range(sample_num): | |
start = time.time() | |
model(**dummy_input) | |
cur_model_infer_time = time.time() - start | |
infer_time_list += [cur_model_infer_time] | |
avg_infer_time = sum(infer_time_list) / sample_num | |
if return_detail: | |
return avg_infer_time, infer_time_list | |
return avg_infer_time |