Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from modules.layers.smoothswap.resnet import resnet50 | |
class IdentityHead(nn.Module): | |
def __init__(self): | |
super(IdentityHead, self).__init__() | |
self.fc1 = nn.Sequential( | |
nn.Linear(512 * 4, 1024), | |
nn.BatchNorm1d(num_features=1024), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
) | |
self.fc2 = nn.Sequential( | |
nn.Linear(1024, 512), | |
nn.BatchNorm1d(num_features=512) | |
) | |
for m in self.modules(): | |
if isinstance(m, (nn.BatchNorm2d,)): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.fc2(x) | |
x = F.normalize(x) | |
return x | |
class IdentityEmbedder(nn.Module): | |
def __init__(self): | |
super(IdentityEmbedder, self).__init__() | |
self.backbone = resnet50(pretrained=False) | |
self.head = IdentityHead() | |
def forward(self, x_src): | |
x_src = self.backbone(x_src) | |
x_src = self.head(x_src) | |
return x_src | |
if __name__ == '__main__': | |
img = torch.randn((11, 3, 256, 256)).cuda() | |
net = IdentityEmbedder().cuda() | |
out = net(img) | |
print(out.shape) | |