|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
class AVDNetwork(nn.Module): |
|
""" |
|
Animation via Disentanglement network |
|
""" |
|
|
|
def __init__(self, num_tps, id_bottle_size=64, pose_bottle_size=64): |
|
super(AVDNetwork, self).__init__() |
|
input_size = 5*2 * num_tps |
|
self.num_tps = num_tps |
|
|
|
self.id_encoder = nn.Sequential( |
|
nn.Linear(input_size, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(256, 512), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(512, 1024), |
|
nn.BatchNorm1d(1024), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(1024, id_bottle_size) |
|
) |
|
|
|
self.pose_encoder = nn.Sequential( |
|
nn.Linear(input_size, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(256, 512), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(512, 1024), |
|
nn.BatchNorm1d(1024), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(1024, pose_bottle_size) |
|
) |
|
|
|
self.decoder = nn.Sequential( |
|
nn.Linear(pose_bottle_size + id_bottle_size, 1024), |
|
nn.BatchNorm1d(1024), |
|
nn.ReLU(), |
|
nn.Linear(1024, 512), |
|
nn.BatchNorm1d(512), |
|
nn.ReLU(), |
|
nn.Linear(512, 256), |
|
nn.BatchNorm1d(256), |
|
nn.ReLU(), |
|
nn.Linear(256, input_size) |
|
) |
|
|
|
def forward(self, kp_source, kp_random): |
|
|
|
bs = kp_source['fg_kp'].shape[0] |
|
|
|
pose_emb = self.pose_encoder(kp_random['fg_kp'].view(bs, -1)) |
|
id_emb = self.id_encoder(kp_source['fg_kp'].view(bs, -1)) |
|
|
|
rec = self.decoder(torch.cat([pose_emb, id_emb], dim=1)) |
|
|
|
rec = {'fg_kp': rec.view(bs, self.num_tps*5, -1)} |
|
return rec |
|
|