Spaces:
Running
on
T4
Running
on
T4
import math | |
import torch.nn as nn | |
class Upsample(nn.Module): | |
"""Upsample module. | |
Args: | |
scale (int): Scale factor. Supported scales: 2^n and 3. | |
num_feat (int): Channel number of intermediate features. | |
""" | |
def __init__(self, scale, num_feat): | |
super(Upsample, self).__init__() | |
m = [] | |
if (scale & (scale - 1)) == 0: # scale = 2^n | |
for _ in range(int(math.log(scale, 2))): | |
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) | |
m.append(nn.PixelShuffle(2)) | |
elif scale == 3: | |
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) | |
m.append(nn.PixelShuffle(3)) | |
else: | |
raise ValueError( | |
f"scale {scale} is not supported. " "Supported scales: 2^n and 3." | |
) | |
self.up = nn.Sequential(*m) | |
def forward(self, x): | |
return self.up(x) | |
class UpsampleOneStep(nn.Module): | |
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) | |
Used in lightweight SR to save parameters. | |
Args: | |
scale (int): Scale factor. Supported scales: 2^n and 3. | |
num_feat (int): Channel number of intermediate features. | |
""" | |
def __init__(self, scale, num_feat, num_out_ch): | |
super(UpsampleOneStep, self).__init__() | |
self.num_feat = num_feat | |
m = [] | |
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) | |
m.append(nn.PixelShuffle(scale)) | |
self.up = nn.Sequential(*m) | |
def forward(self, x): | |
return self.up(x) | |