|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from src.efficientvit.models.nn import (ConvLayer, DSConv, EfficientViTBlock, |
|
FusedMBConv, IdentityLayer, MBConv, |
|
OpSequential, ResBlock, ResidualBlock) |
|
from src.efficientvit.models.utils import build_kwargs_from_config |
|
|
|
__all__ = [ |
|
"EfficientViTBackbone", |
|
"efficientvit_backbone_b0", |
|
"efficientvit_backbone_b1", |
|
"efficientvit_backbone_b2", |
|
"efficientvit_backbone_b3", |
|
"EfficientViTLargeBackbone", |
|
"efficientvit_backbone_l0", |
|
"efficientvit_backbone_l1", |
|
"efficientvit_backbone_l2", |
|
"efficientvit_backbone_l3", |
|
] |
|
|
|
|
|
class EfficientViTBackbone(nn.Module): |
|
def __init__( |
|
self, |
|
width_list: list[int], |
|
depth_list: list[int], |
|
in_channels=3, |
|
dim=32, |
|
expand_ratio=4, |
|
norm="bn2d", |
|
act_func="hswish", |
|
) -> None: |
|
super().__init__() |
|
|
|
self.width_list = [] |
|
|
|
self.input_stem = [ |
|
ConvLayer( |
|
in_channels=3, |
|
out_channels=width_list[0], |
|
stride=2, |
|
norm=norm, |
|
act_func=act_func, |
|
) |
|
] |
|
for _ in range(depth_list[0]): |
|
block = self.build_local_block( |
|
in_channels=width_list[0], |
|
out_channels=width_list[0], |
|
stride=1, |
|
expand_ratio=1, |
|
norm=norm, |
|
act_func=act_func, |
|
) |
|
self.input_stem.append(ResidualBlock(block, IdentityLayer())) |
|
in_channels = width_list[0] |
|
self.input_stem = OpSequential(self.input_stem) |
|
self.width_list.append(in_channels) |
|
|
|
|
|
self.stages = [] |
|
for w, d in zip(width_list[1:3], depth_list[1:3]): |
|
stage = [] |
|
for i in range(d): |
|
stride = 2 if i == 0 else 1 |
|
block = self.build_local_block( |
|
in_channels=in_channels, |
|
out_channels=w, |
|
stride=stride, |
|
expand_ratio=expand_ratio, |
|
norm=norm, |
|
act_func=act_func, |
|
) |
|
block = ResidualBlock(block, IdentityLayer() if stride == 1 else None) |
|
stage.append(block) |
|
in_channels = w |
|
self.stages.append(OpSequential(stage)) |
|
self.width_list.append(in_channels) |
|
|
|
for w, d in zip(width_list[3:], depth_list[3:]): |
|
stage = [] |
|
block = self.build_local_block( |
|
in_channels=in_channels, |
|
out_channels=w, |
|
stride=2, |
|
expand_ratio=expand_ratio, |
|
norm=norm, |
|
act_func=act_func, |
|
fewer_norm=True, |
|
) |
|
stage.append(ResidualBlock(block, None)) |
|
in_channels = w |
|
|
|
for _ in range(d): |
|
stage.append( |
|
EfficientViTBlock( |
|
in_channels=in_channels, |
|
dim=dim, |
|
expand_ratio=expand_ratio, |
|
norm=norm, |
|
act_func=act_func, |
|
) |
|
) |
|
self.stages.append(OpSequential(stage)) |
|
self.width_list.append(in_channels) |
|
self.stages = nn.ModuleList(self.stages) |
|
|
|
@staticmethod |
|
def build_local_block( |
|
in_channels: int, |
|
out_channels: int, |
|
stride: int, |
|
expand_ratio: float, |
|
norm: str, |
|
act_func: str, |
|
fewer_norm: bool = False, |
|
) -> nn.Module: |
|
if expand_ratio == 1: |
|
block = DSConv( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
stride=stride, |
|
use_bias=(True, False) if fewer_norm else False, |
|
norm=(None, norm) if fewer_norm else norm, |
|
act_func=(act_func, None), |
|
) |
|
else: |
|
block = MBConv( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
stride=stride, |
|
expand_ratio=expand_ratio, |
|
use_bias=(True, True, False) if fewer_norm else False, |
|
norm=(None, None, norm) if fewer_norm else norm, |
|
act_func=(act_func, act_func, None), |
|
) |
|
return block |
|
|
|
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: |
|
output_dict = {"input": x} |
|
output_dict["stage0"] = x = self.input_stem(x) |
|
for stage_id, stage in enumerate(self.stages, 1): |
|
output_dict["stage%d" % stage_id] = x = stage(x) |
|
output_dict["stage_final"] = x |
|
return output_dict |
|
|
|
|
|
def efficientvit_backbone_b0(**kwargs) -> EfficientViTBackbone: |
|
backbone = EfficientViTBackbone( |
|
width_list=[8, 16, 32, 64, 128], |
|
depth_list=[1, 2, 2, 2, 2], |
|
dim=16, |
|
**build_kwargs_from_config(kwargs, EfficientViTBackbone), |
|
) |
|
return backbone |
|
|
|
|
|
def efficientvit_backbone_b1(**kwargs) -> EfficientViTBackbone: |
|
backbone = EfficientViTBackbone( |
|
width_list=[16, 32, 64, 128, 256], |
|
depth_list=[1, 2, 3, 3, 4], |
|
dim=16, |
|
**build_kwargs_from_config(kwargs, EfficientViTBackbone), |
|
) |
|
return backbone |
|
|
|
|
|
def efficientvit_backbone_b2(**kwargs) -> EfficientViTBackbone: |
|
backbone = EfficientViTBackbone( |
|
width_list=[24, 48, 96, 192, 384], |
|
depth_list=[1, 3, 4, 4, 6], |
|
dim=32, |
|
**build_kwargs_from_config(kwargs, EfficientViTBackbone), |
|
) |
|
return backbone |
|
|
|
|
|
def efficientvit_backbone_b3(**kwargs) -> EfficientViTBackbone: |
|
backbone = EfficientViTBackbone( |
|
width_list=[32, 64, 128, 256, 512], |
|
depth_list=[1, 4, 6, 6, 9], |
|
dim=32, |
|
**build_kwargs_from_config(kwargs, EfficientViTBackbone), |
|
) |
|
return backbone |
|
|
|
|
|
class EfficientViTLargeBackbone(nn.Module): |
|
def __init__( |
|
self, |
|
width_list: list[int], |
|
depth_list: list[int], |
|
block_list: list[str] or None = None, |
|
expand_list: list[float] or None = None, |
|
fewer_norm_list: list[bool] or None = None, |
|
in_channels=3, |
|
qkv_dim=32, |
|
norm="bn2d", |
|
act_func="gelu", |
|
) -> None: |
|
super().__init__() |
|
block_list = block_list or ["res", "fmb", "fmb", "mb", "att"] |
|
expand_list = expand_list or [1, 4, 4, 4, 6] |
|
fewer_norm_list = fewer_norm_list or [False, False, False, True, True] |
|
|
|
self.width_list = [] |
|
self.stages = [] |
|
|
|
stage0 = [ |
|
ConvLayer( |
|
in_channels=3, |
|
out_channels=width_list[0], |
|
stride=2, |
|
norm=norm, |
|
act_func=act_func, |
|
) |
|
] |
|
for _ in range(depth_list[0]): |
|
block = self.build_local_block( |
|
block=block_list[0], |
|
in_channels=width_list[0], |
|
out_channels=width_list[0], |
|
stride=1, |
|
expand_ratio=expand_list[0], |
|
norm=norm, |
|
act_func=act_func, |
|
fewer_norm=fewer_norm_list[0], |
|
) |
|
stage0.append(ResidualBlock(block, IdentityLayer())) |
|
in_channels = width_list[0] |
|
self.stages.append(OpSequential(stage0)) |
|
self.width_list.append(in_channels) |
|
|
|
for stage_id, (w, d) in enumerate(zip(width_list[1:], depth_list[1:]), start=1): |
|
stage = [] |
|
block = self.build_local_block( |
|
block=( |
|
"mb" |
|
if block_list[stage_id] not in ["mb", "fmb"] |
|
else block_list[stage_id] |
|
), |
|
in_channels=in_channels, |
|
out_channels=w, |
|
stride=2, |
|
expand_ratio=expand_list[stage_id] * 4, |
|
norm=norm, |
|
act_func=act_func, |
|
fewer_norm=fewer_norm_list[stage_id], |
|
) |
|
stage.append(ResidualBlock(block, None)) |
|
in_channels = w |
|
|
|
for _ in range(d): |
|
if block_list[stage_id].startswith("att"): |
|
stage.append( |
|
EfficientViTBlock( |
|
in_channels=in_channels, |
|
dim=qkv_dim, |
|
expand_ratio=expand_list[stage_id], |
|
scales=(3,) if block_list[stage_id] == "att@3" else (5,), |
|
norm=norm, |
|
act_func=act_func, |
|
) |
|
) |
|
else: |
|
block = self.build_local_block( |
|
block=block_list[stage_id], |
|
in_channels=in_channels, |
|
out_channels=in_channels, |
|
stride=1, |
|
expand_ratio=expand_list[stage_id], |
|
norm=norm, |
|
act_func=act_func, |
|
fewer_norm=fewer_norm_list[stage_id], |
|
) |
|
block = ResidualBlock(block, IdentityLayer()) |
|
stage.append(block) |
|
self.stages.append(OpSequential(stage)) |
|
self.width_list.append(in_channels) |
|
self.stages = nn.ModuleList(self.stages) |
|
|
|
@staticmethod |
|
def build_local_block( |
|
block: str, |
|
in_channels: int, |
|
out_channels: int, |
|
stride: int, |
|
expand_ratio: float, |
|
norm: str, |
|
act_func: str, |
|
fewer_norm: bool = False, |
|
) -> nn.Module: |
|
if block == "res": |
|
block = ResBlock( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
stride=stride, |
|
use_bias=(True, False) if fewer_norm else False, |
|
norm=(None, norm) if fewer_norm else norm, |
|
act_func=(act_func, None), |
|
) |
|
elif block == "fmb": |
|
block = FusedMBConv( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
stride=stride, |
|
expand_ratio=expand_ratio, |
|
use_bias=(True, False) if fewer_norm else False, |
|
norm=(None, norm) if fewer_norm else norm, |
|
act_func=(act_func, None), |
|
) |
|
elif block == "mb": |
|
block = MBConv( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
stride=stride, |
|
expand_ratio=expand_ratio, |
|
use_bias=(True, True, False) if fewer_norm else False, |
|
norm=(None, None, norm) if fewer_norm else norm, |
|
act_func=(act_func, act_func, None), |
|
) |
|
else: |
|
raise ValueError(block) |
|
return block |
|
|
|
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: |
|
output_dict = {"input": x} |
|
for stage_id, stage in enumerate(self.stages): |
|
output_dict["stage%d" % stage_id] = x = stage(x) |
|
output_dict["stage_final"] = x |
|
return output_dict |
|
|
|
|
|
def efficientvit_backbone_l0(**kwargs) -> EfficientViTLargeBackbone: |
|
backbone = EfficientViTLargeBackbone( |
|
width_list=[32, 64, 128, 256, 512], |
|
depth_list=[1, 1, 1, 4, 4], |
|
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), |
|
) |
|
return backbone |
|
|
|
|
|
def efficientvit_backbone_l1(**kwargs) -> EfficientViTLargeBackbone: |
|
backbone = EfficientViTLargeBackbone( |
|
width_list=[32, 64, 128, 256, 512], |
|
depth_list=[1, 1, 1, 6, 6], |
|
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), |
|
) |
|
return backbone |
|
|
|
|
|
def efficientvit_backbone_l2(**kwargs) -> EfficientViTLargeBackbone: |
|
backbone = EfficientViTLargeBackbone( |
|
width_list=[32, 64, 128, 256, 512], |
|
depth_list=[1, 2, 2, 8, 8], |
|
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), |
|
) |
|
return backbone |
|
|
|
|
|
def efficientvit_backbone_l3(**kwargs) -> EfficientViTLargeBackbone: |
|
backbone = EfficientViTLargeBackbone( |
|
width_list=[64, 128, 256, 512, 1024], |
|
depth_list=[1, 2, 2, 8, 8], |
|
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), |
|
) |
|
return backbone |
|
|