gavinyuan
udpate: app.py import FSGenerator
a104d3f
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.layers.smoothswap.resnet import resnet50
class IdentityHead(nn.Module):
def __init__(self):
super(IdentityHead, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(512 * 4, 1024),
nn.BatchNorm1d(num_features=1024),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
self.fc2 = nn.Sequential(
nn.Linear(1024, 512),
nn.BatchNorm1d(num_features=512)
)
for m in self.modules():
if isinstance(m, (nn.BatchNorm2d,)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = F.normalize(x)
return x
class IdentityEmbedder(nn.Module):
def __init__(self):
super(IdentityEmbedder, self).__init__()
self.backbone = resnet50(pretrained=False)
self.head = IdentityHead()
def forward(self, x_src):
x_src = self.backbone(x_src)
x_src = self.head(x_src)
return x_src
if __name__ == '__main__':
img = torch.randn((11, 3, 256, 256)).cuda()
net = IdentityEmbedder().cuda()
out = net(img)
print(out.shape)