Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,403 Bytes
56a1295 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
class ConvNextV2LayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
input_dtype = x.dtype
x = x.float()
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = x.to(dtype=input_dtype)
x = self.weight[None, :, None] * x + self.bias[None, :, None]
return x
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class InterpolationLayer(nn.Module):
def __init__(self, ): # this is a default of 1 / 50 * (44100 / 512) / 4
super().__init__()
pass
def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x = F.interpolate(x, size=target_len, mode='linear')
return x
class ConvNeXtV2Stage(nn.Module):
def __init__(
self,
dim: int = 512,
intermediate_dim: int = 2048,
num_blocks: int = 1,
dilation: int = 1,
downsample_layer_indices: List[int] = None,
downsample_factors: List[int] = None,
upsample_layer_indices: List[int] = None,
upsample_factors: List[int] = None,
interpolation_layer_indices: List[int] = None,
input_dim: int = None,
output_dim: int = None,
gin_channels: int = 0,
):
super().__init__()
# maybe downsample layers
if downsample_layer_indices is not None:
assert downsample_factors is not None
self.downsample_blocks = nn.ModuleList(
[
nn.Sequential(
ConvNextV2LayerNorm(dim, data_format="channels_first"),
nn.Conv1d(
dim, dim, kernel_size=downsample_factor, stride=downsample_factor
),
) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors)
]
)
self.downsample_layer_indices = downsample_layer_indices
else:
self.downsample_blocks = nn.ModuleList()
self.downsample_layer_indices = []
# maybe upsample layers
if upsample_layer_indices is not None:
assert upsample_factors is not None
self.upsample_blocks = nn.ModuleList(
[
nn.Sequential(
ConvNextV2LayerNorm(dim, data_format="channels_first"),
nn.ConvTranspose1d(
dim, dim, kernel_size=upsample_factor, stride=upsample_factor
),
) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors)
]
)
self.upsample_layer_indices = upsample_layer_indices
else:
self.upsample_blocks = nn.ModuleList()
self.upsample_layer_indices = []
# maybe interpolation layers
if interpolation_layer_indices is not None:
self.interpolation_blocks = nn.ModuleList(
[
InterpolationLayer()
for _ in interpolation_layer_indices
]
)
self.interpolation_layer_indices = interpolation_layer_indices
else:
self.interpolation_blocks = nn.ModuleList()
self.interpolation_layer_indices = []
# main blocks
self.blocks = nn.ModuleList(
[
ConvNeXtV2Block(
dim=dim,
intermediate_dim=intermediate_dim,
dilation=dilation,
)
for _ in range(num_blocks)
]
)
# maybe input and output projections
if input_dim is not None and input_dim != dim:
self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1)
else:
self.input_projection = nn.Identity()
if output_dim is not None and output_dim != dim:
self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1)
else:
self.output_projection = nn.Identity()
if gin_channels > 0:
self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x = self.input_projection(x) # B, D, T
if hasattr(self, 'gin'):
g = kwargs['g']
x = x + self.gin(g)
# pad to a multiple of cumprod(downsample_factors)
if len(self.downsample_blocks) > 0:
downsample_factor = 1
for factor in self.downsample_blocks:
downsample_factor *= factor[1].stride[0]
pad_len = downsample_factor - x.size(-1) % downsample_factor
if pad_len > 0:
x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
# main blocks
for layer_idx, block in enumerate(self.blocks):
if layer_idx in self.downsample_layer_indices:
x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x)
if layer_idx in self.upsample_layer_indices:
x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x)
if layer_idx in self.interpolation_layer_indices:
x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len'])
x = block(x)
x = self.output_projection(x)
return x
def setup_caches(self, *args, **kwargs):
pass
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first")
self.pwconv1 = nn.Linear(
dim, intermediate_dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = self.norm(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.transpose(1, 2) # b n d -> b d n
return residual + x |