bubbliiiing
Create Code
19fe404
raw
history blame
17 kB
import torch
import torch.nn as nn
from .attention import SpatialAttention, TemporalAttention
from .common import ResidualBlock3D
from .downsamplers import (SpatialDownsampler3D, SpatialTemporalDownsampler3D,
TemporalDownsampler3D)
from .gc_block import GlobalContextBlock
def get_down_block(
down_block_type: str,
in_channels: int,
out_channels: int,
num_layers: int,
act_fn: str,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
) -> nn.Module:
if down_block_type == "DownBlock3D":
return DownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
)
elif down_block_type == "SpatialDownBlock3D":
return SpatialDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "SpatialAttnDownBlock3D":
return SpatialAttnDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "TemporalDownBlock3D":
return TemporalDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "TemporalAttnDownBlock3D":
return TemporalAttnDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "SpatialTemporalDownBlock3D":
return SpatialTemporalDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
else:
raise ValueError(f"Unknown down block type: {down_block_type}")
class DownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
return x
class SpatialDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = SpatialDownsampler3D(out_channels, out_channels)
self.spatial_downsample_factor = 2
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class TemporalDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.temporal_downsample_factor = 1
self.spatial_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class SpatialTemporalDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = SpatialTemporalDownsampler3D(out_channels, out_channels)
self.spatial_downsample_factor = 2
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class SpatialAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
SpatialAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = SpatialDownsampler3D(out_channels, out_channels)
self.spatial_downsample_factor = 2
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class TemporalDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.temporal_downsample_factor = 1
self.spatial_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class TemporalAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
TemporalAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.temporal_downsample_factor = 1
self.spatial_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x