Spaces:
Runtime error
Runtime error
'''Various miscellaneous modules | |
''' | |
import torch | |
class View(torch.nn.Module): | |
"""Basic reshape module. | |
""" | |
def __init__(self, *shape): | |
""" | |
Args: | |
*shape: Input shape. | |
""" | |
super().__init__() | |
self.shape = shape | |
def forward(self, input): | |
"""Reshapes tensor. | |
Args: | |
input: Input tensor. | |
Returns: | |
torch.Tensor: Flattened tensor. | |
""" | |
return input.view(*self.shape) | |
class Unfold(torch.nn.Module): | |
"""Module for unfolding tensor. | |
Performs strided crops on 2d (image) tensors. Stride is assumed to be half the crop size. | |
""" | |
def __init__(self, img_size, fold_size): | |
""" | |
Args: | |
img_size: Input size. | |
fold_size: Crop size. | |
""" | |
super().__init__() | |
fold_stride = fold_size // 2 | |
self.fold_size = fold_size | |
self.fold_stride = fold_stride | |
self.n_locs = 2 * (img_size // fold_size) - 1 | |
self.unfold = torch.nn.Unfold((self.fold_size, self.fold_size), | |
stride=(self.fold_stride, self.fold_stride)) | |
def forward(self, x): | |
"""Unfolds tensor. | |
Args: | |
x: Input tensor. | |
Returns: | |
torch.Tensor: Unfolded tensor. | |
""" | |
N = x.size(0) | |
x = self.unfold(x).reshape(N, -1, self.fold_size, self.fold_size, self.n_locs * self.n_locs)\ | |
.permute(0, 4, 1, 2, 3)\ | |
.reshape(N * self.n_locs * self.n_locs, -1, self.fold_size, self.fold_size) | |
return x | |
class Fold(torch.nn.Module): | |
"""Module (re)folding tensor. | |
Undoes the strided crops above. Works only on 1x1. | |
""" | |
def __init__(self, img_size, fold_size): | |
""" | |
Args: | |
img_size: Images size. | |
fold_size: Crop size. | |
""" | |
super().__init__() | |
self.n_locs = 2 * (img_size // fold_size) - 1 | |
def forward(self, x): | |
"""(Re)folds tensor. | |
Args: | |
x: Input tensor. | |
Returns: | |
torch.Tensor: Refolded tensor. | |
""" | |
dim_c, dim_x, dim_y = x.size()[1:] | |
x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y) | |
x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y)\ | |
.permute(0, 2, 3, 1)\ | |
.reshape(-1, dim_c * dim_x * dim_y, self.n_locs, self.n_locs).contiguous() | |
return x | |
class Permute(torch.nn.Module): | |
"""Module for permuting axes. | |
""" | |
def __init__(self, *perm): | |
""" | |
Args: | |
*perm: Permute axes. | |
""" | |
super().__init__() | |
self.perm = perm | |
def forward(self, input): | |
"""Permutes axes of tensor. | |
Args: | |
input: Input tensor. | |
Returns: | |
torch.Tensor: permuted tensor. | |
""" | |
return input.permute(*self.perm) | |