Spaces:
Build error
Build error
File size: 1,631 Bytes
518857d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import sys
import torch
import torch.nn as nn
sys.path.insert(0, 'MobileStyleGAN.pytorch')
from core.models.mapping_network import MappingNetwork
from core.models.mobile_synthesis_network import MobileSynthesisNetwork
from core.models.synthesis_network import SynthesisNetwork
class Model(nn.Module):
def __init__(self):
super().__init__()
# teacher model
mapping_net_params = {'style_dim': 512, 'n_layers': 8, 'lr_mlp': 0.01}
synthesis_net_params = {
'size': 1024,
'style_dim': 512,
'blur_kernel': [1, 3, 3, 1],
'channels': [512, 512, 512, 512, 512, 256, 128, 64, 32]
}
self.mapping_net = MappingNetwork(**mapping_net_params).eval()
self.synthesis_net = SynthesisNetwork(**synthesis_net_params).eval()
# student network
self.student = MobileSynthesisNetwork(
style_dim=self.mapping_net.style_dim,
channels=synthesis_net_params['channels'][:-1])
self.style_mean = nn.Parameter(torch.zeros((1, 512)),
requires_grad=False)
def forward(self,
var: torch.Tensor,
truncation_psi: float = 0.5,
generator: str = 'student') -> torch.Tensor:
style = self.mapping_net(var)
style = self.style_mean + truncation_psi * (style - self.style_mean)
if generator == 'student':
img = self.student(style)['img']
elif generator == 'teacher':
img = self.synthesis_net(style)['img']
else:
raise ValueError
return img
|