Spaces:
Runtime error
Runtime error
| # Copyright 2024 MIT Han Lab | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ...models.nn.act import build_act | |
| from ...models.nn.norm import build_norm | |
| from ...models.utils import get_same_padding, list_sum, resize, val2list, val2tuple | |
| __all__ = [ | |
| "ConvLayer", | |
| "UpSampleLayer", | |
| "ConvPixelUnshuffleDownSampleLayer", | |
| "PixelUnshuffleChannelAveragingDownSampleLayer", | |
| "ConvPixelShuffleUpSampleLayer", | |
| "ChannelDuplicatingPixelUnshuffleUpSampleLayer", | |
| "LinearLayer", | |
| "IdentityLayer", | |
| "DSConv", | |
| "MBConv", | |
| "FusedMBConv", | |
| "ResBlock", | |
| "LiteMLA", | |
| "EfficientViTBlock", | |
| "ResidualBlock", | |
| "DAGBlock", | |
| "OpSequential", | |
| ] | |
| ################################################################################# | |
| # Basic Layers # | |
| ################################################################################# | |
| class ConvLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| dilation=1, | |
| groups=1, | |
| use_bias=False, | |
| dropout=0, | |
| norm="bn2d", | |
| act_func="relu", | |
| ): | |
| super().__init__() | |
| padding = get_same_padding(kernel_size) | |
| padding *= dilation | |
| self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None | |
| self.conv = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=(kernel_size, kernel_size), | |
| stride=(stride, stride), | |
| padding=padding, | |
| dilation=(dilation, dilation), | |
| groups=groups, | |
| bias=use_bias, | |
| ) | |
| self.norm = build_norm(norm, num_features=out_channels) | |
| self.act = build_act(act_func) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.dropout is not None: | |
| x = self.dropout(x) | |
| x = self.conv(x) | |
| if self.norm: | |
| x = self.norm(x) | |
| if self.act: | |
| x = self.act(x) | |
| return x | |
| class UpSampleLayer(nn.Module): | |
| def __init__( | |
| self, | |
| mode="bicubic", | |
| size: Optional[int | tuple[int, int] | list[int]] = None, | |
| factor=2, | |
| align_corners=False, | |
| ): | |
| super().__init__() | |
| self.mode = mode | |
| self.size = val2list(size, 2) if size is not None else None | |
| self.factor = None if self.size is not None else factor | |
| self.align_corners = align_corners | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if (self.size is not None and tuple(x.shape[-2:]) == self.size) or self.factor == 1: | |
| return x | |
| if x.dtype in [torch.float16, torch.bfloat16]: | |
| x = x.float() | |
| return resize(x, self.size, self.factor, self.mode, self.align_corners) | |
| class ConvPixelUnshuffleDownSampleLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| factor: int, | |
| ): | |
| super().__init__() | |
| self.factor = factor | |
| out_ratio = factor**2 | |
| assert out_channels % out_ratio == 0 | |
| self.conv = ConvLayer( | |
| in_channels=in_channels, | |
| out_channels=out_channels // out_ratio, | |
| kernel_size=kernel_size, | |
| use_bias=True, | |
| norm=None, | |
| act_func=None, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.conv(x) | |
| x = F.pixel_unshuffle(x, self.factor) | |
| return x | |
| class PixelUnshuffleChannelAveragingDownSampleLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| factor: int, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.factor = factor | |
| assert in_channels * factor**2 % out_channels == 0 | |
| self.group_size = in_channels * factor**2 // out_channels | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = F.pixel_unshuffle(x, self.factor) | |
| B, C, H, W = x.shape | |
| x = x.view(B, self.out_channels, self.group_size, H, W) | |
| x = x.mean(dim=2) | |
| return x | |
| class ConvPixelShuffleUpSampleLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| factor: int, | |
| ): | |
| super().__init__() | |
| self.factor = factor | |
| out_ratio = factor**2 | |
| self.conv = ConvLayer( | |
| in_channels=in_channels, | |
| out_channels=out_channels * out_ratio, | |
| kernel_size=kernel_size, | |
| use_bias=True, | |
| norm=None, | |
| act_func=None, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.conv(x) | |
| x = F.pixel_shuffle(x, self.factor) | |
| return x | |
| class InterpolateConvUpSampleLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| factor: int, | |
| mode: str = "nearest", | |
| ) -> None: | |
| super().__init__() | |
| self.factor = factor | |
| self.mode = mode | |
| self.conv = ConvLayer( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| use_bias=True, | |
| norm=None, | |
| act_func=None, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode) | |
| x = self.conv(x) | |
| return x | |
| class ChannelDuplicatingPixelUnshuffleUpSampleLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| factor: int, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.factor = factor | |
| assert out_channels * factor**2 % in_channels == 0 | |
| self.repeats = out_channels * factor**2 // in_channels | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x.repeat_interleave(self.repeats, dim=1) | |
| x = F.pixel_shuffle(x, self.factor) | |
| return x | |
| class LinearLayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| use_bias=True, | |
| dropout=0, | |
| norm=None, | |
| act_func=None, | |
| ): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None | |
| self.linear = nn.Linear(in_features, out_features, use_bias) | |
| self.norm = build_norm(norm, num_features=out_features) | |
| self.act = build_act(act_func) | |
| def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor: | |
| if x.dim() > 2: | |
| x = torch.flatten(x, start_dim=1) | |
| return x | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self._try_squeeze(x) | |
| if self.dropout: | |
| x = self.dropout(x) | |
| x = self.linear(x) | |
| if self.norm: | |
| x = self.norm(x) | |
| if self.act: | |
| x = self.act(x) | |
| return x | |
| class IdentityLayer(nn.Module): | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x | |
| ################################################################################# | |
| # Basic Blocks # | |
| ################################################################################# | |
| class DSConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| use_bias=False, | |
| norm=("bn2d", "bn2d"), | |
| act_func=("relu6", None), | |
| ): | |
| super().__init__() | |
| use_bias = val2tuple(use_bias, 2) | |
| norm = val2tuple(norm, 2) | |
| act_func = val2tuple(act_func, 2) | |
| self.depth_conv = ConvLayer( | |
| in_channels, | |
| in_channels, | |
| kernel_size, | |
| stride, | |
| groups=in_channels, | |
| norm=norm[0], | |
| act_func=act_func[0], | |
| use_bias=use_bias[0], | |
| ) | |
| self.point_conv = ConvLayer( | |
| in_channels, | |
| out_channels, | |
| 1, | |
| norm=norm[1], | |
| act_func=act_func[1], | |
| use_bias=use_bias[1], | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.depth_conv(x) | |
| x = self.point_conv(x) | |
| return x | |
| class MBConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| mid_channels=None, | |
| expand_ratio=6, | |
| use_bias=False, | |
| norm=("bn2d", "bn2d", "bn2d"), | |
| act_func=("relu6", "relu6", None), | |
| ): | |
| super().__init__() | |
| use_bias = val2tuple(use_bias, 3) | |
| norm = val2tuple(norm, 3) | |
| act_func = val2tuple(act_func, 3) | |
| mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels | |
| self.inverted_conv = ConvLayer( | |
| in_channels, | |
| mid_channels, | |
| 1, | |
| stride=1, | |
| norm=norm[0], | |
| act_func=act_func[0], | |
| use_bias=use_bias[0], | |
| ) | |
| self.depth_conv = ConvLayer( | |
| mid_channels, | |
| mid_channels, | |
| kernel_size, | |
| stride=stride, | |
| groups=mid_channels, | |
| norm=norm[1], | |
| act_func=act_func[1], | |
| use_bias=use_bias[1], | |
| ) | |
| self.point_conv = ConvLayer( | |
| mid_channels, | |
| out_channels, | |
| 1, | |
| norm=norm[2], | |
| act_func=act_func[2], | |
| use_bias=use_bias[2], | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inverted_conv(x) | |
| x = self.depth_conv(x) | |
| x = self.point_conv(x) | |
| return x | |
| class FusedMBConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| mid_channels=None, | |
| expand_ratio=6, | |
| groups=1, | |
| use_bias=False, | |
| norm=("bn2d", "bn2d"), | |
| act_func=("relu6", None), | |
| ): | |
| super().__init__() | |
| use_bias = val2tuple(use_bias, 2) | |
| norm = val2tuple(norm, 2) | |
| act_func = val2tuple(act_func, 2) | |
| mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels | |
| self.spatial_conv = ConvLayer( | |
| in_channels, | |
| mid_channels, | |
| kernel_size, | |
| stride, | |
| groups=groups, | |
| use_bias=use_bias[0], | |
| norm=norm[0], | |
| act_func=act_func[0], | |
| ) | |
| self.point_conv = ConvLayer( | |
| mid_channels, | |
| out_channels, | |
| 1, | |
| use_bias=use_bias[1], | |
| norm=norm[1], | |
| act_func=act_func[1], | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.spatial_conv(x) | |
| x = self.point_conv(x) | |
| return x | |
| class GLUMBConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| mid_channels=None, | |
| expand_ratio=6, | |
| use_bias=False, | |
| norm=(None, None, "ln2d"), | |
| act_func=("silu", "silu", None), | |
| ): | |
| super().__init__() | |
| use_bias = val2tuple(use_bias, 3) | |
| norm = val2tuple(norm, 3) | |
| act_func = val2tuple(act_func, 3) | |
| mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels | |
| self.glu_act = build_act(act_func[1], inplace=False) | |
| self.inverted_conv = ConvLayer( | |
| in_channels, | |
| mid_channels * 2, | |
| 1, | |
| use_bias=use_bias[0], | |
| norm=norm[0], | |
| act_func=act_func[0], | |
| ) | |
| self.depth_conv = ConvLayer( | |
| mid_channels * 2, | |
| mid_channels * 2, | |
| kernel_size, | |
| stride=stride, | |
| groups=mid_channels * 2, | |
| use_bias=use_bias[1], | |
| norm=norm[1], | |
| act_func=None, | |
| ) | |
| self.point_conv = ConvLayer( | |
| mid_channels, | |
| out_channels, | |
| 1, | |
| use_bias=use_bias[2], | |
| norm=norm[2], | |
| act_func=act_func[2], | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.inverted_conv(x) | |
| x = self.depth_conv(x) | |
| x, gate = torch.chunk(x, 2, dim=1) | |
| gate = self.glu_act(gate) | |
| x = x * gate | |
| x = self.point_conv(x) | |
| return x | |
| class ResBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| mid_channels=None, | |
| expand_ratio=1, | |
| use_bias=False, | |
| norm=("bn2d", "bn2d"), | |
| act_func=("relu6", None), | |
| ): | |
| super().__init__() | |
| use_bias = val2tuple(use_bias, 2) | |
| norm = val2tuple(norm, 2) | |
| act_func = val2tuple(act_func, 2) | |
| mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels | |
| self.conv1 = ConvLayer( | |
| in_channels, | |
| mid_channels, | |
| kernel_size, | |
| stride, | |
| use_bias=use_bias[0], | |
| norm=norm[0], | |
| act_func=act_func[0], | |
| ) | |
| self.conv2 = ConvLayer( | |
| mid_channels, | |
| out_channels, | |
| kernel_size, | |
| 1, | |
| use_bias=use_bias[1], | |
| norm=norm[1], | |
| act_func=act_func[1], | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| return x | |
| class LiteMLA(nn.Module): | |
| r"""Lightweight multi-scale linear attention""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| heads: Optional[int] = None, | |
| heads_ratio: float = 1.0, | |
| dim=8, | |
| use_bias=False, | |
| norm=(None, "bn2d"), | |
| act_func=(None, None), | |
| kernel_func="relu", | |
| scales: tuple[int, ...] = (5,), | |
| eps=1.0e-15, | |
| ): | |
| super().__init__() | |
| self.eps = eps | |
| heads = int(in_channels // dim * heads_ratio) if heads is None else heads | |
| total_dim = heads * dim | |
| use_bias = val2tuple(use_bias, 2) | |
| norm = val2tuple(norm, 2) | |
| act_func = val2tuple(act_func, 2) | |
| self.dim = dim | |
| self.qkv = ConvLayer( | |
| in_channels, | |
| 3 * total_dim, | |
| 1, | |
| use_bias=use_bias[0], | |
| norm=norm[0], | |
| act_func=act_func[0], | |
| ) | |
| self.aggreg = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.Conv2d( | |
| 3 * total_dim, | |
| 3 * total_dim, | |
| scale, | |
| padding=get_same_padding(scale), | |
| groups=3 * total_dim, | |
| bias=use_bias[0], | |
| ), | |
| nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), | |
| ) | |
| for scale in scales | |
| ] | |
| ) | |
| self.kernel_func = build_act(kernel_func, inplace=False) | |
| self.proj = ConvLayer( | |
| total_dim * (1 + len(scales)), | |
| out_channels, | |
| 1, | |
| use_bias=use_bias[1], | |
| norm=norm[1], | |
| act_func=act_func[1], | |
| ) | |
| def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor: | |
| B, _, H, W = list(qkv.size()) | |
| if qkv.dtype == torch.float16: | |
| qkv = qkv.float() | |
| qkv = torch.reshape( | |
| qkv, | |
| ( | |
| B, | |
| -1, | |
| 3 * self.dim, | |
| H * W, | |
| ), | |
| ) | |
| q, k, v = ( | |
| qkv[:, :, 0 : self.dim], | |
| qkv[:, :, self.dim : 2 * self.dim], | |
| qkv[:, :, 2 * self.dim :], | |
| ) | |
| # lightweight linear attention | |
| q = self.kernel_func(q) | |
| k = self.kernel_func(k) | |
| # linear matmul | |
| trans_k = k.transpose(-1, -2) | |
| v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1) | |
| vk = torch.matmul(v, trans_k) | |
| out = torch.matmul(vk, q) | |
| if out.dtype == torch.bfloat16: | |
| out = out.float() | |
| out = out[:, :, :-1] / (out[:, :, -1:] + self.eps) | |
| out = torch.reshape(out, (B, -1, H, W)) | |
| return out | |
| def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor: | |
| B, _, H, W = list(qkv.size()) | |
| qkv = torch.reshape( | |
| qkv, | |
| ( | |
| B, | |
| -1, | |
| 3 * self.dim, | |
| H * W, | |
| ), | |
| ) | |
| q, k, v = ( | |
| qkv[:, :, 0 : self.dim], | |
| qkv[:, :, self.dim : 2 * self.dim], | |
| qkv[:, :, 2 * self.dim :], | |
| ) | |
| q = self.kernel_func(q) | |
| k = self.kernel_func(k) | |
| att_map = torch.matmul(k.transpose(-1, -2), q) # b h n n | |
| original_dtype = att_map.dtype | |
| if original_dtype in [torch.float16, torch.bfloat16]: | |
| att_map = att_map.float() | |
| att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n n | |
| att_map = att_map.to(original_dtype) | |
| out = torch.matmul(v, att_map) # b h d n | |
| out = torch.reshape(out, (B, -1, H, W)) | |
| return out | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # generate multi-scale q, k, v | |
| qkv = self.qkv(x) | |
| multi_scale_qkv = [qkv] | |
| for op in self.aggreg: | |
| multi_scale_qkv.append(op(qkv)) | |
| qkv = torch.cat(multi_scale_qkv, dim=1) | |
| H, W = list(qkv.size())[-2:] | |
| if H * W > self.dim: | |
| out = self.relu_linear_att(qkv).to(qkv.dtype) | |
| else: | |
| out = self.relu_quadratic_att(qkv) | |
| out = self.proj(out) | |
| return out | |
| class EfficientViTBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| heads_ratio: float = 1.0, | |
| dim=32, | |
| expand_ratio: float = 4, | |
| scales: tuple[int, ...] = (5,), | |
| norm: str = "bn2d", | |
| act_func: str = "hswish", | |
| context_module: str = "LiteMLA", | |
| local_module: str = "MBConv", | |
| ): | |
| super().__init__() | |
| if context_module == "LiteMLA": | |
| self.context_module = ResidualBlock( | |
| LiteMLA( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| heads_ratio=heads_ratio, | |
| dim=dim, | |
| norm=(None, norm), | |
| scales=scales, | |
| ), | |
| IdentityLayer(), | |
| ) | |
| else: | |
| raise ValueError(f"context_module {context_module} is not supported") | |
| if local_module == "MBConv": | |
| self.local_module = ResidualBlock( | |
| MBConv( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| expand_ratio=expand_ratio, | |
| use_bias=(True, True, False), | |
| norm=(None, None, norm), | |
| act_func=(act_func, act_func, None), | |
| ), | |
| IdentityLayer(), | |
| ) | |
| elif local_module == "GLUMBConv": | |
| self.local_module = ResidualBlock( | |
| GLUMBConv( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| expand_ratio=expand_ratio, | |
| use_bias=(True, True, False), | |
| norm=(None, None, norm), | |
| act_func=(act_func, act_func, None), | |
| ), | |
| IdentityLayer(), | |
| ) | |
| else: | |
| raise NotImplementedError(f"local_module {local_module} is not supported") | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.context_module(x) | |
| x = self.local_module(x) | |
| return x | |
| ################################################################################# | |
| # Functional Blocks # | |
| ################################################################################# | |
| class ResidualBlock(nn.Module): | |
| def __init__( | |
| self, | |
| main: Optional[nn.Module], | |
| shortcut: Optional[nn.Module], | |
| post_act=None, | |
| pre_norm: Optional[nn.Module] = None, | |
| ): | |
| super().__init__() | |
| self.pre_norm = pre_norm | |
| self.main = main | |
| self.shortcut = shortcut | |
| self.post_act = build_act(post_act) | |
| def forward_main(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.pre_norm is None: | |
| return self.main(x) | |
| else: | |
| return self.main(self.pre_norm(x)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.main is None: | |
| res = x | |
| elif self.shortcut is None: | |
| res = self.forward_main(x) | |
| else: | |
| res = self.forward_main(x) + self.shortcut(x) | |
| if self.post_act: | |
| res = self.post_act(res) | |
| return res | |
| class DAGBlock(nn.Module): | |
| def __init__( | |
| self, | |
| inputs: dict[str, nn.Module], | |
| merge: str, | |
| post_input: Optional[nn.Module], | |
| middle: nn.Module, | |
| outputs: dict[str, nn.Module], | |
| ): | |
| super().__init__() | |
| self.input_keys = list(inputs.keys()) | |
| self.input_ops = nn.ModuleList(list(inputs.values())) | |
| self.merge = merge | |
| self.post_input = post_input | |
| self.middle = middle | |
| self.output_keys = list(outputs.keys()) | |
| self.output_ops = nn.ModuleList(list(outputs.values())) | |
| def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |
| feat = [op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)] | |
| if self.merge == "add": | |
| feat = list_sum(feat) | |
| elif self.merge == "cat": | |
| feat = torch.concat(feat, dim=1) | |
| else: | |
| raise NotImplementedError | |
| if self.post_input is not None: | |
| feat = self.post_input(feat) | |
| feat = self.middle(feat) | |
| for key, op in zip(self.output_keys, self.output_ops): | |
| feature_dict[key] = op(feat) | |
| return feature_dict | |
| class OpSequential(nn.Module): | |
| def __init__(self, op_list: list[Optional[nn.Module]]): | |
| super().__init__() | |
| valid_op_list = [] | |
| for op in op_list: | |
| if op is not None: | |
| valid_op_list.append(op) | |
| self.op_list = nn.ModuleList(valid_op_list) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for op in self.op_list: | |
| x = op(x) | |
| return x | |