LambdaSuperRes / KAIR /models /network_feature.py
cooperll
LambdaSuperRes initial commit
2514fb4
raw
history blame
1.59 kB
import torch
import torch.nn as nn
import torchvision
"""
# --------------------------------------------
# VGG Feature Extractor
# --------------------------------------------
"""
# --------------------------------------------
# VGG features
# Assume input range is [0, 1]
# --------------------------------------------
class VGGFeatureExtractor(nn.Module):
def __init__(self,
feature_layer=34,
use_bn=False,
use_input_norm=True,
device=torch.device('cpu')):
super(VGGFeatureExtractor, self).__init__()
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=True)
else:
model = torchvision.models.vgg19(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output