import torch import torch.nn as nn import numpy as np import torch.nn.functional as F # from util import util from lib.models.networks.audio_network import ResNetSE, SEBasicBlock from lib.models.networks.FAN_feature_extractor import FAN_use from lib.models.networks.vision_network import ResNeXt50 from torchvision.models.vgg import vgg19_bn from lib.models.networks.swin_transformer import SwinTransformer from lib.models.networks.wavlm.wavlm import WavLM, WavLMConfig class ResSEAudioEncoder(nn.Module): def __init__(self, opt, nOut=2048, n_mel_T=None): super(ResSEAudioEncoder, self).__init__() self.nOut = nOut self.opt = opt pose_dim = self.opt.model.net_nonidentity.pose_dim eye_dim = self.opt.model.net_nonidentity.eye_dim # Number of filters num_filters = [32, 64, 128, 256] if n_mel_T is None: # use it when use audio identity n_mel_T = opt.model.net_audio.n_mel_T self.model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, self.nOut, n_mel_T=n_mel_T) if opt.audio_only: self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim)) else: self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim-eye_dim)) # def forward_feature(self, x): def forward(self, x, _type=None): input_size = x.size() if len(input_size) == 5: bz, clip_len, c, f, t = input_size x = x.view(bz * clip_len, c, f, t) out = self.model(x) if _type == "to_mouth_embed": out = out.view(-1, out.shape[-1]) mouth_embed = self.mouth_embed(out) return out, mouth_embed return out # def forward(self, x): # out = self.forward_feature(x) # score = self.fc(out) # return out, score class ResSESyncEncoder(ResSEAudioEncoder): def __init__(self, opt): super(ResSESyncEncoder, self).__init__(opt, nOut=512, n_mel_T=1) class ResNeXtEncoder(ResNeXt50): def __init__(self, opt): super(ResNeXtEncoder, self).__init__(opt) class VGGEncoder(nn.Module): def __init__(self, opt): super(VGGEncoder, self).__init__() self.model = vgg19_bn(num_classes=opt.data.num_classes) def forward(self, x): return self.model(x) class FanEncoder(nn.Module): def __init__(self, opt): super(FanEncoder, self).__init__() self.opt = opt pose_dim = self.opt.model.net_nonidentity.pose_dim eye_dim = self.opt.model.net_nonidentity.eye_dim self.model = FAN_use() # self.classifier = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, opt.data.num_classes)) # mapper to mouth subspace ### revised version1 # self.to_mouth = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) # self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim-eye_dim)) # mapper to head pose subspace ### revised version1 self.to_headpose = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) self.headpose_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, pose_dim)) self.to_eye = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) self.eye_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, eye_dim)) self.to_emo = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) self.emo_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 30)) # self.feature_fuse = nn.Sequential(nn.ReLU(), nn.Linear(1036, 512)) # self.feature_fuse = nn.Sequential(nn.ReLU(), nn.Linear(1036, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) def forward_feature(self, x): net = self.model(x) return net def forward(self, x, _type="feature"): if _type == "feature": return self.forward_feature(x) elif _type == "feature_embed": x = self.model(x) # mouth_feat = self.to_mouth(x) # mouth_emb = self.mouth_embed(mouth_feat) headpose_feat = self.to_headpose(x) headpose_emb = self.headpose_embed(headpose_feat) eye_feat = self.to_eye(x) eye_embed = self.eye_embed(eye_feat) emo_feat = self.to_emo(x) emo_embed = self.emo_embed(emo_feat) # return headpose_emb, eye_embed, emo_feat return headpose_emb, eye_embed, emo_embed elif _type == "to_headpose": x = self.model(x) headpose_feat = self.to_headpose(x) headpose_emb = self.headpose_embed(headpose_feat) return headpose_emb # class WavlmEncoder(nn.Module): # def __init__(self, opt): # super(WavlmEncoder, self).__init__() # wavlm_checkpoint = torch.load(opt.model.net_audio.official_pretrain) # wavlm_cfg = WavLMConfig(wavlm_checkpoint['cfg']) # # pose_dim = opt.model.net_nonidentity.pose_dim # self.model = WavLM(wavlm_cfg) # # self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(768, 512-pose_dim)) # def forward(self, x): # feature = self.model.extract_features(x)[0] # # audio_feat = self.mouth_embed(feature.mean(1)) # return feature.mean(1) class WavlmEncoder(nn.Module): def __init__(self, opt): super(WavlmEncoder, self).__init__() self.input_wins = opt.audio.num_frames_per_clip self.s = (self.input_wins - 5) // 2 * 2 self.e = self.s + 5 * 2 - 1 wavlm_checkpoint = torch.load(opt.model.net_audio.official_pretrain) wavlm_cfg = WavLMConfig(wavlm_checkpoint['cfg']) pose_dim = opt.model.net_nonidentity.pose_dim self.model = WavLM(wavlm_cfg) self.mouth_feat = nn.Sequential(nn.Linear(768, 512), nn.ReLU(), nn.Linear(512, 512)) self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim)) def forward(self, x): feature = self.model.extract_features(x)[0] feature = self.mouth_feat(feature[:, self.s:self.e].mean(1)) audio_feat = self.mouth_embed(feature) return feature, audio_feat class SwinEncoder(nn.Module): def __init__(self, cfg): super(SwinEncoder, self).__init__() self.encoder = SwinTransformer( num_classes = 0, img_size = cfg.model.net_nonidentity.img_size, patch_size = cfg.model.net_nonidentity.patch_size, in_chans = cfg.model.net_nonidentity.in_chans, embed_dim = cfg.model.net_nonidentity.embed_dim, depths = cfg.model.net_nonidentity.depths, num_heads = cfg.model.net_nonidentity.num_heads, window_size = cfg.model.net_nonidentity.window_size, mlp_ratio = cfg.model.net_nonidentity.mlp_ratio, qkv_bias = cfg.model.net_nonidentity.qkv_bias, qk_scale = None if not cfg.model.net_nonidentity.qk_scale else 0.1, drop_rate = cfg.model.net_nonidentity.drop_rate, drop_path_rate = cfg.model.net_nonidentity.drop_path_rate, ape = cfg.model.net_nonidentity.ape, patch_norm = cfg.model.net_nonidentity.patch_norm, use_checkpoint = False ) # self.audio_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.eye_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.landmark_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.exp_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) def forward(self, img): feature = self.encoder(img) # audio_embed = self.audio_mlp(feature) # eye_embed = self.eye_mlp(feature) # ldmk_embed = self.landmark_mlp(feature) # exp_embed = self.exp_mlp(feature) # return feature, audio_embed, eye_embed, ldmk_embed, exp_embed return feature class ResEncoder(nn.Module): def __init__(self, opt): super(ResEncoder, self).__init__() self.opt = opt self.model = resnet50(num_classes=512, include_top=True) def forward(self, x): feature = self.model(x) # print(feature.shape) return feature class FansEncoder(nn.Module): def __init__(self, cfg): super(FansEncoder, self).__init__() self.encoder = FAN_use(out_dim=768) # self.audio_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.eye_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.landmark_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.exp_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) def forward(self, img): feature = self.encoder(img) # audio_embed = self.audio_mlp(feature) # eye_embed = self.eye_mlp(feature) # ldmk_embed = self.landmark_mlp(feature) # exp_embed = self.exp_mlp(feature) # return feature, audio_embed, eye_embed, ldmk_embed, exp_embed return feature class SwinEncoder(nn.Module): def __init__(self, cfg): super(SwinEncoder, self).__init__() self.encoder = SwinTransformer( num_classes = 0, img_size = cfg.model.net_nonidentity.img_size, patch_size = cfg.model.net_nonidentity.patch_size, in_chans = cfg.model.net_nonidentity.in_chans, embed_dim = cfg.model.net_nonidentity.embed_dim, depths = cfg.model.net_nonidentity.depths, num_heads = cfg.model.net_nonidentity.num_heads, window_size = cfg.model.net_nonidentity.window_size, mlp_ratio = cfg.model.net_nonidentity.mlp_ratio, qkv_bias = cfg.model.net_nonidentity.qkv_bias, qk_scale = None if not cfg.model.net_nonidentity.qk_scale else 0.1, drop_rate = cfg.model.net_nonidentity.drop_rate, drop_path_rate = cfg.model.net_nonidentity.drop_path_rate, ape = cfg.model.net_nonidentity.ape, patch_norm = cfg.model.net_nonidentity.patch_norm, use_checkpoint = False ) # self.audio_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.eye_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.landmark_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.exp_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) def forward(self, img): feature = self.encoder(img) # audio_embed = self.audio_mlp(feature) # eye_embed = self.eye_mlp(feature) # ldmk_embed = self.landmark_mlp(feature) # exp_embed = self.exp_mlp(feature) # return feature, audio_embed, eye_embed, ldmk_embed, exp_embed return feature class ResEncoder(nn.Module): def __init__(self, opt): super(ResEncoder, self).__init__() self.opt = opt self.model = resnet50(num_classes=512, include_top=True) def forward(self, x): feature = self.model(x) # print(feature.shape) return feature class FansEncoder(nn.Module): def __init__(self, cfg): super(FansEncoder, self).__init__() self.encoder = FAN_use(out_dim=768) # self.audio_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.eye_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.landmark_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) # self.exp_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) def forward(self, img): feature = self.encoder(img) # audio_embed = self.audio_mlp(feature) # eye_embed = self.eye_mlp(feature) # ldmk_embed = self.landmark_mlp(feature) # exp_embed = self.exp_mlp(feature) # return feature, audio_embed, eye_embed, ldmk_embed, exp_embed return feature