Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
''' | |
# -------------------------------------------- | |
# Kai Zhang (github: https://github.com/cszn) | |
# 03/Mar/2019 | |
# -------------------------------------------- | |
''' | |
# -------------------------------------------- | |
# SVD Orthogonal Regularization | |
# -------------------------------------------- | |
def regularizer_orth(m): | |
""" | |
# ---------------------------------------- | |
# SVD Orthogonal Regularization | |
# ---------------------------------------- | |
# Applies regularization to the training by performing the | |
# orthogonalization technique described in the paper | |
# This function is to be called by the torch.nn.Module.apply() method, | |
# which applies svd_orthogonalization() to every layer of the model. | |
# usage: net.apply(regularizer_orth) | |
# ---------------------------------------- | |
""" | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
w = m.weight.data.clone() | |
c_out, c_in, f1, f2 = w.size() | |
# dtype = m.weight.data.type() | |
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) | |
# self.netG.apply(svd_orthogonalization) | |
u, s, v = torch.svd(w) | |
s[s > 1.5] = s[s > 1.5] - 1e-4 | |
s[s < 0.5] = s[s < 0.5] + 1e-4 | |
w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) | |
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) | |
else: | |
pass | |
# -------------------------------------------- | |
# SVD Orthogonal Regularization | |
# -------------------------------------------- | |
def regularizer_orth2(m): | |
""" | |
# ---------------------------------------- | |
# Applies regularization to the training by performing the | |
# orthogonalization technique described in the paper | |
# This function is to be called by the torch.nn.Module.apply() method, | |
# which applies svd_orthogonalization() to every layer of the model. | |
# usage: net.apply(regularizer_orth2) | |
# ---------------------------------------- | |
""" | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
w = m.weight.data.clone() | |
c_out, c_in, f1, f2 = w.size() | |
# dtype = m.weight.data.type() | |
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) | |
u, s, v = torch.svd(w) | |
s_mean = s.mean() | |
s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4 | |
s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4 | |
w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) | |
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) | |
else: | |
pass | |
def regularizer_clip(m): | |
""" | |
# ---------------------------------------- | |
# usage: net.apply(regularizer_clip) | |
# ---------------------------------------- | |
""" | |
eps = 1e-4 | |
c_min = -1.5 | |
c_max = 1.5 | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1 or classname.find('Linear') != -1: | |
w = m.weight.data.clone() | |
w[w > c_max] -= eps | |
w[w < c_min] += eps | |
m.weight.data = w | |
if m.bias is not None: | |
b = m.bias.data.clone() | |
b[b > c_max] -= eps | |
b[b < c_min] += eps | |
m.bias.data = b | |
# elif classname.find('BatchNorm2d') != -1: | |
# | |
# rv = m.running_var.data.clone() | |
# rm = m.running_mean.data.clone() | |
# | |
# if m.affine: | |
# m.weight.data | |
# m.bias.data | |