File size: 4,801 Bytes
d4e7f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from torch import nn
import torch
import torch.utils.checkpoint as checkpoint
from .model import VisualTransformer
class VITClassificationHeadV0(nn.Module):
def __init__(
self,
num_features: int,
channel: int,
num_labels: int,
norm=False,
dropout=0.0,
ret_feat=False,
):
super().__init__()
self.weights = nn.Parameter(
torch.ones(1, num_features * 3, 1, dtype=torch.float32)
)
self.final_fc = nn.Linear(channel, num_labels)
self.norm = norm
if self.norm:
for i in range(num_features * 3):
setattr(self, f"norm_{i}", nn.LayerNorm(channel))
self.dropout = nn.Dropout(p=dropout)
self.ret_feat = ret_feat
def forward(self, features, cls_tokens):
xs = []
for feature, cls_token in zip(features, cls_tokens):
# feature: b x c x s x s
# cls_token: b x c
xs.append(feature.mean([2, 3]))
xs.append(feature.max(-1).values.max(-1).values)
xs.append(cls_token)
if self.norm:
xs = [getattr(self, f"norm_{i}")(x) for i, x in enumerate(xs)]
xs = torch.stack(xs, dim=1) # b x 3N x c
feat = (xs * self.weights.softmax(dim=1)).sum(1) # b x c
x = self.dropout(feat)
x = self.final_fc(x) # b x num_labels
if self.ret_feat:
return x, feat
else:
return x
class FACTransformer(nn.Module):
"""A face attribute classification transformer leveraging multiple cls_tokens.
Args:
image (torch.Tensor): Float32 tensor with shape [b, 3, h, w], normalized to [0, 1].
Returns:
logits (torch.Tensor): Float32 tensor with shape [b, n_classes].
aux_outputs:
"""
def __init__(self, backbone: nn.Module, head: nn.Module):
super().__init__()
self.backbone = backbone
self.head = head
self.cuda().float()
def forward(self, image):
logits = self.head(*self.backbone(image))
return logits
def add_method(obj, name, method):
import types
setattr(obj, name, types.MethodType(method, obj))
def get_clip_encode_func(layers):
def func(self, x):
x = self.conv1(x) # shape = [*, width, grid, grid]
# shape = [*, width, grid ** 2]
x = x.reshape(x.shape[0], x.shape[1], -1)
extra_tokens = getattr(self, "extra_tokens", [])
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
class_token = self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
)
special_tokens = [
getattr(self, name).to(x.dtype)
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
for name in extra_tokens
]
x = torch.cat(
[class_token, *special_tokens, x], dim=1
) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
outs = []
max_layer = max(layers)
use_checkpoint = self.transformer.use_checkpoint
for layer_i, blk in enumerate(self.transformer.resblocks):
if layer_i > max_layer:
break
if self.training and use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
outs.append(x)
outs = torch.stack(outs).permute(0, 2, 1, 3)
cls_tokens = outs[layers, :, 0, :]
extra_token_feats = {}
for i, name in enumerate(extra_tokens):
extra_token_feats[name] = outs[layers, :, i + 1, :]
L, B, N, C = outs.shape
import math
W = int(math.sqrt(N - 1 - len(extra_tokens)))
features = (
outs[layers, :, 1 + len(extra_tokens) :, :]
.reshape(len(layers), B, W, W, C)
.permute(0, 1, 4, 2, 3)
)
if getattr(self, "ret_special", False):
return features, cls_tokens, extra_token_feats
else:
return features, cls_tokens
return func
def farl_classification(num_classes=2, layers=list(range(12))):
model = VisualTransformer(
input_resolution=224,
patch_size=16,
width=768,
layers=12,
heads=12,
output_dim=512,
)
channel = 768
model = model.cuda()
del model.proj
del model.ln_post
add_method(model, "forward", get_clip_encode_func(layers))
head = VITClassificationHeadV0(
num_features=len(layers), channel=channel, num_labels=num_classes, norm=True
)
model = FACTransformer(model, head)
return model
|