Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| class AbstractPermuter(nn.Module): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| def forward(self, x, reverse=False): | |
| raise NotImplementedError | |
| class Identity(AbstractPermuter): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x, reverse=False): | |
| return x | |
| class Subsample(AbstractPermuter): | |
| def __init__(self, H, W): | |
| super().__init__() | |
| C = 1 | |
| indices = np.arange(H*W).reshape(C,H,W) | |
| while min(H, W) > 1: | |
| indices = indices.reshape(C,H//2,2,W//2,2) | |
| indices = indices.transpose(0,2,4,1,3) | |
| indices = indices.reshape(C*4,H//2, W//2) | |
| H = H//2 | |
| W = W//2 | |
| C = C*4 | |
| assert H == W == 1 | |
| idx = torch.tensor(indices.ravel()) | |
| self.register_buffer('forward_shuffle_idx', | |
| nn.Parameter(idx, requires_grad=False)) | |
| self.register_buffer('backward_shuffle_idx', | |
| nn.Parameter(torch.argsort(idx), requires_grad=False)) | |
| def forward(self, x, reverse=False): | |
| if not reverse: | |
| return x[:, self.forward_shuffle_idx] | |
| else: | |
| return x[:, self.backward_shuffle_idx] | |
| def mortonify(i, j): | |
| """(i,j) index to linear morton code""" | |
| i = np.uint64(i) | |
| j = np.uint64(j) | |
| z = np.uint(0) | |
| for pos in range(32): | |
| z = (z | | |
| ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | | |
| ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) | |
| ) | |
| return z | |
| class ZCurve(AbstractPermuter): | |
| def __init__(self, H, W): | |
| super().__init__() | |
| reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] | |
| idx = np.argsort(reverseidx) | |
| idx = torch.tensor(idx) | |
| reverseidx = torch.tensor(reverseidx) | |
| self.register_buffer('forward_shuffle_idx', | |
| idx) | |
| self.register_buffer('backward_shuffle_idx', | |
| reverseidx) | |
| def forward(self, x, reverse=False): | |
| if not reverse: | |
| return x[:, self.forward_shuffle_idx] | |
| else: | |
| return x[:, self.backward_shuffle_idx] | |
| class SpiralOut(AbstractPermuter): | |
| def __init__(self, H, W): | |
| super().__init__() | |
| assert H == W | |
| size = W | |
| indices = np.arange(size*size).reshape(size,size) | |
| i0 = size//2 | |
| j0 = size//2-1 | |
| i = i0 | |
| j = j0 | |
| idx = [indices[i0, j0]] | |
| step_mult = 0 | |
| for c in range(1, size//2+1): | |
| step_mult += 1 | |
| # steps left | |
| for k in range(step_mult): | |
| i = i - 1 | |
| j = j | |
| idx.append(indices[i, j]) | |
| # step down | |
| for k in range(step_mult): | |
| i = i | |
| j = j + 1 | |
| idx.append(indices[i, j]) | |
| step_mult += 1 | |
| if c < size//2: | |
| # step right | |
| for k in range(step_mult): | |
| i = i + 1 | |
| j = j | |
| idx.append(indices[i, j]) | |
| # step up | |
| for k in range(step_mult): | |
| i = i | |
| j = j - 1 | |
| idx.append(indices[i, j]) | |
| else: | |
| # end reached | |
| for k in range(step_mult-1): | |
| i = i + 1 | |
| idx.append(indices[i, j]) | |
| assert len(idx) == size*size | |
| idx = torch.tensor(idx) | |
| self.register_buffer('forward_shuffle_idx', idx) | |
| self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) | |
| def forward(self, x, reverse=False): | |
| if not reverse: | |
| return x[:, self.forward_shuffle_idx] | |
| else: | |
| return x[:, self.backward_shuffle_idx] | |
| class SpiralIn(AbstractPermuter): | |
| def __init__(self, H, W): | |
| super().__init__() | |
| assert H == W | |
| size = W | |
| indices = np.arange(size*size).reshape(size,size) | |
| i0 = size//2 | |
| j0 = size//2-1 | |
| i = i0 | |
| j = j0 | |
| idx = [indices[i0, j0]] | |
| step_mult = 0 | |
| for c in range(1, size//2+1): | |
| step_mult += 1 | |
| # steps left | |
| for k in range(step_mult): | |
| i = i - 1 | |
| j = j | |
| idx.append(indices[i, j]) | |
| # step down | |
| for k in range(step_mult): | |
| i = i | |
| j = j + 1 | |
| idx.append(indices[i, j]) | |
| step_mult += 1 | |
| if c < size//2: | |
| # step right | |
| for k in range(step_mult): | |
| i = i + 1 | |
| j = j | |
| idx.append(indices[i, j]) | |
| # step up | |
| for k in range(step_mult): | |
| i = i | |
| j = j - 1 | |
| idx.append(indices[i, j]) | |
| else: | |
| # end reached | |
| for k in range(step_mult-1): | |
| i = i + 1 | |
| idx.append(indices[i, j]) | |
| assert len(idx) == size*size | |
| idx = idx[::-1] | |
| idx = torch.tensor(idx) | |
| self.register_buffer('forward_shuffle_idx', idx) | |
| self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) | |
| def forward(self, x, reverse=False): | |
| if not reverse: | |
| return x[:, self.forward_shuffle_idx] | |
| else: | |
| return x[:, self.backward_shuffle_idx] | |
| class Random(nn.Module): | |
| def __init__(self, H, W): | |
| super().__init__() | |
| indices = np.random.RandomState(1).permutation(H*W) | |
| idx = torch.tensor(indices.ravel()) | |
| self.register_buffer('forward_shuffle_idx', idx) | |
| self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) | |
| def forward(self, x, reverse=False): | |
| if not reverse: | |
| return x[:, self.forward_shuffle_idx] | |
| else: | |
| return x[:, self.backward_shuffle_idx] | |
| class AlternateParsing(AbstractPermuter): | |
| def __init__(self, H, W): | |
| super().__init__() | |
| indices = np.arange(W*H).reshape(H,W) | |
| for i in range(1, H, 2): | |
| indices[i, :] = indices[i, ::-1] | |
| idx = indices.flatten() | |
| assert len(idx) == H*W | |
| idx = torch.tensor(idx) | |
| self.register_buffer('forward_shuffle_idx', idx) | |
| self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) | |
| def forward(self, x, reverse=False): | |
| if not reverse: | |
| return x[:, self.forward_shuffle_idx] | |
| else: | |
| return x[:, self.backward_shuffle_idx] | |
| if __name__ == "__main__": | |
| p0 = AlternateParsing(16, 16) | |
| print(p0.forward_shuffle_idx) | |
| print(p0.backward_shuffle_idx) | |
| x = torch.randint(0, 768, size=(11, 256)) | |
| y = p0(x) | |
| xre = p0(y, reverse=True) | |
| assert torch.equal(x, xre) | |
| p1 = SpiralOut(2, 2) | |
| print(p1.forward_shuffle_idx) | |
| print(p1.backward_shuffle_idx) | |