Update app.py
Browse files
app.py
CHANGED
|
@@ -59,18 +59,6 @@ class RelationModuleMultiScale(torch.nn.Module):
|
|
| 59 |
return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
|
| 60 |
|
| 61 |
|
| 62 |
-
class GradReverse(Function):
|
| 63 |
-
@staticmethod
|
| 64 |
-
def forward(ctx, x, beta):
|
| 65 |
-
ctx.beta = beta
|
| 66 |
-
return x.view_as(x)
|
| 67 |
-
|
| 68 |
-
@staticmethod
|
| 69 |
-
def backward(ctx, grad_output):
|
| 70 |
-
grad_input = grad_output.neg() * ctx.beta
|
| 71 |
-
return grad_input, None
|
| 72 |
-
|
| 73 |
-
|
| 74 |
class TransferVAE_Video(nn.Module):
|
| 75 |
|
| 76 |
def __init__(self):
|
|
@@ -133,86 +121,18 @@ class TransferVAE_Video(nn.Module):
|
|
| 133 |
self.relation_domain_classifier_all += [relation_domain_classifier]
|
| 134 |
|
| 135 |
self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class)
|
| 136 |
-
|
| 137 |
self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim)
|
| 138 |
self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2)
|
| 139 |
|
| 140 |
-
|
| 141 |
-
def domain_classifier_frame(self, feat, beta):
|
| 142 |
-
feat_fc_domain_frame = GradReverse.apply(feat, beta)
|
| 143 |
-
feat_fc_domain_frame = self.fc_feature_domain_frame(feat_fc_domain_frame)
|
| 144 |
-
feat_fc_domain_frame = self.relu(feat_fc_domain_frame)
|
| 145 |
-
pred_fc_domain_frame = self.fc_classifier_domain_frame(feat_fc_domain_frame)
|
| 146 |
-
return pred_fc_domain_frame
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def domain_classifier_video(self, feat_video, beta):
|
| 150 |
-
feat_fc_domain_video = GradReverse.apply(feat_video, beta)
|
| 151 |
-
feat_fc_domain_video = self.fc_feature_domain_video(feat_fc_domain_video)
|
| 152 |
-
feat_fc_domain_video = self.relu(feat_fc_domain_video)
|
| 153 |
-
pred_fc_domain_video = self.fc_classifier_domain_video(feat_fc_domain_video)
|
| 154 |
-
return pred_fc_domain_video
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def domain_classifier_latent(self, f):
|
| 158 |
-
feat_fc_domain_latent = self.fc_feature_domain_latent(f)
|
| 159 |
-
feat_fc_domain_latent = self.relu(feat_fc_domain_latent)
|
| 160 |
-
pred_fc_domain_latent = self.fc_classifier_doamin_latent(feat_fc_domain_latent)
|
| 161 |
-
return pred_fc_domain_latent
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def domain_classifier_relation(self, feat_relation, beta):
|
| 165 |
-
pred_fc_domain_relation_video = None
|
| 166 |
-
for i in range(len(self.relation_domain_classifier_all)):
|
| 167 |
-
feat_relation_single = feat_relation[:,i,:].squeeze(1)
|
| 168 |
-
feat_fc_domain_relation_single = GradReverse.apply(feat_relation_single, beta)
|
| 169 |
-
|
| 170 |
-
pred_fc_domain_relation_single = self.relation_domain_classifier_all[i](feat_fc_domain_relation_single)
|
| 171 |
-
|
| 172 |
-
if pred_fc_domain_relation_video is None:
|
| 173 |
-
pred_fc_domain_relation_video = pred_fc_domain_relation_single.view(-1,1,2)
|
| 174 |
-
else:
|
| 175 |
-
pred_fc_domain_relation_video = torch.cat((pred_fc_domain_relation_video, pred_fc_domain_relation_single.view(-1,1,2)), 1)
|
| 176 |
-
|
| 177 |
-
pred_fc_domain_relation_video = pred_fc_domain_relation_video.view(-1,2)
|
| 178 |
-
|
| 179 |
-
return pred_fc_domain_relation_video
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
def get_trans_attn(self, pred_domain):
|
| 183 |
-
softmax = nn.Softmax(dim=1)
|
| 184 |
-
logsoftmax = nn.LogSoftmax(dim=1)
|
| 185 |
-
entropy = torch.sum(-softmax(pred_domain) * logsoftmax(pred_domain), 1)
|
| 186 |
-
weights = 1 - entropy
|
| 187 |
-
return weights
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def get_general_attn(self, feat):
|
| 191 |
-
num_segments = feat.size()[1]
|
| 192 |
-
feat = feat.view(-1, feat.size()[-1]) # reshape features: 128x4x256 --> (128x4)x256
|
| 193 |
-
weights = self.attn_layer(feat) # e.g. (128x4)x1
|
| 194 |
-
weights = weights.view(-1, num_segments, weights.size()[-1]) # reshape attention weights: (128x4)x1 --> 128x4x1
|
| 195 |
-
weights = F.softmax(weights, dim=1) # softmax over segments ==> 128x4x1
|
| 196 |
-
return weights
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
def get_attn_feat_relation(self, feat_fc, pred_domain, num_segments):
|
| 200 |
-
weights_attn = self.get_trans_attn(pred_domain)
|
| 201 |
-
weights_attn = weights_attn.view(-1, num_segments-1, 1).repeat(1,1,feat_fc.size()[-1]) # reshape & repeat weights (e.g. 16 x 4 x 256)
|
| 202 |
-
feat_fc_attn = (weights_attn+1) * feat_fc
|
| 203 |
-
return feat_fc_attn, weights_attn[:,:,0]
|
| 204 |
-
|
| 205 |
|
| 206 |
def encode_and_sample_post(self, x):
|
| 207 |
if isinstance(x, list):
|
| 208 |
conv_x = self.encoder_frame(x[0])
|
| 209 |
else:
|
| 210 |
conv_x = self.encoder_frame(x)
|
| 211 |
-
|
| 212 |
-
# pass the bidirectional lstm
|
| 213 |
lstm_out, _ = self.z_lstm(conv_x)
|
| 214 |
-
|
| 215 |
-
# get f:
|
| 216 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
| 217 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
| 218 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
|
@@ -220,7 +140,6 @@ class TransferVAE_Video(nn.Module):
|
|
| 220 |
f_logvar = self.f_logvar(lstm_out_f)
|
| 221 |
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
| 222 |
|
| 223 |
-
# pass to one direction rnn
|
| 224 |
features, _ = self.z_rnn(lstm_out)
|
| 225 |
z_mean = self.z_mean(features)
|
| 226 |
z_logvar = self.z_logvar(features)
|
|
@@ -232,7 +151,6 @@ class TransferVAE_Video(nn.Module):
|
|
| 232 |
for t in range(1,3,1):
|
| 233 |
conv_x = self.encoder_frame(x[t])
|
| 234 |
lstm_out, _ = self.z_lstm(conv_x)
|
| 235 |
-
# get f:
|
| 236 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
| 237 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
| 238 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
|
@@ -243,7 +161,6 @@ class TransferVAE_Video(nn.Module):
|
|
| 243 |
f_post_list.append(f_post)
|
| 244 |
f_mean = f_mean_list
|
| 245 |
f_post = f_post_list
|
| 246 |
-
# f_mean and f_post are list if triple else not
|
| 247 |
return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
|
| 248 |
|
| 249 |
|
|
@@ -260,7 +177,6 @@ class TransferVAE_Video(nn.Module):
|
|
| 260 |
|
| 261 |
|
| 262 |
def reparameterize(self, mean, logvar, random_sampling=True):
|
| 263 |
-
# Reparametrization occurs only if random sampling is set to true, otherwise mean is returned
|
| 264 |
if random_sampling is True:
|
| 265 |
eps = torch.randn_like(logvar)
|
| 266 |
std = torch.exp(0.5 * logvar)
|
|
@@ -269,88 +185,20 @@ class TransferVAE_Video(nn.Module):
|
|
| 269 |
else:
|
| 270 |
return mean
|
| 271 |
|
| 272 |
-
def sample_z_prior_train(self, z_post, random_sampling=True):
|
| 273 |
-
z_out = None
|
| 274 |
-
z_means = None
|
| 275 |
-
z_logvars = None
|
| 276 |
-
batch_size = z_post.shape[0]
|
| 277 |
-
|
| 278 |
-
z_t = torch.zeros(batch_size, self.z_dim).cpu()
|
| 279 |
-
h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
| 280 |
-
c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
| 281 |
-
h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
| 282 |
-
c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
| 283 |
-
|
| 284 |
-
for i in range(self.frames):
|
| 285 |
-
# two layer LSTM and two one-layer FC
|
| 286 |
-
h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1))
|
| 287 |
-
h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2))
|
| 288 |
-
|
| 289 |
-
z_mean_t = self.z_prior_mean(h_t_ly2)
|
| 290 |
-
z_logvar_t = self.z_prior_logvar(h_t_ly2)
|
| 291 |
-
z_prior = self.reparameterize(z_mean_t, z_logvar_t, random_sampling)
|
| 292 |
-
if z_out is None:
|
| 293 |
-
# If z_out is none it means z_t is z_1, hence store it in the format [batch_size, 1, z_dim]
|
| 294 |
-
z_out = z_prior.unsqueeze(1)
|
| 295 |
-
z_means = z_mean_t.unsqueeze(1)
|
| 296 |
-
z_logvars = z_logvar_t.unsqueeze(1)
|
| 297 |
-
else:
|
| 298 |
-
# If z_out is not none, z_t is not the initial z and hence append it to the previous z_ts collected in z_out
|
| 299 |
-
z_out = torch.cat((z_out, z_prior.unsqueeze(1)), dim=1)
|
| 300 |
-
z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1)
|
| 301 |
-
z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1)
|
| 302 |
-
z_t = z_post[:,i,:]
|
| 303 |
-
return z_means, z_logvars, z_out
|
| 304 |
-
|
| 305 |
-
# If random sampling is true, reparametrization occurs else z_t is just set to the mean
|
| 306 |
-
def sample_z(self, batch_size, random_sampling=True):
|
| 307 |
-
z_out = None # This will ultimately store all z_s in the format [batch_size, frames, z_dim]
|
| 308 |
-
z_means = None
|
| 309 |
-
z_logvars = None
|
| 310 |
-
|
| 311 |
-
# All states are initially set to 0, especially z_0 = 0
|
| 312 |
-
z_t = torch.zeros(batch_size, self.z_dim).cpu()
|
| 313 |
-
# z_mean_t = torch.zeros(batch_size, self.z_dim)
|
| 314 |
-
# z_logvar_t = torch.zeros(batch_size, self.z_dim)
|
| 315 |
-
h_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
| 316 |
-
c_t_ly1 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
| 317 |
-
h_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
| 318 |
-
c_t_ly2 = torch.zeros(batch_size, self.hidden_dim).cpu()
|
| 319 |
-
for _ in range(self.frames):
|
| 320 |
-
# h_t, c_t = self.z_prior_lstm(z_t, (h_t, c_t))
|
| 321 |
-
# two layer LSTM and two one-layer FC
|
| 322 |
-
h_t_ly1, c_t_ly1 = self.z_prior_lstm_ly1(z_t, (h_t_ly1, c_t_ly1))
|
| 323 |
-
h_t_ly2, c_t_ly2 = self.z_prior_lstm_ly2(h_t_ly1, (h_t_ly2, c_t_ly2))
|
| 324 |
-
|
| 325 |
-
z_mean_t = self.z_prior_mean(h_t_ly2)
|
| 326 |
-
z_logvar_t = self.z_prior_logvar(h_t_ly2)
|
| 327 |
-
z_t = self.reparameterize(z_mean_t, z_logvar_t, random_sampling)
|
| 328 |
-
if z_out is None:
|
| 329 |
-
# If z_out is none it means z_t is z_1, hence store it in the format [batch_size, 1, z_dim]
|
| 330 |
-
z_out = z_t.unsqueeze(1)
|
| 331 |
-
z_means = z_mean_t.unsqueeze(1)
|
| 332 |
-
z_logvars = z_logvar_t.unsqueeze(1)
|
| 333 |
-
else:
|
| 334 |
-
# If z_out is not none, z_t is not the initial z and hence append it to the previous z_ts collected in z_out
|
| 335 |
-
z_out = torch.cat((z_out, z_t.unsqueeze(1)), dim=1)
|
| 336 |
-
z_means = torch.cat((z_means, z_mean_t.unsqueeze(1)), dim=1)
|
| 337 |
-
z_logvars = torch.cat((z_logvars, z_logvar_t.unsqueeze(1)), dim=1)
|
| 338 |
-
return z_means, z_logvars, z_out
|
| 339 |
|
| 340 |
def forward(self, x, beta):
|
| 341 |
_, _, f_post, _, _, z_post = self.encode_and_sample_post(x)
|
| 342 |
-
|
| 343 |
if isinstance(f_post, list):
|
| 344 |
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
| 345 |
else:
|
| 346 |
f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
| 347 |
zf = torch.cat((z_post, f_expand), dim=2)
|
| 348 |
-
|
| 349 |
recon_x = self.decoder_frame(zf)
|
| 350 |
-
|
| 351 |
return f_post, z_post, recon_x
|
| 352 |
|
| 353 |
|
|
|
|
|
|
|
| 354 |
def name2seq(file_name):
|
| 355 |
images = []
|
| 356 |
|
|
@@ -520,7 +368,7 @@ def run(domain_source, action_source, hair_source, top_source, bottom_source, do
|
|
| 520 |
|
| 521 |
# == Forward ==
|
| 522 |
with torch.no_grad():
|
| 523 |
-
|
| 524 |
|
| 525 |
src_orig_sample = x[0, :, :, :, :]
|
| 526 |
src_recon_sample = recon_x[0, :, :, :, :]
|
|
|
|
| 59 |
return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation))
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
class TransferVAE_Video(nn.Module):
|
| 63 |
|
| 64 |
def __init__(self):
|
|
|
|
| 121 |
self.relation_domain_classifier_all += [relation_domain_classifier]
|
| 122 |
|
| 123 |
self.pred_classifier_video = nn.Linear(self.feat_aggregated_dim, self.num_class)
|
|
|
|
| 124 |
self.fc_feature_domain_latent = nn.Linear(self.f_dim, self.f_dim)
|
| 125 |
self.fc_classifier_doamin_latent = nn.Linear(self.f_dim, 2)
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
def encode_and_sample_post(self, x):
|
| 129 |
if isinstance(x, list):
|
| 130 |
conv_x = self.encoder_frame(x[0])
|
| 131 |
else:
|
| 132 |
conv_x = self.encoder_frame(x)
|
| 133 |
+
|
|
|
|
| 134 |
lstm_out, _ = self.z_lstm(conv_x)
|
| 135 |
+
|
|
|
|
| 136 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
| 137 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
| 138 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
|
|
|
| 140 |
f_logvar = self.f_logvar(lstm_out_f)
|
| 141 |
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
| 142 |
|
|
|
|
| 143 |
features, _ = self.z_rnn(lstm_out)
|
| 144 |
z_mean = self.z_mean(features)
|
| 145 |
z_logvar = self.z_logvar(features)
|
|
|
|
| 151 |
for t in range(1,3,1):
|
| 152 |
conv_x = self.encoder_frame(x[t])
|
| 153 |
lstm_out, _ = self.z_lstm(conv_x)
|
|
|
|
| 154 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
| 155 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
| 156 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
|
|
|
| 161 |
f_post_list.append(f_post)
|
| 162 |
f_mean = f_mean_list
|
| 163 |
f_post = f_post_list
|
|
|
|
| 164 |
return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
|
| 165 |
|
| 166 |
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
def reparameterize(self, mean, logvar, random_sampling=True):
|
|
|
|
| 180 |
if random_sampling is True:
|
| 181 |
eps = torch.randn_like(logvar)
|
| 182 |
std = torch.exp(0.5 * logvar)
|
|
|
|
| 185 |
else:
|
| 186 |
return mean
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
def forward(self, x, beta):
|
| 190 |
_, _, f_post, _, _, z_post = self.encode_and_sample_post(x)
|
|
|
|
| 191 |
if isinstance(f_post, list):
|
| 192 |
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
| 193 |
else:
|
| 194 |
f_expand = f_post.unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
| 195 |
zf = torch.cat((z_post, f_expand), dim=2)
|
|
|
|
| 196 |
recon_x = self.decoder_frame(zf)
|
|
|
|
| 197 |
return f_post, z_post, recon_x
|
| 198 |
|
| 199 |
|
| 200 |
+
|
| 201 |
+
|
| 202 |
def name2seq(file_name):
|
| 203 |
images = []
|
| 204 |
|
|
|
|
| 368 |
|
| 369 |
# == Forward ==
|
| 370 |
with torch.no_grad():
|
| 371 |
+
f_post, z_post, recon_x = model(x, [0]*3)
|
| 372 |
|
| 373 |
src_orig_sample = x[0, :, :, :, :]
|
| 374 |
src_recon_sample = recon_x[0, :, :, :, :]
|