Spaces:
Runtime error
Runtime error
import torch | |
from swapae.models.networks import BaseNetwork | |
from swapae.models.networks.pyramidnet import PyramidNet | |
class PyramidNetClassifier(BaseNetwork): | |
def modify_commandline_options(parser, is_train): | |
parser.add_argument("--pyramid_alpha", type=int, default=240) | |
parser.add_argument("--pyramid_depth", type=int, default=200) | |
return parser | |
def __init__(self, opt): | |
super().__init__(opt) | |
assert "cifar" in opt.dataset_mode | |
self.net = PyramidNet( | |
opt.dataset_mode, depth=opt.pyramid_depth, alpha=opt.pyramid_alpha, | |
num_classes=opt.num_classes, bottleneck=True) | |
mean = torch.tensor([x / 127.5 - 1.0 for x in [125.3, 123.0, 113.9]], dtype=torch.float) | |
std = torch.tensor([x / 127.5 for x in [63.0, 62.1, 66.7]], dtype=torch.float) | |
self.register_buffer("mean", mean[None, :, None, None]) | |
self.register_buffer("std", std[None, :, None, None]) | |
def forward(self, x): | |
x = (x - self.mean) / self.std | |
return self.net(x) | |