FSFM-3C
Add V1.0
d4e7f2f
raw
history blame
4.8 kB
from torch import nn
import torch
import torch.utils.checkpoint as checkpoint
from .model import VisualTransformer
class VITClassificationHeadV0(nn.Module):
def __init__(
self,
num_features: int,
channel: int,
num_labels: int,
norm=False,
dropout=0.0,
ret_feat=False,
):
super().__init__()
self.weights = nn.Parameter(
torch.ones(1, num_features * 3, 1, dtype=torch.float32)
)
self.final_fc = nn.Linear(channel, num_labels)
self.norm = norm
if self.norm:
for i in range(num_features * 3):
setattr(self, f"norm_{i}", nn.LayerNorm(channel))
self.dropout = nn.Dropout(p=dropout)
self.ret_feat = ret_feat
def forward(self, features, cls_tokens):
xs = []
for feature, cls_token in zip(features, cls_tokens):
# feature: b x c x s x s
# cls_token: b x c
xs.append(feature.mean([2, 3]))
xs.append(feature.max(-1).values.max(-1).values)
xs.append(cls_token)
if self.norm:
xs = [getattr(self, f"norm_{i}")(x) for i, x in enumerate(xs)]
xs = torch.stack(xs, dim=1) # b x 3N x c
feat = (xs * self.weights.softmax(dim=1)).sum(1) # b x c
x = self.dropout(feat)
x = self.final_fc(x) # b x num_labels
if self.ret_feat:
return x, feat
else:
return x
class FACTransformer(nn.Module):
"""A face attribute classification transformer leveraging multiple cls_tokens.
Args:
image (torch.Tensor): Float32 tensor with shape [b, 3, h, w], normalized to [0, 1].
Returns:
logits (torch.Tensor): Float32 tensor with shape [b, n_classes].
aux_outputs:
"""
def __init__(self, backbone: nn.Module, head: nn.Module):
super().__init__()
self.backbone = backbone
self.head = head
self.cuda().float()
def forward(self, image):
logits = self.head(*self.backbone(image))
return logits
def add_method(obj, name, method):
import types
setattr(obj, name, types.MethodType(method, obj))
def get_clip_encode_func(layers):
def func(self, x):
x = self.conv1(x) # shape = [*, width, grid, grid]
# shape = [*, width, grid ** 2]
x = x.reshape(x.shape[0], x.shape[1], -1)
extra_tokens = getattr(self, "extra_tokens", [])
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
class_token = self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
)
special_tokens = [
getattr(self, name).to(x.dtype)
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
for name in extra_tokens
]
x = torch.cat(
[class_token, *special_tokens, x], dim=1
) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
outs = []
max_layer = max(layers)
use_checkpoint = self.transformer.use_checkpoint
for layer_i, blk in enumerate(self.transformer.resblocks):
if layer_i > max_layer:
break
if self.training and use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
outs.append(x)
outs = torch.stack(outs).permute(0, 2, 1, 3)
cls_tokens = outs[layers, :, 0, :]
extra_token_feats = {}
for i, name in enumerate(extra_tokens):
extra_token_feats[name] = outs[layers, :, i + 1, :]
L, B, N, C = outs.shape
import math
W = int(math.sqrt(N - 1 - len(extra_tokens)))
features = (
outs[layers, :, 1 + len(extra_tokens) :, :]
.reshape(len(layers), B, W, W, C)
.permute(0, 1, 4, 2, 3)
)
if getattr(self, "ret_special", False):
return features, cls_tokens, extra_token_feats
else:
return features, cls_tokens
return func
def farl_classification(num_classes=2, layers=list(range(12))):
model = VisualTransformer(
input_resolution=224,
patch_size=16,
width=768,
layers=12,
heads=12,
output_dim=512,
)
channel = 768
model = model.cuda()
del model.proj
del model.ln_post
add_method(model, "forward", get_clip_encode_func(layers))
head = VITClassificationHeadV0(
num_features=len(layers), channel=channel, num_labels=num_classes, norm=True
)
model = FACTransformer(model, head)
return model