Spaces:
Sleeping
Sleeping
| from .module import Module | |
| from .. import functional as F | |
| from torch import Tensor | |
| __all__ = ['ChannelShuffle'] | |
| class ChannelShuffle(Module): | |
| r"""Divides and rearranges the channels in a tensor. | |
| This operation divides the channels in a tensor of shape :math:`(*, C , H, W)` | |
| into g groups and rearranges them as :math:`(*, \frac{C}{g}, g, H, W)`, | |
| while keeping the original tensor shape. | |
| Args: | |
| groups (int): number of groups to divide channels in. | |
| Examples:: | |
| >>> # xdoctest: +IGNORE_WANT("FIXME: incorrect want") | |
| >>> channel_shuffle = nn.ChannelShuffle(2) | |
| >>> input = torch.randn(1, 4, 2, 2) | |
| >>> print(input) | |
| [[[[1, 2], | |
| [3, 4]], | |
| [[5, 6], | |
| [7, 8]], | |
| [[9, 10], | |
| [11, 12]], | |
| [[13, 14], | |
| [15, 16]], | |
| ]] | |
| >>> output = channel_shuffle(input) | |
| >>> print(output) | |
| [[[[1, 2], | |
| [3, 4]], | |
| [[9, 10], | |
| [11, 12]], | |
| [[5, 6], | |
| [7, 8]], | |
| [[13, 14], | |
| [15, 16]], | |
| ]] | |
| """ | |
| __constants__ = ['groups'] | |
| groups: int | |
| def __init__(self, groups: int) -> None: | |
| super().__init__() | |
| self.groups = groups | |
| def forward(self, input: Tensor) -> Tensor: | |
| return F.channel_shuffle(input, self.groups) | |
| def extra_repr(self) -> str: | |
| return f'groups={self.groups}' | |