Fucius's picture
Upload 52 files
ad5354d verified
raw
history blame
12.6 kB
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import torch
import torch.nn as nn
from src.efficientvit.models.nn import (ConvLayer, DSConv, EfficientViTBlock,
FusedMBConv, IdentityLayer, MBConv,
OpSequential, ResBlock, ResidualBlock)
from src.efficientvit.models.utils import build_kwargs_from_config
__all__ = [
"EfficientViTBackbone",
"efficientvit_backbone_b0",
"efficientvit_backbone_b1",
"efficientvit_backbone_b2",
"efficientvit_backbone_b3",
"EfficientViTLargeBackbone",
"efficientvit_backbone_l0",
"efficientvit_backbone_l1",
"efficientvit_backbone_l2",
"efficientvit_backbone_l3",
]
class EfficientViTBackbone(nn.Module):
def __init__(
self,
width_list: list[int],
depth_list: list[int],
in_channels=3,
dim=32,
expand_ratio=4,
norm="bn2d",
act_func="hswish",
) -> None:
super().__init__()
self.width_list = []
# input stem
self.input_stem = [
ConvLayer(
in_channels=3,
out_channels=width_list[0],
stride=2,
norm=norm,
act_func=act_func,
)
]
for _ in range(depth_list[0]):
block = self.build_local_block(
in_channels=width_list[0],
out_channels=width_list[0],
stride=1,
expand_ratio=1,
norm=norm,
act_func=act_func,
)
self.input_stem.append(ResidualBlock(block, IdentityLayer()))
in_channels = width_list[0]
self.input_stem = OpSequential(self.input_stem)
self.width_list.append(in_channels)
# stages
self.stages = []
for w, d in zip(width_list[1:3], depth_list[1:3]):
stage = []
for i in range(d):
stride = 2 if i == 0 else 1
block = self.build_local_block(
in_channels=in_channels,
out_channels=w,
stride=stride,
expand_ratio=expand_ratio,
norm=norm,
act_func=act_func,
)
block = ResidualBlock(block, IdentityLayer() if stride == 1 else None)
stage.append(block)
in_channels = w
self.stages.append(OpSequential(stage))
self.width_list.append(in_channels)
for w, d in zip(width_list[3:], depth_list[3:]):
stage = []
block = self.build_local_block(
in_channels=in_channels,
out_channels=w,
stride=2,
expand_ratio=expand_ratio,
norm=norm,
act_func=act_func,
fewer_norm=True,
)
stage.append(ResidualBlock(block, None))
in_channels = w
for _ in range(d):
stage.append(
EfficientViTBlock(
in_channels=in_channels,
dim=dim,
expand_ratio=expand_ratio,
norm=norm,
act_func=act_func,
)
)
self.stages.append(OpSequential(stage))
self.width_list.append(in_channels)
self.stages = nn.ModuleList(self.stages)
@staticmethod
def build_local_block(
in_channels: int,
out_channels: int,
stride: int,
expand_ratio: float,
norm: str,
act_func: str,
fewer_norm: bool = False,
) -> nn.Module:
if expand_ratio == 1:
block = DSConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
use_bias=(True, False) if fewer_norm else False,
norm=(None, norm) if fewer_norm else norm,
act_func=(act_func, None),
)
else:
block = MBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, True, False) if fewer_norm else False,
norm=(None, None, norm) if fewer_norm else norm,
act_func=(act_func, act_func, None),
)
return block
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
output_dict = {"input": x}
output_dict["stage0"] = x = self.input_stem(x)
for stage_id, stage in enumerate(self.stages, 1):
output_dict["stage%d" % stage_id] = x = stage(x)
output_dict["stage_final"] = x
return output_dict
def efficientvit_backbone_b0(**kwargs) -> EfficientViTBackbone:
backbone = EfficientViTBackbone(
width_list=[8, 16, 32, 64, 128],
depth_list=[1, 2, 2, 2, 2],
dim=16,
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
)
return backbone
def efficientvit_backbone_b1(**kwargs) -> EfficientViTBackbone:
backbone = EfficientViTBackbone(
width_list=[16, 32, 64, 128, 256],
depth_list=[1, 2, 3, 3, 4],
dim=16,
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
)
return backbone
def efficientvit_backbone_b2(**kwargs) -> EfficientViTBackbone:
backbone = EfficientViTBackbone(
width_list=[24, 48, 96, 192, 384],
depth_list=[1, 3, 4, 4, 6],
dim=32,
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
)
return backbone
def efficientvit_backbone_b3(**kwargs) -> EfficientViTBackbone:
backbone = EfficientViTBackbone(
width_list=[32, 64, 128, 256, 512],
depth_list=[1, 4, 6, 6, 9],
dim=32,
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
)
return backbone
class EfficientViTLargeBackbone(nn.Module):
def __init__(
self,
width_list: list[int],
depth_list: list[int],
block_list: list[str] or None = None,
expand_list: list[float] or None = None,
fewer_norm_list: list[bool] or None = None,
in_channels=3,
qkv_dim=32,
norm="bn2d",
act_func="gelu",
) -> None:
super().__init__()
block_list = block_list or ["res", "fmb", "fmb", "mb", "att"]
expand_list = expand_list or [1, 4, 4, 4, 6]
fewer_norm_list = fewer_norm_list or [False, False, False, True, True]
self.width_list = []
self.stages = []
# stage 0
stage0 = [
ConvLayer(
in_channels=3,
out_channels=width_list[0],
stride=2,
norm=norm,
act_func=act_func,
)
]
for _ in range(depth_list[0]):
block = self.build_local_block(
block=block_list[0],
in_channels=width_list[0],
out_channels=width_list[0],
stride=1,
expand_ratio=expand_list[0],
norm=norm,
act_func=act_func,
fewer_norm=fewer_norm_list[0],
)
stage0.append(ResidualBlock(block, IdentityLayer()))
in_channels = width_list[0]
self.stages.append(OpSequential(stage0))
self.width_list.append(in_channels)
for stage_id, (w, d) in enumerate(zip(width_list[1:], depth_list[1:]), start=1):
stage = []
block = self.build_local_block(
block=(
"mb"
if block_list[stage_id] not in ["mb", "fmb"]
else block_list[stage_id]
),
in_channels=in_channels,
out_channels=w,
stride=2,
expand_ratio=expand_list[stage_id] * 4,
norm=norm,
act_func=act_func,
fewer_norm=fewer_norm_list[stage_id],
)
stage.append(ResidualBlock(block, None))
in_channels = w
for _ in range(d):
if block_list[stage_id].startswith("att"):
stage.append(
EfficientViTBlock(
in_channels=in_channels,
dim=qkv_dim,
expand_ratio=expand_list[stage_id],
scales=(3,) if block_list[stage_id] == "att@3" else (5,),
norm=norm,
act_func=act_func,
)
)
else:
block = self.build_local_block(
block=block_list[stage_id],
in_channels=in_channels,
out_channels=in_channels,
stride=1,
expand_ratio=expand_list[stage_id],
norm=norm,
act_func=act_func,
fewer_norm=fewer_norm_list[stage_id],
)
block = ResidualBlock(block, IdentityLayer())
stage.append(block)
self.stages.append(OpSequential(stage))
self.width_list.append(in_channels)
self.stages = nn.ModuleList(self.stages)
@staticmethod
def build_local_block(
block: str,
in_channels: int,
out_channels: int,
stride: int,
expand_ratio: float,
norm: str,
act_func: str,
fewer_norm: bool = False,
) -> nn.Module:
if block == "res":
block = ResBlock(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
use_bias=(True, False) if fewer_norm else False,
norm=(None, norm) if fewer_norm else norm,
act_func=(act_func, None),
)
elif block == "fmb":
block = FusedMBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, False) if fewer_norm else False,
norm=(None, norm) if fewer_norm else norm,
act_func=(act_func, None),
)
elif block == "mb":
block = MBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, True, False) if fewer_norm else False,
norm=(None, None, norm) if fewer_norm else norm,
act_func=(act_func, act_func, None),
)
else:
raise ValueError(block)
return block
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
output_dict = {"input": x}
for stage_id, stage in enumerate(self.stages):
output_dict["stage%d" % stage_id] = x = stage(x)
output_dict["stage_final"] = x
return output_dict
def efficientvit_backbone_l0(**kwargs) -> EfficientViTLargeBackbone:
backbone = EfficientViTLargeBackbone(
width_list=[32, 64, 128, 256, 512],
depth_list=[1, 1, 1, 4, 4],
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
)
return backbone
def efficientvit_backbone_l1(**kwargs) -> EfficientViTLargeBackbone:
backbone = EfficientViTLargeBackbone(
width_list=[32, 64, 128, 256, 512],
depth_list=[1, 1, 1, 6, 6],
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
)
return backbone
def efficientvit_backbone_l2(**kwargs) -> EfficientViTLargeBackbone:
backbone = EfficientViTLargeBackbone(
width_list=[32, 64, 128, 256, 512],
depth_list=[1, 2, 2, 8, 8],
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
)
return backbone
def efficientvit_backbone_l3(**kwargs) -> EfficientViTLargeBackbone:
backbone = EfficientViTLargeBackbone(
width_list=[64, 128, 256, 512, 1024],
depth_list=[1, 2, 2, 8, 8],
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
)
return backbone