|
import gradio as gr |
|
|
|
import argparse |
|
import cv2 |
|
import imageio |
|
import math |
|
from math import ceil |
|
import matplotlib.pyplot as plt |
|
import matplotlib.animation as animation |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Function |
|
|
|
|
|
class RelationModuleMultiScale(torch.nn.Module): |
|
|
|
def __init__(self, img_feature_dim, num_bottleneck, num_frames): |
|
super(RelationModuleMultiScale, self).__init__() |
|
self.subsample_num = 3 |
|
self.img_feature_dim = img_feature_dim |
|
self.scales = [i for i in range(num_frames, 1, -1)] |
|
self.relations_scales = [] |
|
self.subsample_scales = [] |
|
for scale in self.scales: |
|
relations_scale = self.return_relationset(num_frames, scale) |
|
self.relations_scales.append(relations_scale) |
|
self.subsample_scales.append(min(self.subsample_num, len(relations_scale))) |
|
self.num_frames = num_frames |
|
self.fc_fusion_scales = nn.ModuleList() |
|
for i in range(len(self.scales)): |
|
scale = self.scales[i] |
|
fc_fusion = nn.Sequential(nn.ReLU(), nn.Linear(scale * self.img_feature_dim, num_bottleneck), nn.ReLU()) |
|
self.fc_fusion_scales += [fc_fusion] |
|
|
|
def forward(self, input): |
|
act_scale_1 = input[:, self.relations_scales[0][0] , :] |
|
act_scale_1 = act_scale_1.view(act_scale_1.size(0), self.scales[0] * self.img_feature_dim) |
|
act_scale_1 = self.fc_fusion_scales[0](act_scale_1) |
|
act_scale_1 = act_scale_1.unsqueeze(1) |
|
act_all = act_scale_1.clone() |
|
for scaleID in range(1, len(self.scales)): |
|
act_relation_all = torch.zeros_like(act_scale_1) |
|
num_total_relations = len(self.relations_scales[scaleID]) |
|
num_select_relations = self.subsample_scales[scaleID] |
|
idx_relations_evensample = [int(ceil(i * num_total_relations / num_select_relations)) for i in range(num_select_relations)] |
|
for idx in idx_relations_evensample: |
|
act_relation = input[:, self.relations_scales[scaleID][idx], :] |
|
act_relation = act_relation.view(act_relation.size(0), self.scales[scaleID] * self.img_feature_dim) |
|
act_relation = self.fc_fusion_scales[scaleID](act_relation) |
|
act_relation = act_relation.unsqueeze(1) |
|
act_relation_all += act_relation |
|
act_all = torch.cat((act_all, act_relation_all), 1) |
|
return act_all |
|
|
|
def return_relationset(self, num_frames, num_frames_relation): |
|
import itertools |
|
return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation)) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--dataset', default='Sprite', help='datasets') |
|
parser.add_argument('--data_root', default='dataset', help='root directory for data') |
|
parser.add_argument('--num_class', type=int, default=15, help='the number of class for jester dataset') |
|
parser.add_argument('--input_type', default='image', choices=['feature', 'image'], help='the type of input') |
|
parser.add_argument('--src', default='domain_1', help='source domain') |
|
parser.add_argument('--tar', default='domain_2', help='target domain') |
|
parser.add_argument('--num_segments', type=int, default=8, help='the number of frame segment') |
|
parser.add_argument('--backbone', type=str, default="dcgan", choices=['dcgan', 'resnet101', 'I3Dpretrain','I3Dfinetune'], help='backbone') |
|
parser.add_argument('--channels', default=3, type=int, help='input channels for image inputs') |
|
parser.add_argument('--add_fc', default=1, type=int, metavar='M', help='number of additional fc layers (excluding the last fc layer) (e.g. 0, 1, 2)') |
|
parser.add_argument('--fc_dim', type=int, default=1024, help='dimension of added fc') |
|
parser.add_argument('--frame_aggregation', type=str, default='trn', choices=[ 'rnn', 'trn'], help='aggregation of frame features (none if baseline_type is not video)') |
|
parser.add_argument('--dropout_rate', default=0.5, type=float, help='dropout ratio for frame-level feature (default: 0.5)') |
|
parser.add_argument('--f_dim', type=int, default=512, help='dim of f') |
|
parser.add_argument('--z_dim', type=int, default=512, help='dimensionality of z_t') |
|
parser.add_argument('--f_rnn_layers', type=int, default=1, help='number of layers (content lstm)') |
|
parser.add_argument('--use_bn', type=str, default='none', choices=['none', 'AdaBN', 'AutoDIAL'], help='normalization-based methods') |
|
parser.add_argument('--prior_sample', type=str, default='random', choices=['random', 'post'], help='how to sample prior') |
|
parser.add_argument('--batch_size', default=128, type=int, help='-batch size') |
|
parser.add_argument('--use_attn', type=str, default='TransAttn', choices=['none', 'TransAttn', 'general'], help='attention-mechanism') |
|
parser.add_argument('--data_threads', type=int, default=5, help='number of data loading threads') |
|
opt = parser.parse_args(args=[]) |
|
|
|
|
|
class GradReverse(Function): |
|
@staticmethod |
|
def forward(ctx, x, beta): |
|
ctx.beta = beta |
|
return x.view_as(x) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
grad_input = grad_output.neg() * ctx.beta |
|
return grad_input, None |
|
|
|
|
|
class TransferVAE_Video(nn.Module): |
|
|
|
def __init__(self, opt): |
|
super(TransferVAE_Video, self).__init__() |
|
self.f_dim = opt.f_dim |
|
self.z_dim = opt.z_dim |
|
self.fc_dim = opt.fc_dim |
|
self.channels = opt.channels |
|
self.input_type = opt.input_type |
|
self.frames = opt.num_segments |
|
self.use_bn = opt.use_bn |
|
self.frame_aggregation = opt.frame_aggregation |
|
self.batch_size = opt.batch_size |
|
self.use_attn = opt.use_attn |
|
self.dropout_rate = opt.dropout_rate |
|
self.num_class = opt.num_class |
|
self.prior_sample = opt.prior_sample |
|
|
|
if self.input_type == 'image': |
|
import dcgan_64 |
|
self.encoder = dcgan_64.encoder(self.fc_dim, self.channels) |
|
self.decoder = dcgan_64.decoder_woSkip(self.z_dim + self.f_dim, self.channels) |
|
self.fc_output_dim = self.fc_dim |
|
elif self.input_type == 'feature': |
|
if opt.backbone == 'resnet101': |
|
model_backnone = getattr(torchvision.models, opt.backbone)(True) |
|
self.input_dim = model_backnone.fc.in_features |
|
elif opt.backbone == 'I3Dpretrain': |
|
self.input_dim = 2048 |
|
elif opt.backbone == 'I3Dfinetune': |
|
self.input_dim = 2048 |
|
self.add_fc = opt.add_fc |
|
self.enc_fc_layer1 = nn.Linear(self.input_dim, self.fc_dim) |
|
self.dec_fc_layer1 = nn.Linear(self.fc_dim, self.input_dim) |
|
self.fc_output_dim = self.fc_dim |
|
|
|
if self.use_bn == 'shared': |
|
self.bn_enc_layer1 = nn.BatchNorm1d(self.fc_output_dim) |
|
self.bn_dec_layer1 = nn.BatchNorm1d(self.input_dim) |
|
elif self.use_bn == 'separated': |
|
self.bn_S_enc_layer1 = nn.BatchNorm1d(self.fc_output_dim) |
|
self.bn_T_enc_layer1 = nn.BatchNorm1d(self.fc_output_dim) |
|
self.bn_S_dec_layer1 = nn.BatchNorm1d(self.input_dim) |
|
self.bn_T_dec_layer1 = nn.BatchNorm1d(self.input_dim) |
|
|
|
if self.add_fc > 1: |
|
self.enc_fc_layer2 = nn.Linear(self.fc_dim, self.fc_dim) |
|
self.dec_fc_layer2 = nn.Linear(self.fc_dim, self.fc_dim) |
|
self.fc_output_dim = self.fc_dim |
|
|
|
if self.use_bn == 'shared': |
|
self.bn_enc_layer2 = nn.BatchNorm1d(self.fc_output_dim) |
|
self.bn_dec_layer2 = nn.BatchNorm1d(self.fc_dim) |
|
elif self.use_bn == 'separated': |
|
self.bn_S_enc_layer2 = nn.BatchNorm1d(self.fc_output_dim) |
|
self.bn_T_enc_layer2 = nn.BatchNorm1d(self.fc_output_dim) |
|
self.bn_S_dec_layer2 = nn.BatchNorm1d(self.fc_dim) |
|
self.bn_T_dec_layer2 = nn.BatchNorm1d(self.fc_dim) |
|
|
|
if self.add_fc > 2: |
|
self.enc_fc_layer3 = nn.Linear(self.fc_dim, self.fc_dim) |
|
self.dec_fc_layer3 = nn.Linear(self.fc_dim, self.fc_dim) |
|
self.fc_output_dim = self.fc_dim |
|
|
|
if self.use_bn == 'shared': |
|
self.bn_enc_layer3 = nn.BatchNorm1d(self.fc_output_dim) |
|
self.bn_dec_layer3 = nn.BatchNorm1d(self.fc_dim) |
|
elif self.use_bn == 'separated': |
|
self.bn_S_enc_layer3 = nn.BatchNorm1d(self.fc_output_dim) |
|
self.bn_T_enc_layer3 = nn.BatchNorm1d(self.fc_output_dim) |
|
self.bn_S_dec_layer3 = nn.BatchNorm1d(self.fc_dim) |
|
self.bn_T_dec_layer3 = nn.BatchNorm1d(self.fc_dim) |
|
|
|
self.z_2_out = nn.Linear(self.z_dim + self.f_dim, self.fc_output_dim) |
|
|
|
|
|
|
|
self.relu = nn.LeakyReLU(0.1) |
|
self.dropout_f = nn.Dropout(p=self.dropout_rate) |
|
self.dropout_v = nn.Dropout(p=self.dropout_rate) |
|
|
|
|
|
|
|
|
|
|
|
self.hidden_dim = opt.z_dim |
|
self.f_rnn_layers = opt.f_rnn_layers |
|
|
|
|
|
self.z_prior_lstm_ly1 = nn.LSTMCell(self.z_dim, self.hidden_dim) |
|
self.z_prior_lstm_ly2 = nn.LSTMCell(self.hidden_dim, self.hidden_dim) |
|
|
|
self.z_prior_mean = nn.Linear(self.hidden_dim, self.z_dim) |
|
self.z_prior_logvar = nn.Linear(self.hidden_dim, self.z_dim) |
|
|
|
|
|
|
|
self.z_lstm = nn.LSTM(self.fc_output_dim, self.hidden_dim, self.f_rnn_layers, bidirectional=True, batch_first=True) |
|
self.f_mean = nn.Linear(self.hidden_dim * 2, self.f_dim) |
|
self.f_logvar = nn.Linear(self.hidden_dim * 2, self.f_dim) |
|
|
|
self.z_rnn = nn.RNN(self.hidden_dim * 2, self.hidden_dim, batch_first=True) |
|
|
|
self.z_mean = nn.Linear(self.hidden_dim, self.z_dim) |
|
self.z_logvar = nn.Linear(self.hidden_dim, self.z_dim) |
|
|
|
|
|
|
|
|
|
|
|
self.fc_feature_domain_frame = nn.Linear(self.z_dim, self.z_dim) |
|
self.fc_classifier_domain_frame = nn.Linear(self.z_dim, 2) |
|
|
|
|
|
if self.frame_aggregation == 'rnn': |
|
self.bilstm = nn.LSTM(self.z_dim, self.z_dim * 2, self.f_rnn_layers, bidirectional=True, batch_first=True) |
|
self.feat_aggregated_dim = self.z_dim * 2 |
|
elif self.frame_aggregation == 'trn': |
|
self.num_bottleneck = 256 |
|
self.TRN = RelationModuleMultiScale(self.z_dim, self.num_bottleneck, self.frames) |
|
self.bn_trn_S = nn.BatchNorm1d(self.num_bottleneck) |
|
self.bn_trn_T = nn.BatchNorm1d(self.num_bottleneck) |
|
self.feat_aggregated_dim = self.num_bottleneck |
|
|
|
|
|
self.fc_feature_domain_video = nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim) |
|
self.fc_classifier_domain_video = nn.Linear(self.feat_aggregated_dim, 2) |
|
|
|
|
|
if self.frame_aggregation == 'trn': |
|
self.relation_domain_classifier_all = nn.ModuleList() |
|
for i in range(self.frames-1): |
|
relation_domain_classifier = nn.Sequential( |
|
nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim), |
|
nn.ReLU(), |
|
nn.Linear(self.feat_aggregated_dim, 2) |
|
) |
|
self.relation_domain_classifier_all += [relation_domain_classifier] |
|
|
|
|
|
self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class) |
|
|
|
|
|
self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim) |
|
self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2) |
|
|
|
|
|
if self.use_attn == 'general': |
|
self.attn_layer = nn.Sequential( |
|
nn.Linear(self.feat_aggregated_dim, self.feat_aggregated_dim), |
|
nn.Tanh(), |
|
nn.Linear(self.feat_aggregated_dim, 1) |
|
) |
|
|
|
def domain_classifier_frame(self, feat, beta): |
|
feat_fc_domain_frame = GradReverse.apply(feat, beta) |
|
feat_fc_domain_frame = self.fc_feature_domain_frame(feat_fc_domain_frame) |
|
feat_fc_domain_frame = self.relu(feat_fc_domain_frame) |
|
pred_fc_domain_frame = self.fc_classifier_domain_frame(feat_fc_domain_frame) |
|
return pred_fc_domain_frame |
|
|
|
def domain_classifier_video(self, feat_video, beta): |
|
feat_fc_domain_video = GradReverse.apply(feat_video, beta) |
|
feat_fc_domain_video = self.fc_feature_domain_video(feat_fc_domain_video) |
|
feat_fc_domain_video = self.relu(feat_fc_domain_video) |
|
pred_fc_domain_video = self.fc_classifier_domain_video(feat_fc_domain_video) |
|
return pred_fc_domain_video |
|
|
|
def domain_classifier_latent(self, f): |
|
feat_fc_domain_latent = self.fc_feature_domain_latent(f) |
|
feat_fc_domain_latent = self.relu(feat_fc_domain_latent) |
|
pred_fc_domain_latent = self.fc_classifier_doamin_latent(feat_fc_domain_latent) |
|
return pred_fc_domain_latent |
|
|
|
def domain_classifier_relation(self, feat_relation, beta): |
|
pred_fc_domain_relation_video = None |
|
for i in range(len(self.relation_domain_classifier_all)): |
|
feat_relation_single = feat_relation[:,i,:].squeeze(1) |
|
feat_fc_domain_relation_single = GradReverse.apply(feat_relation_single, beta) |
|
|
|
pred_fc_domain_relation_single = self.relation_domain_classifier_all[i](feat_fc_domain_relation_single) |
|
|
|
if pred_fc_domain_relation_video is None: |
|
pred_fc_domain_relation_video = pred_fc_domain_relation_single.view(-1,1,2) |
|
else: |
|
pred_fc_domain_relation_video = torch.cat((pred_fc_domain_relation_video, pred_fc_domain_relation_single.view(-1,1,2)), 1) |
|
|
|
pred_fc_domain_relation_video = pred_fc_domain_relation_video.view(-1,2) |
|
|
|
return pred_fc_domain_relation_video |
|
|
|
def get_trans_attn(self, pred_domain): |
|
softmax = nn.Softmax(dim=1) |
|
logsoftmax = nn.LogSoftmax(dim=1) |
|
entropy = torch.sum(-softmax(pred_domain) * logsoftmax(pred_domain), 1) |
|
weights = 1 - entropy |
|
return weights |
|
|
|
def get_general_attn(self, feat): |
|
num_segments = feat.size()[1] |
|
feat = feat.view(-1, feat.size()[-1]) |
|
weights = self.attn_layer(feat) |
|
weights = weights.view(-1, num_segments, weights.size()[-1]) |
|
weights = F.softmax(weights, dim=1) |
|
return weights |
|
|
|
def get_attn_feat_relation(self, feat_fc, pred_domain, num_segments): |
|
if self.use_attn == 'TransAttn': |
|
weights_attn = self.get_trans_attn(pred_domain) |
|
elif self.use_attn == 'general': |
|
weights_attn = self.get_general_attn(feat_fc) |
|
|
|
weights_attn = weights_attn.view(-1, num_segments-1, 1).repeat(1,1,feat_fc.size()[-1]) |
|
feat_fc_attn = (weights_attn+1) * feat_fc |
|
|
|
return feat_fc_attn, weights_attn[:,:,0] |
|
|
|
|
|
def encode_and_sample_post(self, x): |
|
if isinstance(x, list): |
|
conv_x = self.encoder_frame(x[0]) |
|
else: |
|
conv_x = self.encoder_frame(x) |
|
|
|
|
|
lstm_out, _ = self.z_lstm(conv_x) |
|
|
|
|
|
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim] |
|
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim] |
|
lstm_out_f = torch.cat((frontal, backward), dim=1) |
|
f_mean = self.f_mean(lstm_out_f) |
|
f_logvar = self.f_logvar(lstm_out_f) |
|
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False) |
|
|
|
|
|
features, _ = self.z_rnn(lstm_out) |
|
z_mean = self.z_mean(features) |
|
z_logvar = self.z_logvar(features) |
|
z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False) |
|
|
|
if isinstance(x, list): |
|
f_mean_list = [f_mean] |
|
f_post_list = [f_post] |
|
for t in range(1,3,1): |
|
conv_x = self.encoder_frame(x[t]) |
|
lstm_out, _ = self.z_lstm(conv_x) |
|
|
|
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim] |
|
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim] |
|
lstm_out_f = torch.cat((frontal, backward), dim=1) |
|
f_mean = self.f_mean(lstm_out_f) |
|
f_logvar = self.f_logvar(lstm_out_f) |
|
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False) |
|
f_mean_list.append(f_mean) |
|
f_post_list.append(f_post) |
|
f_mean = f_mean_list |
|
f_post = f_post_list |
|
|
|
return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post |
|
|
|
def decoder_frame(self,zf): |
|
if self.input_type == 'image': |
|
recon_x = self.decoder(zf) |
|
return recon_x |
|
|
|
if self.input_type == 'feature': |
|
zf = self.z_2_out(zf) |
|
zf = self.relu(zf) |
|
|
|
if self.add_fc > 2: |
|
zf = self.dec_fc_layer3(zf) |
|
if self.use_bn == 'shared': |
|
zf = self.bn_dec_layer3(zf) |
|
elif self.use_bn == 'separated': |
|
zf_src = self.bn_S_dec_layer3(zf[:self.batchsize,:,:]) |
|
zf_tar = self.bn_T_dec_layer3(zf[self.batchsize:,:,:]) |
|
zf = torch.cat([zf_src,zf_tar],axis=0) |
|
zf = self.relu(zf) |
|
|
|
if self.add_fc > 1: |
|
zf = self.dec_fc_layer2(zf) |
|
if self.use_bn == 'shared': |
|
zf = self.bn_dec_layer2(zf) |
|
elif self.use_bn == 'separated': |
|
zf_src = self.bn_S_dec_layer2(zf[:self.batchsize,:,:]) |
|
zf_tar = self.bn_T_dec_layer2(zf[self.batchsize:,:,:]) |
|
zf = torch.cat([zf_src,zf_tar],axis=0) |
|
zf = self.relu(zf) |
|
|
|
|
|
zf = self.dec_fc_layer1(zf) |
|
if self.use_bn == 'shared': |
|
zf = self.bn_dec_layer2(zf) |
|
elif self.use_bn == 'separated': |
|
zf_src = self.bn_S_dec_layer2(zf[:self.batchsize,:,:]) |
|
zf_tar = self.bn_T_dec_layer2(zf[self.batchsize:,:,:]) |
|
zf = torch.cat([zf_src,zf_tar],axis=0) |
|
recon_x = self.relu(zf) |
|
return recon_x |
|
|
|
def encoder_frame(self, x): |
|
if self.input_type == 'image': |
|
|
|
|
|
|
|
x_shape = x.shape |
|
x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1]) |
|
x_embed = self.encoder(x)[0] |
|
|
|
|
|
return x_embed.view(x_shape[0], x_shape[1], -1) |
|
|
|
|
|
if self.input_type == 'feature': |
|
|
|
x_embed = self.enc_fc_layer1(x) |
|
|
|
if self.use_bn == 'shared': |
|
x_embed = self.bn_enc_layer1(x_embed) |
|
elif self.use_bn == 'separated': |
|
x_embed_src = self.bn_S_enc_layer1(x_embed[:self.batchsize,:,:]) |
|
x_embed_tar = self.bn_T_enc_layer1(x_embed[self.batchsize:,:,:]) |
|
x_embed = torch.cat([x_embed_src,x_embed_tar],axis=0) |
|
x_embed = self.relu(x_embed) |
|
|
|
if self.add_fc > 1: |
|
x_embed = self.enc_fc_layer2(x_embed) |
|
if self.use_bn == 'shared': |
|
x_embed = self.bn_enc_layer2(x_embed) |
|
elif self.use_bn == 'separated': |
|
x_embed_src = self.bn_S_enc_layer2(x_embed[:self.batchsize,:,:]) |
|
x_embed_tar = self.bn_T_enc_layer2(x_embed[self.batchsize:,:,:]) |
|
x_embed = torch.cat([x_embed_src,x_embed_tar],axis=0) |
|
x_embed = self.relu(x_embed) |
|
|
|
if self.add_fc > 2: |
|
x_embed = self.enc_fc_layer3(x_embed) |
|
if self.use_bn == 'shared': |
|
x_embed = self.bn_enc_layer3(x_embed) |
|
elif self.use_bn == 'separated': |
|
x_embed_src = self.bn_S_enc_layer3(x_embed[:self.batchsize,:,:]) |
|
x_embed_tar = self.bn_T_enc_layer3(x_embed[self.batchsize:,:,:]) |
|
x_embed = torch.cat([x_embed_src,x_embed_tar],axis=0) |
|
x_embed = self.relu(x_embed) |
|
|
|
|
|
return x_embed |
|
|
|
|
|
def reparameterize(self, mean, logvar, random_sampling=True): |
|
|
|
if random_sampling is True: |
|
eps = torch.randn_like(logvar) |
|
std = torch.exp(0.5 * logvar) |
|
z = mean + eps * std |
|
return z |
|
else: |
|
return mean |
|
|
|
def sample_z_prior_train(self, z_post, random_sampling=True): |
|
z_out = None |
|
z_means = None |
|
z_logvars = None |
|
batch_size = z_post.shape[0] |
|
|
|
z_t = torch.zeros(batch_size, self.z_dim).cpu() |
|
h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu() |
|
c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu() |
|
h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu() |
|
c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu() |
|
|
|
for i in range(self.frames): |
|
|
|
h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1)) |
|
h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2)) |
|
|
|
z_mean_t = self.z_prior_mean(h_t_ly2) |
|
z_logvar_t = self.z_prior_logvar(h_t_ly2) |
|
z_prior = self.reparameterize(z_mean_t, z_logvar_t, random_sampling) |
|
if z_out is None: |
|
|
|
z_out = z_prior.unsqueeze(1) |
|
z_means = z_mean_t.unsqueeze(1) |
|
z_logvars = z_logvar_t.unsqueeze(1) |
|
else: |
|
|
|
z_out = torch.cat((z_out, z_prior.unsqueeze(1)), dim=1) |
|
z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1) |
|
z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1) |
|
z_t = z_post[:,i,:] |
|
return z_means, z_logvars, z_out |
|
|
|
|
|
def sample_z(self, batch_size, random_sampling=True): |
|
z_out = None |
|
z_means = None |
|
z_logvars = None |
|
|
|
|
|
z_t = torch.zeros(batch_size, self.z_dim).cpu() |
|
|
|
|
|
h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu() |
|
c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu() |
|
h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu() |
|
c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu() |
|
for _ in range(self.frames): |
|
|
|
|
|
h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1)) |
|
h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2)) |
|
|
|
z_mean_t = self.z_prior_mean(h_t_ly2) |
|
z_logvar_t = self.z_prior_logvar(h_t_ly2) |
|
z_t = self.reparameterize(z_mean_t, z_logvar_t, random_sampling) |
|
if z_out is None: |
|
|
|
z_out = z_t.unsqueeze(1) |
|
z_means = z_mean_t.unsqueeze(1) |
|
z_logvars = z_logvar_t.unsqueeze(1) |
|
else: |
|
|
|
z_out = torch.cat((z_out, z_t.unsqueeze(1)), dim=1) |
|
z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1) |
|
z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1) |
|
return z_means, z_logvars, z_out |
|
|
|
def forward(self, x, beta): |
|
|
|
f_mean, f_logvar, f_post, z_mean_post, z_logvar_post, z_post = self.encode_and_sample_post(x) |
|
if self.prior_sample == 'random': |
|
z_mean_prior, z_logvar_prior, z_prior = self.sample_z(z_post.size(0),random_sampling=False) |
|
elif self.prior_sample == 'post': |
|
z_mean_prior, z_logvar_prior, z_prior = self.sample_z_prior_train(z_post, random_sampling=False) |
|
|
|
|
|
if isinstance(f_post, list): |
|
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim) |
|
else: |
|
f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim) |
|
zf = torch.cat((z_post, f_expand), dim=2) |
|
|
|
|
|
recon_x = self.decoder_frame(zf) |
|
|
|
|
|
pred_domain_all = [] |
|
|
|
|
|
z_post_feat = z_post.view(-1, z_post.size()[-1]) |
|
z_post_feat = self.dropout_f(z_post_feat) |
|
pred_fc_domain_frame = self.domain_classifier_frame(z_post_feat, beta[2]) |
|
pred_fc_domain_frame = pred_fc_domain_frame.view((z_post.size(0), self.frames) + pred_fc_domain_frame.size()[-1:]) |
|
pred_domain_all.append(pred_fc_domain_frame) |
|
|
|
|
|
|
|
if self.frame_aggregation == 'rnn': |
|
self.bilstm.flatten_parameters() |
|
z_post_video_feat, _ = self.bilstm(z_post) |
|
backward = z_post_video_feat[:, 0, self.z_dim:2 * self.z_dim] |
|
frontal = z_post_video_feat[:, self.frames - 1, 0:self.z_dim] |
|
z_post_video_feat = torch.cat((frontal, backward), dim=1) |
|
pred_fc_domain_relation = [] |
|
pred_domain_all.append(pred_fc_domain_relation) |
|
|
|
elif self.frame_aggregation == 'trn': |
|
z_post_video_relation = self.TRN(z_post) |
|
|
|
|
|
pred_fc_domain_relation = self.domain_classifier_relation(z_post_video_relation, beta[0]) |
|
pred_domain_all.append(pred_fc_domain_relation.view((z_post.size(0), z_post_video_relation.size()[1]) + pred_fc_domain_relation.size()[-1:])) |
|
|
|
|
|
if self.use_attn != 'none': |
|
z_post_video_relation_attn, _ = self.get_attn_feat_relation(z_post_video_relation, pred_fc_domain_relation, self.frames) |
|
|
|
|
|
z_post_video_feat = torch.sum(z_post_video_relation_attn, 1) |
|
|
|
|
|
z_post_video_feat = self.dropout_v(z_post_video_feat) |
|
|
|
pred_fc_domain_video = self.domain_classifier_video(z_post_video_feat, beta[1]) |
|
pred_fc_domain_video = pred_fc_domain_video.view((z_post.size(0),) + pred_fc_domain_video.size()[-1:]) |
|
pred_domain_all.append(pred_fc_domain_video) |
|
|
|
|
|
|
|
pred_video_class = self.pred_classifier_video(z_post_video_feat) |
|
|
|
|
|
if isinstance(f_post, list): |
|
pred_fc_domain_latent = self.domain_classifier_latent(f_post[0]) |
|
else: |
|
pred_fc_domain_latent = self.domain_classifier_latent(f_post) |
|
pred_domain_all.append(pred_fc_domain_latent) |
|
|
|
return f_mean, f_logvar, f_post, z_mean_post, z_logvar_post, z_post, z_mean_prior, z_logvar_prior, z_prior, recon_x, pred_domain_all, pred_video_class |
|
|
|
|
|
def name2seq(file_name): |
|
images = [] |
|
|
|
for frame in range(8): |
|
frame_name = '%d' % (frame) |
|
image_filename = file_name + frame_name + '.png' |
|
image = imageio.imread(image_filename) |
|
images.append(image[:, :, :3]) |
|
|
|
images = np.asarray(images, dtype='f') / 256.0 |
|
images = images.transpose((0, 3, 1, 2)) |
|
print(images.shape) |
|
images = torch.Tensor(images).unsqueeze(dim=0) |
|
return images |
|
|
|
|
|
def display_gif(file_name, save_name): |
|
images = [] |
|
|
|
for frame in range(8): |
|
frame_name = '%d' % (frame) |
|
image_filename = file_name + frame_name + '.png' |
|
images.append(imageio.imread(image_filename)) |
|
|
|
gif_filename = 'avatar_source.gif' |
|
return imageio.mimsave(gif_filename, images) |
|
|
|
|
|
def display_gif_pad(file_name, save_name): |
|
images = [] |
|
|
|
for frame in range(8): |
|
frame_name = '%d' % (frame) |
|
image_filename = file_name + frame_name + '.png' |
|
image = imageio.imread(image_filename) |
|
image = image[:, :, :3] |
|
image_pad = cv2.copyMakeBorder(image, 0, 0, 125, 125, cv2.BORDER_CONSTANT, value=0) |
|
images.append(image_pad) |
|
|
|
return imageio.mimsave(save_name, images) |
|
|
|
|
|
def display_image(file_name): |
|
|
|
image_filename = file_name + '0' + '.png' |
|
print(image_filename) |
|
image = imageio.imread(image_filename) |
|
imageio.imwrite('image.png', image) |
|
|
|
|
|
def run(domain_source, action_source, hair_source, top_source, bottom_source, domain_target, action_target, hair_target, top_target, bottom_target): |
|
|
|
|
|
|
|
body_source = '0' |
|
|
|
|
|
if hair_source == "green": hair_source = '0' |
|
elif hair_source == "yellow": hair_source = '2' |
|
elif hair_source == "rose": hair_source = '4' |
|
elif hair_source == "red": hair_source = '7' |
|
elif hair_source == "wine": hair_source = '8' |
|
|
|
|
|
if top_source == "brown": top_source = '0' |
|
elif top_source == "blue": top_source = '1' |
|
elif top_source == "white": top_source = '2' |
|
|
|
|
|
if bottom_source == "white": bottom_source = '0' |
|
elif bottom_source == "golden": bottom_source = '1' |
|
elif bottom_source == "red": bottom_source = '2' |
|
elif bottom_source == "silver": bottom_source = '3' |
|
|
|
file_name_source = './Sprite/frames/domain_1/' + action_source + '/' |
|
file_name_source = file_name_source + 'front' + '_' + str(body_source) + str(bottom_source) + str(top_source) + str(hair_source) + '_' |
|
|
|
gif = display_gif_pad(file_name_source, 'avatar_source.gif') |
|
|
|
|
|
|
|
body_target = '1' |
|
|
|
|
|
if hair_target == "violet": hair_target = '1' |
|
elif hair_target == "silver": hair_target = '3' |
|
elif hair_target == "purple": hair_target = '5' |
|
elif hair_target == "grey": hair_target = '6' |
|
elif hair_target == "golden": hair_target = '9' |
|
|
|
|
|
if top_target == "grey": top_target = '3' |
|
elif top_target == "khaki": top_target = '4' |
|
elif top_target == "linen": top_target = '5' |
|
elif top_target == "ocre": top_target = '6' |
|
|
|
|
|
if bottom_target == "denim": bottom_target = '4' |
|
elif bottom_target == "olive": bottom_target = '5' |
|
elif bottom_target == "brown": bottom_target = '6' |
|
|
|
file_name_target = './Sprite/frames/domain_2/' + action_target + '/' |
|
file_name_target = file_name_target + 'front' + '_' + str(body_target) + str(bottom_target) + str(top_target) + str(hair_target) + '_' |
|
|
|
gif_target = display_gif_pad(file_name_target, 'avatar_target.gif') |
|
|
|
|
|
|
|
model = TransferVAE_Video(opt) |
|
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict']) |
|
model.eval() |
|
|
|
return 'demo.gif' |
|
|
|
|
|
gr.Interface( |
|
run, |
|
inputs=[ |
|
gr.Textbox(value="Source Avatar - Human", interactive=False), |
|
gr.Radio(choices=["slash", "spellcard", "walk"], value="slash"), |
|
gr.Radio(choices=["green", "yellow", "rose", "red", "wine"], value="green"), |
|
gr.Radio(choices=["brown", "blue", "white"], value="brown"), |
|
gr.Radio(choices=["white", "golden", "red", "silver"], value="white"), |
|
gr.Textbox(value="Target Avatar - Alien", interactive=False), |
|
gr.Radio(choices=["slash", "spellcard", "walk"], value="walk"), |
|
gr.Radio(choices=["violet", "silver", "purple", "grey", "golden"], value="golden"), |
|
gr.Radio(choices=["grey", "khaki", "linen", "ocre"], value="ocre"), |
|
gr.Radio(choices=["denim", "olive", "brown"], value="brown"), |
|
], |
|
outputs=[ |
|
gr.components.Image(type="file", label="Domain Disentanglement"), |
|
], |
|
live=True, |
|
title="TransferVAE for Unsupervised Video Domain Adaptation", |
|
).launch() |
|
|