File size: 693 Bytes
2492d81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from torch import nn
import torch
from torchvision import models
class KPDetector(nn.Module):
"""
Predict K*5 keypoints.
"""
def __init__(self, num_tps, **kwargs):
super(KPDetector, self).__init__()
self.num_tps = num_tps
self.fg_encoder = models.resnet18(pretrained=False)
num_features = self.fg_encoder.fc.in_features
self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
def forward(self, image):
fg_kp = self.fg_encoder(image)
bs, _, = fg_kp.shape
fg_kp = torch.sigmoid(fg_kp)
fg_kp = fg_kp * 2 - 1
out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)}
return out
|