sunshineatnoon
Add application file
1b2a9b1
raw
history blame
1.07 kB
import torch
from swapae.models.networks import BaseNetwork
from swapae.models.networks.pyramidnet import PyramidNet
class PyramidNetClassifier(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument("--pyramid_alpha", type=int, default=240)
parser.add_argument("--pyramid_depth", type=int, default=200)
return parser
def __init__(self, opt):
super().__init__(opt)
assert "cifar" in opt.dataset_mode
self.net = PyramidNet(
opt.dataset_mode, depth=opt.pyramid_depth, alpha=opt.pyramid_alpha,
num_classes=opt.num_classes, bottleneck=True)
mean = torch.tensor([x / 127.5 - 1.0 for x in [125.3, 123.0, 113.9]], dtype=torch.float)
std = torch.tensor([x / 127.5 for x in [63.0, 62.1, 66.7]], dtype=torch.float)
self.register_buffer("mean", mean[None, :, None, None])
self.register_buffer("std", std[None, :, None, None])
def forward(self, x):
x = (x - self.mean) / self.std
return self.net(x)