import numpy as np import cv2 import os import math import torch from torch import nn import torch.nn.functional as F from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import torch.utils.checkpoint as checkpoint from functools import partial from einops import rearrange try: from flash_attn.modules.mlp import FusedMLP except: print(f'FusedMLP of flash_attn is not installed!!!') try: from flash_attn.ops.rms_norm import DropoutAddRMSNorm except: print(f'DropoutAddRMSNorm of flash_attn is not installed!!!') from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from flash_attn.bert_padding import unpad_input, pad_input class FlashAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): super().__init__() self.softmax_scale = softmax_scale self.dropout_p = attention_dropout def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None if unpadded: (nnz, 3, h, d) key_padding_mask: a bool tensor of shape (B, S) """ assert not need_weights assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.is_cuda if cu_seqlens is None: batch_size = qkv.shape[0] seqlen = qkv.shape[1] if key_padding_mask is None: qkv = rearrange(qkv, 'b s ... -> (b s) ...') max_s = seqlen cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) output = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) else: nheads = qkv.shape[-2] x = rearrange(qkv, 'b s three h d -> b s (three h d)') x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, batch_size, seqlen), 'b s (h d) -> b s h d', h=nheads) else: assert max_s is not None output = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) return output, None # -------------------------------------------------------- # 2D sine-cosine position embedding # References: # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py # MoCo v3: https://github.com/facebookresearch/moco-v3 # -------------------------------------------------------- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate( [np.zeros([1, embed_dim]), pos_embed], axis=0 ) return pos_embed def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False): """ t_size: int of the temporal size return: pos_embed: [t_size, embed_dim] or [1+t_size, embed_dim] (w/ or w/o cls_token) """ grid_t = np.arange(t_size, dtype=np.float32) pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t) if cls_token: pos_embed = np.concatenate( [np.zeros([1, embed_dim]), pos_embed], axis=0 ) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid( embed_dim // 2, grid[0] ) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid( embed_dim // 2, grid[1] ) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb def interpolate_pos_embed(checkpoint_model, model, orig_t_size=4, pos_name='vision_encoder.pos_embed'): if pos_name in checkpoint_model: pos_embed_checkpoint = checkpoint_model[pos_name] embedding_size = pos_embed_checkpoint.shape[-1] # channel dim num_patches = model.patch_embed.num_patches # num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1 # we use 4 frames for pretraining new_t_size = model.T # height (== width) for the checkpoint position embedding orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5) # height (== width) for the new position embedding new_size = int((num_patches // (new_t_size))** 0.5) # class_token and dist_token are kept unchanged if orig_t_size != new_t_size: print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})") extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1) pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size) pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size) pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear') pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size) pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model[pos_name] = new_pos_embed pos_embed_checkpoint = new_pos_embed # class_token and dist_token are kept unchanged if orig_size != new_size: print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})") extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] # B, L, C -> BT, H, W, C -> BT, C, H, W pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size) pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) pos_tokens = pos_tokens.flatten(1, 3) # B, L, C new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model[pos_name] = new_pos_embed def interpolate_pos_embed_internvideo2(checkpoint_model, model, orig_t_size = 8): # interpolate position embedding for pos_name in ['pos_embed', 'clip_pos_embed']: if pos_name in checkpoint_model: pos_embed_checkpoint = checkpoint_model[pos_name] embedding_size = pos_embed_checkpoint.shape[-1] # channel dim num_patches = model.patch_embed.num_patches # num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1 # we use 8 frames for pretraining # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size new_t_size = model.num_frames // model.tubelet_size # height (== width) for the checkpoint position embedding orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5) # height (== width) for the new position embedding new_size = int((num_patches // (new_t_size))** 0.5) # class_token and dist_token are kept unchanged if orig_t_size != new_t_size: print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})") extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1) pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size) pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size) pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear') pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size) pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model[pos_name] = new_pos_embed pos_embed_checkpoint = new_pos_embed # class_token and dist_token are kept unchanged if orig_size != new_size: print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})") extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] # B, L, C -> BT, H, W, C -> BT, C, H, W pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size) pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) pos_tokens = pos_tokens.flatten(1, 3) # B, L, C new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model[pos_name] = new_pos_embed if 'pos_embed_spatial' in checkpoint_model or 'pos_embed_temporal' in checkpoint_model: raise NotImplementedError def interpolate_pos_embed_internvideo2_new(checkpoint_model, model, orig_t_size = 8): pos_names = [] for k in checkpoint_model.keys(): if ('pos_embed' in k or 'clip_pos_embed' in k) and 'img_pos_embed' not in k: pos_names.append(k) print(f"pos names list for interpolating: {pos_names}") assert len(pos_names) > 0, checkpoint_model.keys() if 'pos_embed_spatial' in checkpoint_model.keys() or 'pos_embed_temporal' in checkpoint_model.keys(): raise NotImplementedError # interpolate position embedding for pos_name in pos_names: pos_embed_checkpoint = checkpoint_model[pos_name] embedding_size = pos_embed_checkpoint.shape[-1] # channel dim num_patches = model.patch_embed.num_patches # num_extra_tokens = model.pos_embed.shape[-2] - num_patches # 0/1 # we use 8 frames for pretraining # new_t_size = args.num_frames * args.num_segments // model.patch_embed.tubelet_size new_t_size = model.num_frames // model.tubelet_size # height (== width) for the checkpoint position embedding orig_size = int(((pos_embed_checkpoint.shape[-2] - num_extra_tokens)//(orig_t_size)) ** 0.5) # height (== width) for the new position embedding new_size = int((num_patches // (new_t_size))** 0.5) # class_token and dist_token are kept unchanged if orig_t_size != new_t_size: print(f"Temporal interpolate from {orig_t_size} to {new_t_size} ({pos_name})") extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] # B, L, C -> B, T, HW, C -> BHW, C, T (B = 1) pos_tokens = pos_tokens.view(1, orig_t_size, -1, embedding_size) pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size, orig_t_size) pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=new_t_size, mode='linear') pos_tokens = pos_tokens.view(1, -1, embedding_size, new_t_size) pos_tokens = pos_tokens.permute(0, 3, 1, 2).reshape(1, -1, embedding_size) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model[pos_name] = new_pos_embed pos_embed_checkpoint = new_pos_embed # class_token and dist_token are kept unchanged if orig_size != new_size: print(f"Position interpolate from {orig_size}x{orig_size} to {new_size}x{new_size} ({pos_name})") extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] # B, L, C -> BT, H, W, C -> BT, C, H, W pos_tokens = pos_tokens.reshape(-1, new_t_size, orig_size, orig_size, embedding_size) pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, new_t_size, new_size, new_size, embedding_size) pos_tokens = pos_tokens.flatten(1, 3) # B, L, C new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model[pos_name] = new_pos_embed def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size, cls_token=False): """ grid_size: int of the grid height and width t_size: int of the temporal size return: pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ assert embed_dim % 4 == 0 embed_dim_spatial = embed_dim // 4 * 3 embed_dim_temporal = embed_dim // 4 # spatial grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed_spatial = get_2d_sincos_pos_embed_from_grid( embed_dim_spatial, grid ) # temporal grid_t = np.arange(t_size, dtype=np.float32) pos_embed_temporal = get_1d_sincos_pos_embed_from_grid( embed_dim_temporal, grid_t ) # concate: [T, H, W] order pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] pos_embed_temporal = np.repeat( pos_embed_temporal, grid_size**2, axis=1 ) # [T, H*W, D // 4] pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] pos_embed_spatial = np.repeat( pos_embed_spatial, t_size, axis=0 ) # [T, H*W, D // 4 * 3] pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] if cls_token: pos_embed = np.concatenate( [np.zeros([1, embed_dim]), pos_embed], axis=0 ) return pos_embed class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class PatchEmbed(nn.Module): """ 3D Image to Patch Embedding """ def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=8, tubelet_size=1, norm_layer=None ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = ( num_frames // tubelet_size, img_size[0] // patch_size[0], img_size[1] // patch_size[1] ) # (T, H, W) self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] self.num_img_patches = self.grid_size[1] * self.grid_size[2] self.proj = nn.Conv3d( in_channels=in_chans, out_channels=embed_dim, kernel_size=(tubelet_size, patch_size[0], patch_size[1]), stride=(tubelet_size, patch_size[0], patch_size[1]) ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): x = self.proj(x) x = x.flatten(3).permute(0, 2, 3, 1) # B x C x T x HW => B x T x HW x C x = self.norm(x) return x class CrossAttention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., attn_head_dim=None, out_dim=None): super().__init__() if out_dim is None: out_dim = dim self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 assert all_head_dim == dim self.q = nn.Linear(dim, all_head_dim, bias=False) self.k = nn.Linear(dim, all_head_dim, bias=False) self.v = nn.Linear(dim, all_head_dim, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.k_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, out_dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, k=None, v=None): B, N, C = x.shape N_k = k.shape[1] N_v = v.shape[1] q_bias, k_bias, v_bias = None, None, None if self.q_bias is not None: q_bias = self.q_bias k_bias = self.k_bias v_bias = self.v_bias q = F.linear(input=x, weight=self.q.weight, bias=q_bias) q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim) k = F.linear(input=k, weight=self.k.weight, bias=k_bias) k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) v = F.linear(input=v, weight=self.v.weight, bias=v_bias) v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) q = q * self.scale attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class AttentiveBlock(nn.Module): def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None): super().__init__() self.norm1_q = norm_layer(dim) self.norm1_k = norm_layer(dim) self.norm1_v = norm_layer(dim) self.cross_attn = CrossAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None): x_q = self.norm1_q(x_q + pos_q) x_k = self.norm1_k(x_kv + pos_k) x_v = self.norm1_v(x_kv) x = self.cross_attn(x_q, k=x_k, v=x_v) return x class AttentionPoolingBlock(AttentiveBlock): def forward(self, x): x_q = x.mean(1, keepdim=True) x_kv, pos_q, pos_k = x, 0, 0 x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None) x = x.squeeze(1) return x class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False): super().__init__() self.inplace = inplace self.weight = nn.Parameter(init_values * torch.ones(dim)) self.force_fp32 = force_fp32 @torch.cuda.amp.autocast(enabled=False) def forward(self, x): if self.force_fp32: output_type = x.dtype out = x.float().mul_(self.weight.float()) if self.inplace else x.float() * self.weight.float() return out.to(dtype=output_type) else: out = x.mul_(self.weight) if self.inplace else x * self.weight return out class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False, causal=False, norm_layer=nn.LayerNorm, qk_normalization=False, use_fused_rmsnorm=False): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.use_flash_attn = use_flash_attn if use_flash_attn: self.causal = causal self.inner_attn = FlashAttention(attention_dropout=attn_drop) self.qk_normalization = qk_normalization self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity() self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity() self.use_fused_rmsnorm = use_fused_rmsnorm def _naive_attn(self, x): B, N, C = x.shape # print(x.shape, torch.cuda.memory_allocated(), torch.cuda.memory_allocated()) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) if self.qk_normalization: B_, H_, N_, D_ = q.shape q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) attn = ((q * self.scale) @ k.transpose(-2, -1)) # attn = attn - attn.max(-1)[0].unsqueeze(-1) # in case of overflow for fp16 attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # print(torch.cuda.memory_allocated(), torch.cuda.memory_allocated()) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x def _flash_attn(self, x, key_padding_mask=None, need_weights=False): qkv = self.qkv(x) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads) if self.qk_normalization: q, k, v = qkv.unbind(2) if self.use_fused_rmsnorm: q = self.q_norm(q.flatten(-2, -1))[0].view(q.shape) k = self.k_norm(k.flatten(-2, -1))[0].view(k.shape) else: q = self.q_norm(q.flatten(-2, -1)).view(q.shape) k = self.k_norm(k.flatten(-2, -1)).view(k.shape) qkv = torch.stack([q, k, v], dim=2) context, _ = self.inner_attn( qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal ) outs = self.proj(rearrange(context, "b s h d -> b s (h d)")) outs = self.proj_drop(outs) return outs def forward(self, x): x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x) return x class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, use_fused_mlp=False, fused_mlp_heuristic=1, with_cp=False, qk_normalization=False, layerscale_no_force_fp32=False, use_fused_rmsnorm=False): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer, qk_normalization=qk_normalization, use_fused_rmsnorm=use_fused_rmsnorm) self.ls1 = LayerScale(dim, init_values=init_values, force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity() # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) if use_fused_mlp: self.mlp = FusedMLP(in_features=dim, hidden_features=mlp_hidden_dim, heuristic=fused_mlp_heuristic) else: self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.ls2 = LayerScale(dim, init_values=init_values, force_fp32=(not layerscale_no_force_fp32)) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.with_cp = with_cp self.use_fused_rmsnorm = use_fused_rmsnorm def forward(self, x, residual=None): def _inner_forward(x, residual=None): if self.use_fused_rmsnorm: x, residual = self.norm1(x, residual) x = self.drop_path1(self.ls1(self.attn(x))) x, residual = self.norm2(x, residual) x = self.drop_path2(self.ls2(self.mlp(x))) return x, residual else: assert residual is None x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x if self.with_cp: return checkpoint.checkpoint(_inner_forward, x, residual) else: return _inner_forward(x, residual=residual) class Linear_Decoder(nn.Module): def __init__(self, in_channels=1408, out_channels=3200, norm_layer=nn.LayerNorm, clip_norm_type='l2'): super().__init__() self.clip_norm_type = clip_norm_type self.head = nn.Linear(in_channels, out_channels) self.norm = norm_layer(out_channels) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x): x = self.norm(self.head(x)) if self.clip_norm_type == 'l2': x = x / x.norm(dim=-1, keepdim=True) elif self.clip_norm_type == 'none': pass else: raise NotImplementedError return x class PretrainInternVideo2(nn.Module): def __init__( self, in_chans: int = 3, patch_size: int = 14, img_size: int = 224, qkv_bias: bool = False, drop_path_rate: float = 0.25, embed_dim: int = 1408, num_heads: int = 16, mlp_ratio: float = 48/11, init_values: float = 1e-5, qk_normalization: bool = True, depth: int = 40, use_flash_attn: bool = True, use_fused_rmsnorm: bool = True, use_fused_mlp: bool = True, fused_mlp_heuristic: int = 1, attn_pool_num_heads: int = 16, clip_embed_dim: int = 768, layerscale_no_force_fp32: bool = False, num_frames: int = 8, tubelet_size: int = 1, sep_pos_embed: bool = False, sep_image_video_pos_embed: bool = False, use_checkpoint: bool = False, checkpoint_num: int = 0, # for unmasked teacher clip_teacher_embed_dim: int = 3200, clip_teacher_final_dim: int = 768, # if 0, not distill final features clip_norm_type: str = 'l2', clip_return_layer: int = 1, clip_student_return_interval: int = 1, ): super().__init__() self.num_frames = num_frames self.tubelet_size = tubelet_size assert use_flash_attn == use_fused_rmsnorm == use_fused_mlp, 'use_flash_attn, use_fused_rmsnorm and use_fused_mlp should be consistent' self.use_flash_attn = use_flash_attn self.embed_dim = embed_dim self.depth = depth self.clip_norm_type = clip_norm_type self.return_index = [] for i in range(clip_return_layer): self.return_index.append(depth - int(i * clip_student_return_interval) - 1) if use_fused_rmsnorm: norm_layer_for_blocks = partial(DropoutAddRMSNorm, eps=1e-6, prenorm=True) else: norm_layer_for_blocks = partial(RMSNorm, eps=1e-6) self.norm_layer_for_blocks = norm_layer_for_blocks self.patch_embed = PatchEmbed( img_size, patch_size, in_chans, embed_dim, num_frames=num_frames, tubelet_size=tubelet_size, ) num_patches = self.patch_embed.num_patches num_img_patches = self.patch_embed.num_img_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # stolen from https://github.com/facebookresearch/mae_st/blob/dc072aaaf640d06892e23a33b42223a994efe272/models_vit.py#L65-L73C17 self.sep_pos_embed = sep_pos_embed self.sep_image_video_pos_embed = sep_image_video_pos_embed if sep_pos_embed: raise NotImplementedError else: if sep_image_video_pos_embed: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim)) # for CLIP decoder self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.clip_img_pos_embed = nn.Parameter(torch.zeros(1, num_img_patches + 1, embed_dim)) else: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.clip_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # choose which layer to use checkpoint with_cp_list = [False] * depth if use_checkpoint: for idx in range(depth): if idx < checkpoint_num: with_cp_list[idx] = True self.blocks = nn.ModuleList([ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer_for_blocks, drop_path=dpr[i], init_values=init_values, attn_drop=0., use_flash_attn=use_flash_attn, use_fused_mlp=use_fused_mlp, fused_mlp_heuristic=fused_mlp_heuristic, with_cp=with_cp_list[i], qk_normalization=qk_normalization, layerscale_no_force_fp32=layerscale_no_force_fp32, use_fused_rmsnorm=use_fused_rmsnorm) for i in range(depth)]) self.clip_projector = AttentionPoolingBlock( dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim) # CLIP decoder self.clip_decoder = nn.ModuleList([ Linear_Decoder( in_channels=embed_dim, out_channels=clip_teacher_embed_dim, norm_layer=partial(nn.LayerNorm, eps=1e-5), clip_norm_type=clip_norm_type ) for _ in range(clip_return_layer) ]) self.final_clip_decoder = nn.Identity() if clip_teacher_final_dim > 0: self.final_clip_decoder = Linear_Decoder( in_channels=clip_embed_dim, out_channels=clip_teacher_final_dim, norm_layer=partial(nn.LayerNorm, eps=1e-5), clip_norm_type=clip_norm_type ) self.init_pos_embed() trunc_normal_(self.cls_token, std=.02) self.apply(self._init_weights) self.fix_init_weight() def init_pos_embed(self): if self.sep_pos_embed: raise NotImplementedError else: # trunc_normal_(self.pos_embed, std=.02) # trunc_normal_(self.clip_pos_embed, std=.02) pos_embed = get_3d_sincos_pos_embed( self.pos_embed.shape[-1], self.patch_embed.grid_size[1], # height & weight self.patch_embed.grid_size[0], # t_size cls_token=True ) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) self.clip_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) if self.sep_image_video_pos_embed: img_pos_embed = get_3d_sincos_pos_embed( self.pos_embed.shape[-1], self.patch_embed.grid_size[1], # height & weight 1, cls_token=True ) self.img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0)) self.clip_img_pos_embed.data.copy_(torch.from_numpy(img_pos_embed).float().unsqueeze(0)) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) @property def dtype(self): return self.patch_embed.proj.weight.dtype def get_num_layers(self): return len(self.blocks) @torch.jit.ignore def no_weight_decay(self): return { 'pos_embed', 'pos_embed_spatial', 'pos_embed_temporal', 'pos_embed_cls', 'img_pos_embed', 'cls_token', 'clip_pos_embed', 'clip_pos_embed_spatial', 'clip_pos_embed_temporal', 'clip_pos_embed_cls', 'clip_img_pos_embed' } # @torch.cuda.amp.autocast(enabled=False) def forward(self, x, mask=None, use_image=False, x_vis_return_idx=-1, x_vis_only=False): x = self.patch_embed(x.type(self.dtype)) # print(f"x.shape: {x.shape} x.dtype: {x.dtype}, model.dtype: {self.dtype}") B, T, L, C = x.shape # T: temporal; L: spatial x = x.view([B, T * L, C]) # append cls token cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # add pos_embed if self.sep_pos_embed: raise NotImplementedError else: if use_image: if self.sep_image_video_pos_embed: pos_embed = self.img_pos_embed else: # (1, num_img_patches + 1, embed_dim) # print('origin pos_embed.shape:', self.pos_embed.shape) cls_pos_embed = self.pos_embed[:, 0:1, :] # print('cls_pos_embed.shape:', cls_pos_embed.shape) img_pos_embed = self.pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1) # print('img_pos_embed.shape:', img_pos_embed.shape) pos_embed = torch.cat([cls_pos_embed, img_pos_embed], dim=1) # print('final img_pos_embed.shape:', pos_embed.shape) else: pos_embed = self.pos_embed x = x + pos_embed # mask tokens, ~mask means visible if mask is not None: x = x[~mask].reshape(B, -1, C) else: x = x.reshape(B, -1, C) residual = None x_clip = [] for idx, blk in enumerate(self.blocks): if isinstance(x, tuple) and len(x) == 2: x, residual = x # print(f"\033[31m这是{idx}, {x.shape}\033[0m") x = blk(x, residual=residual) # return intermediate features if idx in self.return_index: if isinstance(x, tuple) and len(x) == 2: tmp_x, tmp_residual = x if residual is not None: x_clip.append(tmp_x + tmp_residual) else: x_clip.append(x) if idx == (self.depth + x_vis_return_idx): # print(f'idx = {idx} len(self.blocks)={len(self.blocks)}') break if isinstance(x, tuple) and len(x) == 2: x, residual = x if residual is not None: x = x + residual x_vis = x if x_vis_only: return x_vis x_pool_vis = self.clip_projector(x_vis) x_align = self.final_clip_decoder(x_pool_vis) # align CLIP x_clip = torch.stack(x_clip) K, B, _, C_CLIP = x_clip.shape # add pos_embed if self.sep_pos_embed: raise NotImplementedError else: if use_image: if self.sep_image_video_pos_embed: clip_pos_embed = self.clip_img_pos_embed else: # (1, num_img_patches + 1, embed_dim) # print('origin pos_embed.shape:', self.pos_embed.shape) clip_cls_pos_embed = self.clip_pos_embed[:, 0:1, :] # print('cls_pos_embed.shape:', cls_pos_embed.shape) clip_img_pos_embed = self.clip_pos_embed[:, 1:, :].view(1, self.num_frames, self.patch_embed.num_patches // self.num_frames, self.embed_dim).mean(dim=1) # print('img_pos_embed.shape:', img_pos_embed.shape) clip_pos_embed = torch.cat([clip_cls_pos_embed, clip_img_pos_embed], dim=1) # print('final img_pos_embed.shape:', pos_embed.shape) else: clip_pos_embed = self.clip_pos_embed clip_pos_embed = clip_pos_embed.repeat(B, 1, 1) if mask is not None: x_clip = x_clip + clip_pos_embed[~mask].view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1) else: x_clip = x_clip + clip_pos_embed.view(B, -1, C_CLIP).unsqueeze(0).repeat(K, 1, 1, 1) # CLIP decoder x_clip_align = [] for idx, clip_decoder in enumerate(self.clip_decoder): x_clip_align.append(clip_decoder(x_clip[idx])) x_clip_align = torch.stack(x_clip_align) return x_vis, x_pool_vis, x_clip_align, x_align def pretrain_internvideo2_1b_patch14_224(config): model = PretrainInternVideo2( in_chans=3, img_size=224, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, clip_embed_dim=config.vision_encoder.clip_embed_dim, attn_pool_num_heads=16, qkv_bias=False, drop_path_rate=0.25, init_values=0.00001, qk_normalization=True, use_flash_attn=config.vision_encoder.use_flash_attn, use_fused_rmsnorm=config.vision_encoder.use_fused_rmsnorm, use_fused_mlp=config.vision_encoder.use_fused_mlp, fused_mlp_heuristic=1, layerscale_no_force_fp32=False, num_frames=config.vision_encoder.num_frames, tubelet_size=config.vision_encoder.tubelet_size, sep_pos_embed=False, sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed, use_checkpoint=config.vision_encoder.use_checkpoint, checkpoint_num=config.vision_encoder.checkpoint_num, clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim, clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim, clip_norm_type=config.vision_encoder.clip_norm_type, clip_return_layer=config.vision_encoder.clip_return_layer, clip_student_return_interval=config.vision_encoder.clip_student_return_interval, ) return model def pretrain_internvideo2_6b_patch14_224(config): model = PretrainInternVideo2( in_chans=3, img_size=224, patch_size=14, embed_dim=3200, depth=48, num_heads=25, mlp_ratio=4, clip_embed_dim=config.vision_encoder.clip_embed_dim, attn_pool_num_heads=16, qkv_bias=False, drop_path_rate=0.3, init_values=0.00001, qk_normalization=True, use_flash_attn=config.vision_encoder.use_flash_attn, use_fused_rmsnorm=config.vision_encoder.use_fused_rmsnorm, use_fused_mlp=config.vision_encoder.use_fused_mlp, fused_mlp_heuristic=1, layerscale_no_force_fp32=False, num_frames=config.vision_encoder.num_frames, tubelet_size=config.vision_encoder.tubelet_size, sep_pos_embed=False, sep_image_video_pos_embed=config.vision_encoder.sep_image_video_pos_embed, use_checkpoint=config.vision_encoder.use_checkpoint, checkpoint_num=config.vision_encoder.checkpoint_num, clip_teacher_embed_dim=config.vision_encoder.clip_teacher_embed_dim, clip_teacher_final_dim=config.vision_encoder.clip_teacher_final_dim, clip_norm_type=config.vision_encoder.clip_norm_type, clip_return_layer=config.vision_encoder.clip_return_layer, clip_student_return_interval=config.vision_encoder.clip_student_return_interval, ) return model from dataclasses import dataclass from typing import Tuple, Optional, List from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import (PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer) from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput, ) from torch import Tensor, device from torch.nn import CrossEntropyLoss class BertConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the BERT [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: vocab_size (`int`, *optional*, defaults to 30522): Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. num_hidden_layers (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. intermediate_size (`int`, *optional*, defaults to 3072): Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported. hidden_dropout_prob (`float`, *optional*, defaults to 0.1): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): The dropout ratio for the attention probabilities. max_position_embeddings (`int`, *optional*, defaults to 512): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). type_vocab_size (`int`, *optional*, defaults to 2): The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (`float`, *optional*, defaults to 1e-12): The epsilon used by the layer normalization layers. position_embedding_type (`str`, *optional*, defaults to `"absolute"`): Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. classifier_dropout (`float`, *optional*): The dropout ratio for the classification head. Examples: ```python >>> from transformers import BertModel, BertConfig >>> # Initializing a BERT bert-base-uncased style configuration >>> configuration = BertConfig() >>> # Initializing a model from the bert-base-uncased style configuration >>> model = BertModel(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "bert" def __init__( self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None, cross_module="ca", **kwargs, ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout self.cross_module = cross_module def load_tf_weights_in_bert(model, config, tf_checkpoint_path): """Load tf checkpoints in a pytorch model.""" try: import re import numpy as np import tensorflow as tf except ImportError: print( "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " "https://www.tensorflow.org/install/ for installation instructions." ) raise tf_path = os.path.abspath(tf_checkpoint_path) print("Converting TensorFlow checkpoint from {}".format(tf_path)) # Load weights from TF model init_vars = tf.train.list_variables(tf_path) names = [] arrays = [] for name, shape in init_vars: print("Loading TF weight {} with shape {}".format(name, shape)) array = tf.train.load_variable(tf_path, name) names.append(name) arrays.append(array) for name, array in zip(names, arrays): name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if any( n in [ "adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step", ] for n in name ): print("Skipping {}".format("/".join(name))) continue pointer = model for m_name in name: if re.fullmatch(r"[A-Za-z]+_\d+", m_name): scope_names = re.split(r"_(\d+)", m_name) else: scope_names = [m_name] if scope_names[0] == "kernel" or scope_names[0] == "gamma": pointer = getattr(pointer, "weight") elif scope_names[0] == "output_bias" or scope_names[0] == "beta": pointer = getattr(pointer, "bias") elif scope_names[0] == "output_weights": pointer = getattr(pointer, "weight") elif scope_names[0] == "squad": pointer = getattr(pointer, "classifier") else: try: pointer = getattr(pointer, scope_names[0]) except AttributeError: print("Skipping {}".format("/".join(name))) continue if len(scope_names) >= 2: num = int(scope_names[1]) pointer = pointer[num] if m_name[-11:] == "_embeddings": pointer = getattr(pointer, "weight") elif m_name == "kernel": array = np.transpose(array) try: assert ( pointer.shape == array.shape ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" except AssertionError as e: e.args += (pointer.shape, array.shape) raise print("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) return model class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size ) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) ) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0, ): if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] if position_ids is None: position_ids = self.position_ids[ :, past_key_values_length : seq_length + past_key_values_length ] if token_type_ids is None: token_type_ids = torch.zeros( input_shape, dtype=torch.long, device=self.position_ids.device ) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertSelfAttention(nn.Module): def __init__(self, config, is_cross_attention): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0 and not hasattr( config, "embedding_size" ): raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) if is_cross_attention: self.key = nn.Linear(config.encoder_width, self.all_head_size) self.value = nn.Linear(config.encoder_width, self.all_head_size) else: self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") if ( self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query" ): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding( 2 * config.max_position_embeddings - 1, self.attention_head_size ) self.save_attention = False def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients def get_attn_gradients(self): return self.attn_gradients def save_attention_map(self, attention_map): self.attention_map = attention_map def get_attention_map(self): return self.attention_map def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if ( self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query" ): seq_length = hidden_states.size()[1] position_ids_l = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device ).view(-1, 1) position_ids_r = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device ).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding( distance + self.max_position_embeddings - 1 ) positional_embedding = positional_embedding.to( dtype=query_layer.dtype ) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding ) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum( "bhld,lrd->bhlr", query_layer, positional_embedding ) relative_position_scores_key = torch.einsum( "bhrd,lrd->bhlr", key_layer, positional_embedding ) attention_scores = ( attention_scores + relative_position_scores_query + relative_position_scores_key ) attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if is_cross_attention and self.save_attention: self.save_attention_map(attention_probs) attention_probs.register_hook(self.save_attn_gradients) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs_dropped = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) # added `attention_scores` to return tuple outputs = ( (context_layer, attention_probs, attention_scores) if output_attentions else (context_layer,) ) outputs = outputs + (past_key_value,) return outputs class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() self.self = BertSelfAttention(config, is_cross_attention) self.output = BertSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads, ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): self_outputs = self.self( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] return outputs # (context_layer, attention_probs, attention_scores, past_key_value,) class BertIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertLayer(nn.Module): def __init__(self, config, layer_num): super().__init__() self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = BertAttention(config) self.has_cross_attention = layer_num >= config.fusion_layer if self.has_cross_attention: self.crossattention = BertAttention(config, is_cross_attention=True) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, ) # (context_layer, attention_probs, attention_scores, past_key_value,) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] if self.has_cross_attention: assert ( encoder_hidden_states is not None ), "encoder_hidden_states must be given for cross-attention layers" if type(encoder_hidden_states) == list: cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states[ (self.layer_num - self.config.fusion_layer) % len(encoder_hidden_states) ], encoder_attention_mask[ (self.layer_num - self.config.fusion_layer) % len(encoder_hidden_states) ], output_attentions=output_attentions, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] else: cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions=output_attentions, ) # (context_layer, attention_probs, attention_scores, past_key_value,) attention_output = cross_attention_outputs[0] # add cross attentions if we output attention weights outputs = outputs + cross_attention_outputs[1:-1] layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) outputs = (layer_output,) + outputs outputs = outputs + (present_key_value,) return outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList( [BertLayer(config, i) for i in range(config.num_hidden_layers)] ) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, mode="multi_modal", normalize_attention=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None # all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None if ( mode == "text" or mode == "temporal" ): # temporal is added and used for temporal att module. start_layer = 0 output_layer = self.config.fusion_layer elif mode == "fusion": start_layer = self.config.fusion_layer output_layer = self.config.num_hidden_layers elif mode == "multi_modal": start_layer = 0 output_layer = self.config.num_hidden_layers for i in range(start_layer, output_layer): layer_module = self.layer[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: print( "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, use_reentrant=False, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) # (context_layer, attention_probs, attention_scores, past_key_value,) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1],) if output_attentions: # whether to output normalized attention, # note for unnormalized attention, there is a mask added offset = int(normalize_attention) # all_self_attentions = all_self_attentions + (layer_outputs[1], ) all_self_attentions = all_self_attentions + (layer_outputs[2 - offset],) if hasattr(layer_module, "crossattention"): # all_cross_attentions = all_cross_attentions + (layer_outputs[3], ) all_cross_attentions = all_cross_attentions + (layer_outputs[4 - offset],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class BertPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) if isinstance(config.hidden_act, str): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertOnlyMLMHead(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) return prediction_scores class BertOnlyNSPHead(nn.Module): def __init__(self, config): super().__init__() self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, pooled_output): seq_relationship_score = self.seq_relationship(pooled_output) return seq_relationship_score class BertPreTrainingHeads(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, sequence_output, pooled_output): prediction_scores = self.predictions(sequence_output) seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score class BertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = BertConfig load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class BertModel(BertPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in `Attention is all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an input to the forward pass. """ def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def get_extended_attention_mask( self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (:obj:`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (:obj:`Tuple[int]`): The shape of the input to the model. device: (:obj:`torch.device`): The device of the input to the model. Returns: :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if is_decoder: batch_size, seq_length = input_shape seq_ids = torch.arange(seq_length, device=device) causal_mask = ( seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] ) # in case past_key_values are used we need to add a prefix ones mask to the causal mask # causal and attention masks must have same type with pytorch version < 1.3 causal_mask = causal_mask.to(attention_mask.dtype) if causal_mask.shape[1] < attention_mask.shape[1]: prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] causal_mask = torch.cat( [ torch.ones( (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=-1, ) extended_attention_mask = ( causal_mask[:, None, :, :] * attention_mask[:, None, None, :] ) else: extended_attention_mask = attention_mask[:, None, None, :] else: raise ValueError( "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( input_shape, attention_mask.shape ) ) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to( dtype=self.dtype ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, is_decoder=False, mode="multi_modal", normalize_attention=True, ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() batch_size, seq_length = input_shape device = input_ids.device elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size, seq_length = input_shape device = inputs_embeds.device elif encoder_embeds is not None: input_shape = encoder_embeds.size()[:-1] batch_size, seq_length = input_shape device = encoder_embeds.device else: raise ValueError( "You have to specify either input_ids or inputs_embeds or encoder_embeds" ) # past_key_values_length past_key_values_length = ( past_key_values[0][0].shape[2] if past_key_values is not None else 0 ) if attention_mask is None: attention_mask = torch.ones( ((batch_size, seq_length + past_key_values_length)), device=device ) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( attention_mask, input_shape, device, is_decoder ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: if type(encoder_hidden_states) == list: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ 0 ].size() else: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [ self.invert_attention_mask(mask) for mask in encoder_attention_mask ] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask ) else: encoder_extended_attention_mask = self.invert_attention_mask( encoder_attention_mask ) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) if encoder_embeds is None: embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) else: embedding_output = encoder_embeds encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, mode=mode, normalize_attention=normalize_attention, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) @dataclass class MaskedLMOutputWithDistill(MaskedLMOutput): loss_aux: Optional[torch.FloatTensor] = None loss_distill: Optional[torch.FloatTensor] = None class BertForMaskedLM(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def tie_aux_decoder_weights(self, module, aux_modules): """Tie decoder weights of all `aux_modules` to `module`, (not bias)""" for m in aux_modules: m.predictions.decoder.weight = module.predictions.decoder.weight def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, is_decoder=False, mode="multi_modal", normalize_attention=True, soft_labels=None, alpha=0, return_logits=False, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_embeds=encoder_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, mode=mode, normalize_attention=normalize_attention, ) sequence_output = outputs[0] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores masked_lm_loss = None masked_lm_loss_aux = 0.0 if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) ) if soft_labels is not None: loss_distill = -torch.sum( F.log_softmax(prediction_scores, dim=1) * soft_labels, dim=-1 ) loss_distill = loss_distill[labels != -100].mean() masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill if not return_dict: output = (prediction_scores,) + outputs[2:] return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output # changed from MaskedLMOutput to MaskedLMOutputWithDistill return MaskedLMOutputWithDistill( loss=masked_lm_loss, loss_aux=masked_lm_loss_aux, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): input_shape = input_ids.shape effective_batch_size = input_shape[0] # add a dummy token assert ( self.config.pad_token_id is not None ), "The PAD token should be defined for generation" attention_mask = torch.cat( [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1 ) dummy_token = torch.full( (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device, ) input_ids = torch.cat([input_ids, dummy_token], dim=1) return {"input_ids": input_ids, "attention_mask": attention_mask} def build_bert(model_config, pretrain, checkpoint, encoder_width=None): """build text encoder. Args: model_config (dict): model config. pretrain (bool): Whether to do pretrain or finetuning. checkpoint (bool): whether to do gradient_checkpointing. Returns: TODO """ bert_config = BertConfig.from_json_file(model_config.text_encoder.config) if encoder_width is None: bert_config.encoder_width = model_config.vision_encoder.d_model else: bert_config.encoder_width = encoder_width bert_config.gradient_checkpointing = checkpoint bert_config.fusion_layer = model_config.text_encoder.fusion_layer if not model_config.multimodal.enable: bert_config.fusion_layer = bert_config.num_hidden_layers if pretrain: try: text_encoder, loading_info = BertForMaskedLM.from_pretrained( model_config.text_encoder.pretrained, config=bert_config, output_loading_info=True, local_files_only=True ) except: text_encoder, loading_info = BertForMaskedLM.from_pretrained( model_config.text_encoder.pretrained, config=bert_config, output_loading_info=True, local_files_only=False ) else: try: text_encoder, loading_info = BertModel.from_pretrained( model_config.text_encoder.pretrained, config=bert_config, add_pooling_layer=False, output_loading_info=True, local_files_only=True ) except: text_encoder, loading_info = BertModel.from_pretrained( model_config.text_encoder.pretrained, config=bert_config, add_pooling_layer=False, output_loading_info=True, local_files_only=False ) return text_encoder def get_sim( vision_proj: torch.Tensor, text_proj: torch.Tensor, temp=1.0, agg_method="mean", ): """calculate pair-wise video-text similarity. Args: vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. text_proj (torch.Tensor): The text representation. Shape: [B,C]. temp (torch.Tensor): The temperature. Shape: []. Returns: The similarity between video and text. Shape: [B,B]. """ vision_proj = F.normalize(vision_proj, dim=-1) text_proj = F.normalize(text_proj, dim=-1) if vision_proj.ndim == 3: sim_v2t = torch.einsum("mld,nd->mln", vision_proj, text_proj) / temp # [B, L, B] sim_t2v = torch.einsum("nd,mld->nlm", text_proj, vision_proj) / temp # [B, L, B] if agg_method == "mean": sim_v2t = sim_v2t.mean(1) sim_t2v = sim_t2v.mean(1) elif agg_method == "max": sim_v2t = sim_v2t.max(1)[0] sim_t2v = sim_t2v.max(1)[0] elif text_proj.ndim == 3: sim_v2t = torch.einsum("nd,mld->nlm", vision_proj, text_proj) / temp # [B, L, B] sim_t2v = torch.einsum("nld,md->nlm", text_proj, vision_proj) / temp # [B, L, B] if agg_method == "mean": sim_v2t = sim_v2t.mean(1) sim_t2v = sim_t2v.mean(1) elif agg_method == "max": sim_v2t = sim_v2t.max(1)[0] sim_t2v = sim_t2v.max(1)[0] else: sim_v2t = vision_proj @ text_proj.T / temp sim_t2v = sim_v2t.T return sim_v2t, sim_t2v VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt", "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt", "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt", "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt", "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt", "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt", "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt", "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt", "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt", "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt", "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt", "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt", "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt", } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { "bert-base-uncased": 512, "bert-large-uncased": 512, "bert-base-cased": 512, "bert-large-cased": 512, "bert-base-multilingual-uncased": 512, "bert-base-multilingual-cased": 512, "bert-base-chinese": 512, "bert-base-german-cased": 512, "bert-large-uncased-whole-word-masking": 512, "bert-large-cased-whole-word-masking": 512, "bert-large-uncased-whole-word-masking-finetuned-squad": 512, "bert-large-cased-whole-word-masking-finetuned-squad": 512, "bert-base-cased-finetuned-mrpc": 512, "bert-base-german-dbmdz-cased": 512, "bert-base-german-dbmdz-uncased": 512, "TurkuNLP/bert-base-finnish-cased-v1": 512, "TurkuNLP/bert-base-finnish-uncased-v1": 512, "wietsedv/bert-base-dutch-cased": 512, } PRETRAINED_INIT_CONFIGURATION = { "bert-base-uncased": {"do_lower_case": True}, "bert-large-uncased": {"do_lower_case": True}, "bert-base-cased": {"do_lower_case": False}, "bert-large-cased": {"do_lower_case": False}, "bert-base-multilingual-uncased": {"do_lower_case": True}, "bert-base-multilingual-cased": {"do_lower_case": False}, "bert-base-chinese": {"do_lower_case": False}, "bert-base-german-cased": {"do_lower_case": False}, "bert-large-uncased-whole-word-masking": {"do_lower_case": True}, "bert-large-cased-whole-word-masking": {"do_lower_case": False}, "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, "bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, "bert-base-german-dbmdz-cased": {"do_lower_case": False}, "bert-base-german-dbmdz-uncased": {"do_lower_case": True}, "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, "wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, } import collections import unicodedata from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace def load_vocab(vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() with open(vocab_file, "r", encoding="utf-8") as reader: tokens = reader.readlines() for index, token in enumerate(tokens): token = token.rstrip("\n") vocab[token] = index return vocab def whitespace_tokenize(text): """Runs basic whitespace cleaning and splitting on a piece of text.""" text = text.strip() if not text: return [] tokens = text.split() return tokens class BasicTokenizer(object): """ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). Args: do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to lowercase the input when tokenizing. never_split (:obj:`Iterable`, `optional`): Collection of tokens which will never be split during tokenization. Only has an effect when :obj:`do_basic_tokenize=True` tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this `issue `__). strip_accents: (:obj:`bool`, `optional`): Whether or not to strip all accents. If this option is not specified, then it will be determined by the value for :obj:`lowercase` (as in the original BERT). """ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): if never_split is None: never_split = [] self.do_lower_case = do_lower_case self.never_split = set(never_split) self.tokenize_chinese_chars = tokenize_chinese_chars self.strip_accents = strip_accents def tokenize(self, text, never_split=None): """ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. Args: **never_split**: (`optional`) list of str Kept for backward compatibility purposes. Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) List of token not to split. """ # union() returns a new set by concatenating the two sets. never_split = self.never_split.union( set(never_split)) if never_split else self.never_split text = self._clean_text(text) # This was added on November 1st, 2018 for the multilingual and Chinese # models. This is also applied to the English models now, but it doesn't # matter since the English models were not trained on any Chinese data # and generally don't have any Chinese data in them (there are Chinese # characters in the vocabulary because Wikipedia does have some Chinese # words in the English Wikipedia.). if self.tokenize_chinese_chars: text = self._tokenize_chinese_chars(text) orig_tokens = whitespace_tokenize(text) split_tokens = [] for token in orig_tokens: if token not in never_split: if self.do_lower_case: token = token.lower() if self.strip_accents is not False: token = self._run_strip_accents(token) elif self.strip_accents: token = self._run_strip_accents(token) split_tokens.extend(self._run_split_on_punc(token, never_split)) output_tokens = whitespace_tokenize(" ".join(split_tokens)) return output_tokens def _run_strip_accents(self, text): """Strips accents from a piece of text.""" text = unicodedata.normalize("NFD", text) output = [] for char in text: cat = unicodedata.category(char) if cat == "Mn": continue output.append(char) return "".join(output) def _run_split_on_punc(self, text, never_split=None): """Splits punctuation on a piece of text.""" if never_split is not None and text in never_split: return [text] chars = list(text) i = 0 start_new_word = True output = [] while i < len(chars): char = chars[i] if _is_punctuation(char): output.append([char]) start_new_word = True else: if start_new_word: output.append([]) start_new_word = False output[-1].append(char) i += 1 return ["".join(x) for x in output] def _tokenize_chinese_chars(self, text): """Adds whitespace around any CJK character.""" output = [] for char in text: cp = ord(char) if self._is_chinese_char(cp): output.append(" ") output.append(char) output.append(" ") else: output.append(char) return "".join(output) def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # # Note that the CJK Unicode block is NOT all Japanese and Korean characters, # despite its name. The modern Korean Hangul alphabet is a different block, # as is Japanese Hiragana and Katakana. Those alphabets are used to write # space-separated words, so they are not treated specially and handled # like the all of the other languages. if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) # or (cp >= 0x20000 and cp <= 0x2A6DF) # or (cp >= 0x2A700 and cp <= 0x2B73F) # or (cp >= 0x2B740 and cp <= 0x2B81F) # or (cp >= 0x2B820 and cp <= 0x2CEAF) # or (cp >= 0xF900 and cp <= 0xFAFF) or (cp >= 0x2F800 and cp <= 0x2FA1F) # ): # return True return False def _clean_text(self, text): """Performs invalid character removal and whitespace cleanup on text.""" output = [] for char in text: cp = ord(char) if cp == 0 or cp == 0xFFFD or _is_control(char): continue if _is_whitespace(char): output.append(" ") else: output.append(char) return "".join(output) class WordpieceTokenizer(object): """Runs WordPiece tokenization.""" def __init__(self, vocab, unk_token, max_input_chars_per_word=100): self.vocab = vocab self.unk_token = unk_token self.max_input_chars_per_word = max_input_chars_per_word def tokenize(self, text): """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization using the given vocabulary. For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. Args: text: A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer`. Returns: A list of wordpiece tokens. """ output_tokens = [] for token in whitespace_tokenize(text): chars = list(token) if len(chars) > self.max_input_chars_per_word: output_tokens.append(self.unk_token) continue is_bad = False start = 0 sub_tokens = [] while start < len(chars): end = len(chars) cur_substr = None while start < end: substr = "".join(chars[start:end]) if start > 0: substr = "##" + substr if substr in self.vocab: cur_substr = substr break end -= 1 if cur_substr is None: is_bad = True break sub_tokens.append(cur_substr) start = end if is_bad: output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) return output_tokens class BertTokenizer(PreTrainedTokenizer): r""" Construct a BERT tokenizer. Based on WordPiece. This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab_file (:obj:`str`): File containing the vocabulary. do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to lowercase the input when tokenizing. do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to do basic tokenization before WordPiece. never_split (:obj:`Iterable`, `optional`): Collection of tokens which will never be split during tokenization. Only has an effect when :obj:`do_basic_tokenize=True` unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens. pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): The token used for padding, for example when batching sequences of different lengths. cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens. mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): The token used for masking values. This is the token used when training this model with masked language modeling. This is the token which the model will try to predict. tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this `issue `__). strip_accents: (:obj:`bool`, `optional`): Whether or not to strip all accents. If this option is not specified, then it will be determined by the value for :obj:`lowercase` (as in the original BERT). """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES def __init__( self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]", tokenize_chinese_chars=True, strip_accents=None, **kwargs ): if not os.path.isfile(vocab_file): raise ValueError( "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format( vocab_file) ) self.vocab = load_vocab(vocab_file) super().__init__( do_lower_case=do_lower_case, do_basic_tokenize=do_basic_tokenize, never_split=never_split, unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, tokenize_chinese_chars=tokenize_chinese_chars, strip_accents=strip_accents, **kwargs, ) self.ids_to_tokens = collections.OrderedDict( [(ids, tok) for tok, ids in self.vocab.items()]) self.do_basic_tokenize = do_basic_tokenize if do_basic_tokenize: self.basic_tokenizer = BasicTokenizer( do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars, strip_accents=strip_accents, ) self.wordpiece_tokenizer = WordpieceTokenizer( vocab=self.vocab, unk_token=self.unk_token) @property def do_lower_case(self): return self.basic_tokenizer.do_lower_case @property def vocab_size(self): return len(self.vocab) def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) def _tokenize(self, text): split_tokens = [] if self.do_basic_tokenize: for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) else: split_tokens += self.wordpiece_tokenizer.tokenize(token) else: split_tokens = self.wordpiece_tokenizer.tokenize(text) return split_tokens def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ return self.vocab.get(token, self.vocab.get(self.unk_token)) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.ids_to_tokens.get(index, self.unk_token) def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ out_string = " ".join(tokens).replace(" ##", "").strip() return out_string def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format: - single sequence: ``[CLS] X `` - pair of sequences: ``[CLS] A [SEP] B [SEP]`` Args: token_ids_0 (:obj:`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (:obj:`List[int]`, `optional`): Optional second list of IDs for sequence pairs. Returns: :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. """ if token_ids_1 is None: return [self.cls_token_id] + token_ids_0 cls = [self.cls_token_id] sep = [self.sep_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` method. Args: token_ids_0 (:obj:`List[int]`): List of IDs. token_ids_1 (:obj:`List[int]`, `optional`): Optional second list of IDs for sequence pairs. already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not the token list is already formatted with special tokens for the model. Returns: :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: if token_ids_1 is not None: raise ValueError( "You should not supply a second sequence if the provided sequence of " "ids is already formatted with special tokens for the model." ) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence pair mask has the following format: :: 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). Args: token_ids_0 (:obj:`List[int]`): List of IDs. token_ids_1 (:obj:`List[int]`, `optional`): Optional second list of IDs for sequence pairs. Returns: :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given sequence(s). """ sep = [self.sep_token_id] cls = [self.cls_token_id] if token_ids_1 is None: return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: index = 0 if os.path.isdir(save_directory): vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) else: vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory with open(vocab_file, "w", encoding="utf-8") as writer: for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): if index != token_index: print( "Saving vocabulary to {}: vocabulary indices are not consecutive." " Please check that the vocabulary is not corrupted!".format( vocab_file) ) index = token_index writer.write(token + "\n") index += 1 return (vocab_file,) from huggingface_hub import PyTorchModelHubMixin def _frame_from_video(video): while video.isOpened(): success, frame = video.read() if success: yield frame else: break v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3) v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3) def normalize(data): return (data/255.0-v_mean)/v_std def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')): assert(len(vid_list) >= fnum) step = len(vid_list) // fnum vid_list = vid_list[::step][:fnum] vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list] vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list] vid_tube = np.concatenate(vid_tube, axis=1) vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3)) vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float() return vid_tube def vid2tensor(path: str, fnum: int=8, target_size: tuple=(224, 224), device=torch.device('cuda')): video = cv2.VideoCapture(path) frames = [x for x in _frame_from_video(video)] return frames2tensor(frames, fnum, target_size, device) def get_text_feat_dict(texts, clip, text_feat_d={}): for t in texts: feat = clip.get_txt_feat(t) text_feat_d[t] = feat return text_feat_d def get_vid_feat(frames, vlm): return vlm.get_vid_features(frames) def retrieve_text(frames, texts, model, topk:int=5, device=torch.device('cuda')): vlm = model.to(device) config = vlm.config fn = config.num_frames size_t = config.size_t frames_tensor = frames2tensor(frames, fnum=fn, target_size=(size_t, size_t), device=device) vid_feat = vlm.get_vid_feat(frames_tensor) text_feat_d = {} text_feat_d = get_text_feat_dict(texts, vlm, text_feat_d) text_feats = [text_feat_d[t] for t in texts] text_feats_tensor = torch.cat(text_feats, 0) probs, idxs = vlm.predict_label(vid_feat, text_feats_tensor, top=topk) ret_texts = [texts[i] for i in idxs.long().numpy()[0].tolist()] return ret_texts, probs.float().numpy()[0] def setup_internvideo2(config): model = InternVideo2_Stage2(config=config, is_pretrain=True) torch.set_float32_matmul_precision('high') model = torch.compile(model) model = model.to(torch.device(config.device)) model_without_ddp = model if (config.pretrained_path.strip() and (os.path.isfile(config.pretrained_path)) or "s3://" in config.pretrained_path): checkpoint = torch.load(config.pretrained_path, map_location="cpu") try: if "model" in checkpoint.keys(): state_dict = checkpoint["model"] else: state_dict = checkpoint["module"] # This is a deepspeed stage 1 model except: state_dict = checkpoint # if config.get('origin_num_frames', None) is not None: a = len(state_dict) interpolate_pos_embed_internvideo2_new(state_dict, model_without_ddp.vision_encoder, orig_t_size=config.origin_num_frames) assert a == len(state_dict), state_dict.keys() msg = model_without_ddp.load_state_dict(state_dict, strict=False) model_without_ddp = model_without_ddp.to(torch.float32) return model_without_ddp.eval() class DictToClass: def __init__(self, data): for key, value in data.items(): key = str(key) if isinstance(value, dict): setattr(self, key, DictToClass(value)) elif isinstance(value, list): setattr(self, key, [ DictToClass(item) if isinstance(item, dict) else item for item in value ]) else: setattr(self, key, value) def __repr__(self): """方便调试的对象表示""" attrs = ', '.join(f"{k}={v!r}" for k, v in self.__dict__.items()) return f"{self.__class__.__name__}({attrs})" def instance2dict(obj): """将类实例及其嵌套属性转换为字典""" if isinstance(obj, (str, int, float, bool, type(None))): # 基本类型直接返回 return obj elif isinstance(obj, dict): # 字典类型递归处理值 return {k: instance2dict(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple, set)): # 可迭代类型递归处理元素 return type(obj)(instance2dict(item) for item in obj) elif hasattr(obj, '__dict__'): # 类实例处理 result = {} for key, value in obj.__dict__.items(): # 过滤私有属性(可选) if not key.startswith('_'): result[key] = instance2dict(value) return result else: # 其他不可序列化类型直接返回 return str(obj) # 或者根据需求抛出异常 class InternVideo2_Stage2_Config(PretrainedConfig): _auto_class='AutoConfig' def __init__(self, **kwargs): super().__init__(**kwargs) class InternVideo2_Stage2( PreTrainedModel, ): """docstring for InternVideo2_Stage2""" _auto_class="AutoModel" config_class=InternVideo2_Stage2_Config def __init__(self, config: InternVideo2_Stage2_Config, is_pretrain: bool=True): super(InternVideo2_Stage2, self).__init__(config) config = config.to_dict() self._config = DictToClass(config) if isinstance(config, dict) else config self.tokenizer = BertTokenizer.from_pretrained(self._config.model.text_encoder.pretrained, local_files_only=True, use_safetensors=True) self.is_pretrain = is_pretrain self.vision_width = self._config.model.vision_encoder.clip_embed_dim self.text_width = self._config.model.text_encoder.d_model self.embed_dim = self._config.model.embed_dim # create modules. self.vision_encoder = self.build_vision_encoder() self.text_encoder = self.build_text_encoder() self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) self.text_proj = nn.Linear(self.text_width, self.embed_dim) def freeze_vision(self): """freeze vision encoder""" for p in self.vision_encoder.parameters(): p.requires_grad = False def freeze_text(self): """freeze text encoder""" for p in self.text_encoder.parameters(): p.requires_grad = False @property def dtype(self): return self.vision_encoder.patch_embed.proj.weight.dtype def encode_vision(self, image: torch.Tensor, test: bool=False): """encode image / videos as features. Args: image (torch.Tensor): The input images. test (bool): Whether testing. Returns: tuple. - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. - student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C]. - clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C]. """ T = image.shape[1] use_image = True if T == 1 else False image = image.permute(0, 2, 1, 3, 4).to(self.dtype) # [B,T,C,H,W] -> [B,C,T,H,W] # whether save temporal dimension # keep_temporal=self._config.model.vision_encoder.keep_temporal if test: vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder( image, None, use_image) return vision_embeds, pooled_vision_embeds else: mask, targets_clip_middle_vis, targets_clip_final_vis = self.encode_teacher(image) # if mask is not None and (self.video_mask_type != 'tube' or self.image_mask_type != 'tube'): # keep_temporal = False # print(f"\033[31mmask is {type(mask)}\033[0m") vision_embeds, pooled_vision_embeds, student_output, student_output_final = self.vision_encoder( image, mask, use_image) return vision_embeds, pooled_vision_embeds, student_output, student_output_final, targets_clip_middle_vis, targets_clip_final_vis def encode_text(self, text: dict): """encode text. Args: text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys: - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L]. - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token. - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__". Returns: tuple. - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C]. - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C]. """ text_output = self.get_text_encoder()( text.input_ids, attention_mask=text.attention_mask, return_dict=True, mode="text", ) text_embeds = text_output.last_hidden_state pooled_text_embeds = text_embeds[:, 0] return text_embeds, pooled_text_embeds def build_vision_encoder(self): """build vision encoder Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`. """ encoder_name = self._config.model.vision_encoder.name if encoder_name == 'pretrain_internvideo2_1b_patch14_224': vision_encoder = pretrain_internvideo2_1b_patch14_224(self._config.model) elif encoder_name == 'pretrain_internvideo2_6b_patch14_224': vision_encoder = pretrain_internvideo2_6b_patch14_224(self._config.model) else: raise ValueError(f"Not implemented: {encoder_name}") # parameters for mask img_size = self._config.model.vision_encoder.img_size num_frames = self._config.model.vision_encoder.num_frames tublet_size = self._config.model.vision_encoder.tubelet_size patch_size = self._config.model.vision_encoder.patch_size self.clip_img_size = self._config.model.vision_encoder.clip_input_resolution self.video_mask_type = self._config.model.vision_encoder.video_mask_type self.video_window_size = (num_frames // tublet_size, img_size // patch_size, img_size // patch_size) self.video_mask_ratio = self._config.model.vision_encoder.video_mask_ratio self.image_mask_type = self._config.model.vision_encoder.image_mask_type self.image_window_size = (1, img_size // patch_size, img_size // patch_size) self.image_mask_ratio = self._config.model.vision_encoder.image_mask_ratio return vision_encoder def build_text_encoder(self): """build text_encoder and possiblly video-to-text multimodal fusion encoder. Returns: nn.Module. The text encoder """ encoder_name = self._config.model.text_encoder.name if "bert" in encoder_name: text_encoder = build_bert( self._config.model, self.is_pretrain, self._config.gradient_checkpointing, ) else: raise ValueError(f"Not implemented: {encoder_name}") return text_encoder def get_text_encoder(self): """get text encoder, used for text and cross-modal encoding""" encoder = self.text_encoder return encoder.bert if hasattr(encoder, "bert") else encoder def get_vid_feat(self, frames: torch.Tensor): """get the video features for the given frames. Args: frames (torch.Tensor): The input frames. Shape: [B,T,C,H,W]. Returns: tuple. - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. """ with torch.no_grad(): _, vfeat = self.encode_vision(frames, test=True) vfeat = self.vision_proj(vfeat) vfeat /= vfeat.norm(dim=-1, keepdim=True) return vfeat def get_txt_feat(self, text: str): """get the text features for the given text.""" with torch.no_grad(): text = self.tokenizer( text, padding="max_length", truncation=True, max_length=self._config.max_txt_l, return_tensors="pt",).to(self._config.device) _, tfeat = self.encode_text(text) tfeat = self.text_proj(tfeat) tfeat /= tfeat.norm(dim=-1, keepdim=True) return tfeat def predict_label(self, vid_feat: torch.Tensor, txt_feat: torch.Tensor, top: int=5): label_probs = (100.0 * vid_feat @ txt_feat.T).softmax(dim=-1) top_probs, top_labels = label_probs.float().cpu().topk(top, dim=-1) return top_probs, top_labels