|
from torch import nn |
|
import torch |
|
from torchvision import models |
|
|
|
class KPDetector(nn.Module): |
|
""" |
|
Predict K*5 keypoints. |
|
""" |
|
|
|
def __init__(self, num_tps, **kwargs): |
|
super(KPDetector, self).__init__() |
|
self.num_tps = num_tps |
|
|
|
self.fg_encoder = models.resnet18(pretrained=False) |
|
num_features = self.fg_encoder.fc.in_features |
|
self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2) |
|
|
|
|
|
def forward(self, image): |
|
|
|
fg_kp = self.fg_encoder(image) |
|
bs, _, = fg_kp.shape |
|
fg_kp = torch.sigmoid(fg_kp) |
|
fg_kp = fg_kp * 2 - 1 |
|
out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)} |
|
|
|
return out |
|
|