Spaces:
Runtime error
Runtime error
File size: 2,943 Bytes
01df1d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
'''Various miscellaneous modules
'''
import torch
class View(torch.nn.Module):
"""Basic reshape module.
"""
def __init__(self, *shape):
"""
Args:
*shape: Input shape.
"""
super().__init__()
self.shape = shape
def forward(self, input):
"""Reshapes tensor.
Args:
input: Input tensor.
Returns:
torch.Tensor: Flattened tensor.
"""
return input.view(*self.shape)
class Unfold(torch.nn.Module):
"""Module for unfolding tensor.
Performs strided crops on 2d (image) tensors. Stride is assumed to be half the crop size.
"""
def __init__(self, img_size, fold_size):
"""
Args:
img_size: Input size.
fold_size: Crop size.
"""
super().__init__()
fold_stride = fold_size // 2
self.fold_size = fold_size
self.fold_stride = fold_stride
self.n_locs = 2 * (img_size // fold_size) - 1
self.unfold = torch.nn.Unfold((self.fold_size, self.fold_size),
stride=(self.fold_stride, self.fold_stride))
def forward(self, x):
"""Unfolds tensor.
Args:
x: Input tensor.
Returns:
torch.Tensor: Unfolded tensor.
"""
N = x.size(0)
x = self.unfold(x).reshape(N, -1, self.fold_size, self.fold_size, self.n_locs * self.n_locs)\
.permute(0, 4, 1, 2, 3)\
.reshape(N * self.n_locs * self.n_locs, -1, self.fold_size, self.fold_size)
return x
class Fold(torch.nn.Module):
"""Module (re)folding tensor.
Undoes the strided crops above. Works only on 1x1.
"""
def __init__(self, img_size, fold_size):
"""
Args:
img_size: Images size.
fold_size: Crop size.
"""
super().__init__()
self.n_locs = 2 * (img_size // fold_size) - 1
def forward(self, x):
"""(Re)folds tensor.
Args:
x: Input tensor.
Returns:
torch.Tensor: Refolded tensor.
"""
dim_c, dim_x, dim_y = x.size()[1:]
x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y)
x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y)\
.permute(0, 2, 3, 1)\
.reshape(-1, dim_c * dim_x * dim_y, self.n_locs, self.n_locs).contiguous()
return x
class Permute(torch.nn.Module):
"""Module for permuting axes.
"""
def __init__(self, *perm):
"""
Args:
*perm: Permute axes.
"""
super().__init__()
self.perm = perm
def forward(self, input):
"""Permutes axes of tensor.
Args:
input: Input tensor.
Returns:
torch.Tensor: permuted tensor.
"""
return input.permute(*self.perm)
|