|
from torch import nn |
|
import torch |
|
import torch.utils.checkpoint as checkpoint |
|
from .model import VisualTransformer |
|
|
|
|
|
class VITClassificationHeadV0(nn.Module): |
|
def __init__( |
|
self, |
|
num_features: int, |
|
channel: int, |
|
num_labels: int, |
|
norm=False, |
|
dropout=0.0, |
|
ret_feat=False, |
|
): |
|
super().__init__() |
|
self.weights = nn.Parameter( |
|
torch.ones(1, num_features * 3, 1, dtype=torch.float32) |
|
) |
|
self.final_fc = nn.Linear(channel, num_labels) |
|
self.norm = norm |
|
if self.norm: |
|
for i in range(num_features * 3): |
|
setattr(self, f"norm_{i}", nn.LayerNorm(channel)) |
|
self.dropout = nn.Dropout(p=dropout) |
|
self.ret_feat = ret_feat |
|
|
|
def forward(self, features, cls_tokens): |
|
xs = [] |
|
for feature, cls_token in zip(features, cls_tokens): |
|
|
|
|
|
xs.append(feature.mean([2, 3])) |
|
xs.append(feature.max(-1).values.max(-1).values) |
|
xs.append(cls_token) |
|
if self.norm: |
|
xs = [getattr(self, f"norm_{i}")(x) for i, x in enumerate(xs)] |
|
xs = torch.stack(xs, dim=1) |
|
feat = (xs * self.weights.softmax(dim=1)).sum(1) |
|
x = self.dropout(feat) |
|
x = self.final_fc(x) |
|
if self.ret_feat: |
|
return x, feat |
|
else: |
|
return x |
|
|
|
|
|
class FACTransformer(nn.Module): |
|
"""A face attribute classification transformer leveraging multiple cls_tokens. |
|
Args: |
|
image (torch.Tensor): Float32 tensor with shape [b, 3, h, w], normalized to [0, 1]. |
|
Returns: |
|
logits (torch.Tensor): Float32 tensor with shape [b, n_classes]. |
|
aux_outputs: |
|
""" |
|
|
|
def __init__(self, backbone: nn.Module, head: nn.Module): |
|
super().__init__() |
|
self.backbone = backbone |
|
self.head = head |
|
self.cuda().float() |
|
|
|
def forward(self, image): |
|
logits = self.head(*self.backbone(image)) |
|
return logits |
|
|
|
|
|
def add_method(obj, name, method): |
|
import types |
|
|
|
setattr(obj, name, types.MethodType(method, obj)) |
|
|
|
|
|
def get_clip_encode_func(layers): |
|
def func(self, x): |
|
x = self.conv1(x) |
|
|
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
extra_tokens = getattr(self, "extra_tokens", []) |
|
x = x.permute(0, 2, 1) |
|
class_token = self.class_embedding.to(x.dtype) + torch.zeros( |
|
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device |
|
) |
|
special_tokens = [ |
|
getattr(self, name).to(x.dtype) |
|
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) |
|
for name in extra_tokens |
|
] |
|
x = torch.cat( |
|
[class_token, *special_tokens, x], dim=1 |
|
) |
|
x = x + self.positional_embedding.to(x.dtype) |
|
x = self.ln_pre(x) |
|
x = x.permute(1, 0, 2) |
|
outs = [] |
|
max_layer = max(layers) |
|
use_checkpoint = self.transformer.use_checkpoint |
|
for layer_i, blk in enumerate(self.transformer.resblocks): |
|
if layer_i > max_layer: |
|
break |
|
if self.training and use_checkpoint: |
|
x = checkpoint.checkpoint(blk, x) |
|
else: |
|
x = blk(x) |
|
outs.append(x) |
|
|
|
outs = torch.stack(outs).permute(0, 2, 1, 3) |
|
cls_tokens = outs[layers, :, 0, :] |
|
|
|
extra_token_feats = {} |
|
for i, name in enumerate(extra_tokens): |
|
extra_token_feats[name] = outs[layers, :, i + 1, :] |
|
L, B, N, C = outs.shape |
|
import math |
|
|
|
W = int(math.sqrt(N - 1 - len(extra_tokens))) |
|
features = ( |
|
outs[layers, :, 1 + len(extra_tokens) :, :] |
|
.reshape(len(layers), B, W, W, C) |
|
.permute(0, 1, 4, 2, 3) |
|
) |
|
if getattr(self, "ret_special", False): |
|
return features, cls_tokens, extra_token_feats |
|
else: |
|
return features, cls_tokens |
|
|
|
return func |
|
|
|
|
|
def farl_classification(num_classes=2, layers=list(range(12))): |
|
model = VisualTransformer( |
|
input_resolution=224, |
|
patch_size=16, |
|
width=768, |
|
layers=12, |
|
heads=12, |
|
output_dim=512, |
|
) |
|
channel = 768 |
|
model = model.cuda() |
|
del model.proj |
|
del model.ln_post |
|
add_method(model, "forward", get_clip_encode_func(layers)) |
|
head = VITClassificationHeadV0( |
|
num_features=len(layers), channel=channel, num_labels=num_classes, norm=True |
|
) |
|
model = FACTransformer(model, head) |
|
return model |
|
|