Spaces:
Runtime error
Runtime error
import time | |
from functools import partial, reduce | |
import torch | |
import torch.nn as nn | |
from torch.nn.functional import adaptive_avg_pool3d | |
from .conv_backbone import convnext_3d_small, convnext_3d_tiny, convnextv2_3d_pico, convnextv2_3d_femto, clip_vitL14 | |
from .head import IQAHead, VARHead, VQAHead | |
from .swin_backbone import SwinTransformer2D as ImageBackbone | |
from .swin_backbone import SwinTransformer3D as VideoBackbone | |
from .swin_backbone import swin_3d_small, swin_3d_tiny | |
class BaseEvaluator(nn.Module): | |
def __init__( | |
self, backbone=dict(), vqa_head=dict(), | |
): | |
super().__init__() | |
self.backbone = VideoBackbone(**backbone) | |
self.vqa_head = VQAHead(**vqa_head) | |
def forward(self, vclip, inference=True, **kwargs): | |
if inference: | |
self.eval() | |
with torch.no_grad(): | |
feat = self.backbone(vclip) | |
score = self.vqa_head(feat) | |
self.train() | |
return score | |
else: | |
feat = self.backbone(vclip) | |
score = self.vqa_head(feat) | |
return score | |
def forward_with_attention(self, vclip): | |
self.eval() | |
with torch.no_grad(): | |
feat, avg_attns = self.backbone(vclip, require_attn=True) | |
score = self.vqa_head(feat) | |
return score, avg_attns | |
class COVER(nn.Module): | |
def __init__( | |
self, | |
backbone_size="divided", | |
backbone_preserve_keys="fragments,resize", | |
multi=False, | |
layer=-1, | |
backbone=dict( | |
resize={"window_size": (4, 4, 4)}, fragments={"window_size": (4, 4, 4)} | |
), | |
divide_head=False, | |
vqa_head=dict(in_channels=768), | |
var=False, | |
): | |
self.backbone_preserve_keys = backbone_preserve_keys.split(",") | |
self.multi = multi | |
self.layer = layer | |
super().__init__() | |
for key, hypers in backbone.items(): | |
print(backbone_size) | |
if key not in self.backbone_preserve_keys: | |
continue | |
if backbone_size == "divided": | |
t_backbone_size = hypers["type"] | |
else: | |
t_backbone_size = backbone_size | |
if t_backbone_size == "swin_tiny": | |
b = swin_3d_tiny(**backbone[key]) | |
elif t_backbone_size == "swin_tiny_grpb": | |
# to reproduce fast-vqa | |
b = VideoBackbone() | |
elif t_backbone_size == "swin_tiny_grpb_m": | |
# to reproduce fast-vqa-m | |
b = VideoBackbone(window_size=(4, 4, 4), frag_biases=[0, 0, 0, 0]) | |
elif t_backbone_size == "swin_small": | |
b = swin_3d_small(**backbone[key]) | |
elif t_backbone_size == "conv_tiny": | |
b = convnext_3d_tiny(pretrained=True) | |
elif t_backbone_size == "conv_small": | |
b = convnext_3d_small(pretrained=True) | |
elif t_backbone_size == "conv_femto": | |
b = convnextv2_3d_femto(pretrained=True) | |
elif t_backbone_size == "conv_pico": | |
b = convnextv2_3d_pico(pretrained=True) | |
elif t_backbone_size == "xclip": | |
raise NotImplementedError | |
elif t_backbone_size == "clip_iqa+": | |
b = clip_vitL14(pretrained=True) | |
else: | |
raise NotImplementedError | |
print("Setting backbone:", key + "_backbone") | |
setattr(self, key + "_backbone", b) | |
if divide_head: | |
for key in backbone: | |
pre_pool = False #if key == "technical" else True | |
if key not in self.backbone_preserve_keys: | |
continue | |
b = VQAHead(pre_pool=pre_pool, **vqa_head) | |
print("Setting head:", key + "_head") | |
setattr(self, key + "_head", b) | |
else: | |
if var: | |
self.vqa_head = VARHead(**vqa_head) | |
print(b) | |
else: | |
self.vqa_head = VQAHead(**vqa_head) | |
self.smtc_gate_tech = CrossGatingBlock(x_features=768, num_channels=768, block_size=1, | |
grid_size=1, upsample_y=False, dropout_rate=0.1, use_bias=True, use_global_mlp=False) | |
self.smtc_gate_aesc = CrossGatingBlock(x_features=768, num_channels=768, block_size=1, | |
grid_size=1, upsample_y=False, dropout_rate=0.1, use_bias=True, use_global_mlp=False) | |
def forward( | |
self, | |
vclips, | |
inference=True, | |
return_pooled_feats=False, | |
return_raw_feats=False, | |
reduce_scores=False, | |
pooled=False, | |
**kwargs | |
): | |
assert (return_pooled_feats & return_raw_feats) == False, "Please only choose one kind of features to return" | |
if inference: | |
self.eval() | |
with torch.no_grad(): | |
scores = [] | |
feats = {} | |
for key in vclips: | |
if key == 'technical' or key == 'aesthetic': | |
feat = getattr(self, key.split("_")[0] + "_backbone")( | |
vclips[key], multi=self.multi, layer=self.layer, **kwargs | |
) | |
if key == 'technical': | |
feat_gated = self.smtc_gate_tech(feats['semantic'], feat) | |
elif key == 'aesthetic': | |
feat_gated = self.smtc_gate_aesc(feats['semantic'], feat) | |
if hasattr(self, key.split("_")[0] + "_head"): | |
scores += [getattr(self, key.split("_")[0] + "_head")(feat_gated)] | |
else: | |
scores += [getattr(self, "vqa_head")(feat_gated)] | |
elif key == 'semantic': | |
x = vclips[key].squeeze() | |
x = x.permute(1,0,2,3) | |
feat, _ = getattr(self, key.split("_")[0] + "_backbone")( | |
x, multi=self.multi, layer=self.layer, **kwargs | |
) | |
# for image feature from clipiqa+ VIT14 | |
# image feature shape (t, c) -> (16, 768) | |
feat = feat.permute(1,0).contiguous() # (c, t) -> (768, 16) | |
feat = feat.unsqueeze(-1).unsqueeze(-1) # (c, t, w, h) -> (768, 16, 1, 1) | |
feat_expand = feat.expand(-1, -1, 7, 7) # (c, t, w, h) -> (768, 16, 7, 7) | |
feat_expand = feat_expand.unsqueeze(0) # (b, c, t, w, h) -> (1, 768, 16, 7, 7) | |
if hasattr(self, key.split("_")[0] + "_head"): | |
score = getattr(self, key.split("_")[0] + "_head")(feat_expand) | |
else: | |
score = getattr(self, "vqa_head")(feat_expand) | |
scores += [score] | |
feats[key] = feat_expand | |
if reduce_scores: | |
if len(scores) > 1: | |
scores = reduce(lambda x, y: x + y, scores) | |
else: | |
scores = scores[0] | |
if pooled: | |
scores = torch.mean(scores, (1, 2, 3, 4)) | |
self.train() | |
if return_pooled_feats or return_raw_feats: | |
return scores, feats | |
return scores | |
else: | |
self.train() | |
scores = [] | |
feats = {} | |
for key in vclips: | |
if key == 'technical' or key == 'aesthetic': | |
feat = getattr(self, key.split("_")[0] + "_backbone")( | |
vclips[key], multi=self.multi, layer=self.layer, **kwargs | |
) | |
if key == 'technical': | |
feat_gated = self.smtc_gate_tech(feats['semantic'], feat) | |
elif key == 'aesthetic': | |
feat_gated = self.smtc_gate_aesc(feats['semantic'], feat) | |
if hasattr(self, key.split("_")[0] + "_head"): | |
scores += [getattr(self, key.split("_")[0] + "_head")(feat_gated)] | |
else: | |
scores += [getattr(self, "vqa_head")(feat_gated)] | |
feats[key] = feat | |
elif key == 'semantic': | |
scores_semantic_list = [] | |
feats_semantic_list = [] | |
for batch_idx in range(vclips[key].shape[0]): | |
x = vclips[key][batch_idx].squeeze() | |
x = x.permute(1,0,2,3) | |
feat, _ = getattr(self, key.split("_")[0] + "_backbone")( | |
x, multi=self.multi, layer=self.layer, **kwargs | |
) | |
# for image feature from clipiqa+ VIT14 | |
# image feature shape (t, c) -> (16, 768) | |
feat = feat.permute(1,0).contiguous() # (c, t) -> (768, 16) | |
feat = feat.unsqueeze(-1).unsqueeze(-1) # (c, t, w, h) -> (768, 16, 1, 1) | |
feat_expand = feat.expand(-1, -1, 7, 7) # (c, t, w, h) -> (768, 16, 7, 7) | |
feats_semantic_list.append(feat_expand) | |
if hasattr(self, key.split("_")[0] + "_head"): | |
feat_expand = feat_expand.unsqueeze(0) # (b, c, t, w, h) -> (1, 768, 16, 7, 7) | |
score = getattr(self, key.split("_")[0] + "_head")(feat_expand) | |
score = score.squeeze(0) | |
scores_semantic_list.append(score) | |
else: | |
feat_expand = feat_expand.unsqueeze(0) # (b, c, t, w, h) -> (1, 768, 16, 7, 7) | |
score = getattr(self, "vqa_head")(feat_expand) | |
score = score.squeeze(0) | |
scores_semantic_list.append(score) | |
scores_semantic_tensor = torch.stack(scores_semantic_list) | |
feats[key] = torch.stack(feats_semantic_list) | |
scores += [scores_semantic_tensor] | |
if return_pooled_feats: | |
feats[key] = feat.mean((-3, -2, -1)) | |
if reduce_scores: | |
if len(scores) > 1: | |
scores = reduce(lambda x, y: x + y, scores) | |
else: | |
scores = scores[0] | |
if pooled: | |
print(scores.shape) | |
scores = torch.mean(scores, (1, 2, 3, 4)) | |
print(scores.shape) | |
if return_pooled_feats: | |
return scores, feats | |
return scores | |
def forward_head( | |
self, | |
feats, | |
inference=True, | |
reduce_scores=False, | |
pooled=False, | |
**kwargs | |
): | |
if inference: | |
self.eval() | |
with torch.no_grad(): | |
scores = [] | |
feats = {} | |
for key in feats: | |
feat = feats[key] | |
if hasattr(self, key.split("_")[0] + "_head"): | |
scores += [getattr(self, key.split("_")[0] + "_head")(feat)] | |
else: | |
scores += [getattr(self, "vqa_head")(feat)] | |
if reduce_scores: | |
if len(scores) > 1: | |
scores = reduce(lambda x, y: x + y, scores) | |
else: | |
scores = scores[0] | |
if pooled: | |
scores = torch.mean(scores, (1, 2, 3, 4)) | |
self.train() | |
return scores | |
else: | |
self.train() | |
scores = [] | |
feats = {} | |
for key in vclips: | |
feat = getattr(self, key.split("_")[0] + "_backbone")( | |
vclips[key], multi=self.multi, layer=self.layer, **kwargs | |
) | |
if hasattr(self, key.split("_")[0] + "_head"): | |
scores += [getattr(self, key.split("_")[0] + "_head")(feat)] | |
else: | |
scores += [getattr(self, "vqa_head")(feat)] | |
if return_pooled_feats: | |
feats[key] = feat | |
if reduce_scores: | |
if len(scores) > 1: | |
scores = reduce(lambda x, y: x + y, scores) | |
else: | |
scores = scores[0] | |
if pooled: | |
print(scores.shape) | |
scores = torch.mean(scores, (1, 2, 3, 4)) | |
print(scores.shape) | |
if return_pooled_feats: | |
return scores, feats | |
return scores | |
class MinimumCOVER(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.technical_backbone = VideoBackbone() | |
self.aesthetic_backbone = convnext_3d_tiny(pretrained=True) | |
self.technical_head = VQAHead(pre_pool=False, in_channels=768) | |
self.aesthetic_head = VQAHead(pre_pool=False, in_channels=768) | |
def forward(self,aesthetic_view, technical_view): | |
self.eval() | |
with torch.no_grad(): | |
aesthetic_score = self.aesthetic_head(self.aesthetic_backbone(aesthetic_view)) | |
technical_score = self.technical_head(self.technical_backbone(technical_view)) | |
aesthetic_score_pooled = torch.mean(aesthetic_score, (1,2,3,4)) | |
technical_score_pooled = torch.mean(technical_score, (1,2,3,4)) | |
return [aesthetic_score_pooled, technical_score_pooled] | |
class BaseImageEvaluator(nn.Module): | |
def __init__( | |
self, backbone=dict(), iqa_head=dict(), | |
): | |
super().__init__() | |
self.backbone = ImageBackbone(**backbone) | |
self.iqa_head = IQAHead(**iqa_head) | |
def forward(self, image, inference=True, **kwargs): | |
if inference: | |
self.eval() | |
with torch.no_grad(): | |
feat = self.backbone(image) | |
score = self.iqa_head(feat) | |
self.train() | |
return score | |
else: | |
feat = self.backbone(image) | |
score = self.iqa_head(feat) | |
return score | |
def forward_with_attention(self, image): | |
self.eval() | |
with torch.no_grad(): | |
feat, avg_attns = self.backbone(image, require_attn=True) | |
score = self.iqa_head(feat) | |
return score, avg_attns | |
class CrossGatingBlock(nn.Module): #input shape: n, c, h, w | |
"""Cross-gating MLP block.""" | |
def __init__(self, x_features, num_channels, block_size, grid_size, cin_y=0,upsample_y=True, use_bias=True, use_global_mlp=True, dropout_rate=0): | |
super().__init__() | |
self.cin_y = cin_y | |
self.x_features = x_features | |
self.num_channels = num_channels | |
self.block_size = block_size | |
self.grid_size = grid_size | |
self.upsample_y = upsample_y | |
self.use_bias = use_bias | |
self.use_global_mlp = use_global_mlp | |
self.drop = dropout_rate | |
self.Conv_0 = nn.Linear(self.x_features, self.num_channels) | |
self.Conv_1 = nn.Linear(self.num_channels, self.num_channels) | |
self.in_project_x = nn.Linear(self.num_channels, self.num_channels, bias=self.use_bias) | |
self.gelu1 = nn.GELU(approximate='tanh') | |
self.out_project_y = nn.Linear(self.num_channels, self.num_channels, bias=self.use_bias) | |
self.dropout1 = nn.Dropout(self.drop) | |
def forward(self, x,y): #n,c,t,h,w | |
# Upscale Y signal, y is the gating signal. | |
assert y.shape == x.shape | |
x = x.permute(0,2,3,4,1).contiguous() #n,t,h,w,c | |
y = y.permute(0,2,3,4,1).contiguous() #n,t,h,w,c | |
x = self.Conv_0(x) | |
y = self.Conv_1(y) | |
shortcut_y = y | |
x = self.in_project_x(x) | |
gx = self.gelu1(x) | |
# Apply cross gating | |
y = y * gx # gating y using x | |
y = self.out_project_y(y) | |
y = self.dropout1(y) | |
y = y + shortcut_y # y = y * x + y | |
return y.permute(0,4,1,2,3).contiguous() #n,c,t,h,w |