Spaces:
Runtime error
Runtime error
File size: 2,215 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
class BaseNetwork(torch.nn.Module):
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
def print_architecture(self, verbose=False):
name = type(self).__name__
result = '-------------------%s---------------------\n' % name
total_num_params = 0
for i, (name, child) in enumerate(self.named_children()):
num_params = sum([p.numel() for p in child.parameters()])
total_num_params += num_params
if verbose:
result += "%s: %3.3fM\n" % (name, (num_params / 1e6))
for i, (name, grandchild) in enumerate(child.named_children()):
num_params = sum([p.numel() for p in grandchild.parameters()])
if verbose:
result += "\t%s: %3.3fM\n" % (name, (num_params / 1e6))
result += '[Network %s] Total number of parameters : %.3f M\n' % (name, total_num_params / 1e6)
result += '-----------------------------------------------\n'
print(result)
def set_requires_grad(self, requires_grad):
for param in self.parameters():
param.requires_grad = requires_grad
def collect_parameters(self, name):
params = []
for m in self.modules():
if type(m).__name__ == name:
params += list(m.parameters())
return params
def fix_and_gather_noise_parameters(self):
params = []
device = next(self.parameters()).device
for m in self.modules():
if type(m).__name__ == "NoiseInjection":
assert m.image_size is not None, "One forward call should be made to determine size of noise parameters"
m.fixed_noise = torch.nn.Parameter(torch.randn(m.image_size[0], 1, m.image_size[2], m.image_size[3], device=device))
params.append(m.fixed_noise)
return params
def remove_noise_parameters(self, name):
for m in self.modules():
if type(m).__name__ == "NoiseInjection":
m.fixed_noise = None
def forward(self, x):
return x
|