Spaces:
Runtime error
Runtime error
from torch import nn | |
import torch | |
from torchvision import models | |
class KPDetector(nn.Module): | |
""" | |
Predict K*5 keypoints. | |
""" | |
def __init__(self, num_tps, **kwargs): | |
super(KPDetector, self).__init__() | |
self.num_tps = num_tps | |
self.fg_encoder = models.resnet18(pretrained=False) | |
num_features = self.fg_encoder.fc.in_features | |
self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2) | |
def forward(self, image): | |
fg_kp = self.fg_encoder(image) | |
bs, _, = fg_kp.shape | |
fg_kp = torch.sigmoid(fg_kp) | |
fg_kp = fg_kp * 2 - 1 | |
out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)} | |
return out | |