MobileStyleGAN / model.py
hysts's picture
hysts HF staff
Update
83321d4
raw
history blame
1.63 kB
import sys
import torch
import torch.nn as nn
sys.path.insert(0, 'MobileStyleGAN.pytorch')
from core.models.mapping_network import MappingNetwork
from core.models.mobile_synthesis_network import MobileSynthesisNetwork
from core.models.synthesis_network import SynthesisNetwork
class Model(nn.Module):
def __init__(self):
super().__init__()
# teacher model
mapping_net_params = {'style_dim': 512, 'n_layers': 8, 'lr_mlp': 0.01}
synthesis_net_params = {
'size': 1024,
'style_dim': 512,
'blur_kernel': [1, 3, 3, 1],
'channels': [512, 512, 512, 512, 512, 256, 128, 64, 32]
}
self.mapping_net = MappingNetwork(**mapping_net_params).eval()
self.synthesis_net = SynthesisNetwork(**synthesis_net_params).eval()
# student network
self.student = MobileSynthesisNetwork(
style_dim=self.mapping_net.style_dim,
channels=synthesis_net_params['channels'][:-1])
self.style_mean = nn.Parameter(torch.zeros((1, 512)),
requires_grad=False)
def forward(self,
var: torch.Tensor,
truncation_psi: float = 0.5,
generator: str = 'student') -> torch.Tensor:
style = self.mapping_net(var)
style = self.style_mean + truncation_psi * (style - self.style_mean)
if generator == 'student':
img = self.student(style)['img']
elif generator == 'teacher':
img = self.synthesis_net(style)['img']
else:
raise ValueError
return img