Spaces:
Build error
Build error
| from modules.commons.common_layers import * | |
| import random | |
| class MixStyle(nn.Module): | |
| """MixStyle. | |
| Reference: | |
| Zhou et al. Domain Generalization with MixStyle. ICLR 2021. | |
| """ | |
| def __init__(self, p=0.5, alpha=0.1, eps=1e-6, hidden_size=256): | |
| """ | |
| Args: | |
| p (float): probability of using MixStyle. | |
| alpha (float): parameter of the Beta distribution. | |
| eps (float): scaling parameter to avoid numerical issues. | |
| mix (str): how to mix. | |
| """ | |
| super().__init__() | |
| self.p = p | |
| self.beta = torch.distributions.Beta(alpha, alpha) | |
| self.eps = eps | |
| self.alpha = alpha | |
| self._activated = True | |
| self.hidden_size = hidden_size | |
| self.affine_layer = LinearNorm( | |
| hidden_size, | |
| 2 * hidden_size, # For both b (bias) g (gain) | |
| ) | |
| def __repr__(self): | |
| return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})' | |
| def set_activation_status(self, status=True): | |
| self._activated = status | |
| def forward(self, x, spk_embed): | |
| if not self.training or not self._activated: | |
| return x | |
| if random.random() > self.p: | |
| return x | |
| B = x.size(0) | |
| mu, sig = torch.mean(x, dim=-1, keepdim=True), torch.std(x, dim=-1, keepdim=True) | |
| x_normed = (x - mu) / (sig + 1e-6) # [B, T, H_m] | |
| lmda = self.beta.sample((B, 1, 1)) | |
| lmda = lmda.to(x.device) | |
| # Get Bias and Gain | |
| mu1, sig1 = torch.split(self.affine_layer(spk_embed), self.hidden_size, dim=-1) # [B, 1, 2 * H_m] --> 2 * [B, 1, H_m] | |
| # MixStyle | |
| perm = torch.randperm(B) | |
| mu2, sig2 = mu1[perm], sig1[perm] | |
| mu_mix = mu1*lmda + mu2 * (1-lmda) | |
| sig_mix = sig1*lmda + sig2 * (1-lmda) | |
| # Perform Scailing and Shifting | |
| return sig_mix * x_normed + mu_mix # [B, T, H_m] | |