KingNish commited on
Commit
b4511c9
·
verified ·
1 Parent(s): 4e8aca1

Upload ./vocos/models.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocos/models.py +156 -0
vocos/models.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm
6
+
7
+ from vocos.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
8
+
9
+
10
+ class Backbone(nn.Module):
11
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
12
+
13
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
14
+ """
15
+ Args:
16
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
17
+ C denotes output features, and L is the sequence length.
18
+
19
+ Returns:
20
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
21
+ and H denotes the model dimension.
22
+ """
23
+ raise NotImplementedError("Subclasses must implement the forward method.")
24
+
25
+
26
+ class VocosBackbone(Backbone):
27
+ """
28
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
29
+
30
+ Args:
31
+ input_channels (int): Number of input features channels.
32
+ dim (int): Hidden dimension of the model.
33
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
34
+ num_layers (int): Number of ConvNeXtBlock layers.
35
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
36
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
37
+ None means non-conditional model. Defaults to None.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ input_channels: int,
43
+ dim: int,
44
+ intermediate_dim: int,
45
+ num_layers: int,
46
+ layer_scale_init_value: Optional[float] = None,
47
+ adanorm_num_embeddings: Optional[int] = None,
48
+ ckpt: Optional[str] = None,
49
+ ):
50
+ super().__init__()
51
+ self.input_channels = input_channels
52
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
53
+ self.adanorm = adanorm_num_embeddings is not None
54
+ if adanorm_num_embeddings:
55
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
56
+ else:
57
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
58
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
59
+ self.convnext = nn.ModuleList(
60
+ [
61
+ ConvNeXtBlock(
62
+ dim=dim,
63
+ intermediate_dim=intermediate_dim,
64
+ layer_scale_init_value=layer_scale_init_value,
65
+ adanorm_num_embeddings=adanorm_num_embeddings,
66
+ )
67
+ for _ in range(num_layers)
68
+ ]
69
+ )
70
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
71
+ # print out self's state dict
72
+ if ckpt is not None:
73
+ state_dict = torch.load(ckpt, map_location='cpu')
74
+ state_dict = self._fuzzy_load_state_dict(state_dict)
75
+ self.load_state_dict(state_dict)
76
+ self.apply(self._init_weights)
77
+
78
+ def _fuzzy_load_state_dict(self, state_dict):
79
+ def _get_key(key):
80
+ return key.split('backbone.')[-1]
81
+
82
+ new_state_dict = {}
83
+ for k, v in state_dict.items():
84
+ if k.startswith('backbone'):
85
+ if v.shape == self.state_dict()[_get_key(k)].shape:
86
+ new_state_dict[_get_key(k)] = v
87
+ else:
88
+ new_state_dict[_get_key(k)] = self.state_dict()[_get_key(k)]
89
+ nn.init.trunc_normal_(new_state_dict[_get_key(k)], std=0.02)
90
+ nn.init.constant_(new_state_dict[_get_key(k)], 0)
91
+ return new_state_dict
92
+
93
+ def _init_weights(self, m):
94
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
95
+ nn.init.trunc_normal_(m.weight, std=0.02)
96
+ nn.init.constant_(m.bias, 0)
97
+
98
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
99
+ bandwidth_id = kwargs.get('bandwidth_id', None)
100
+ x = self.embed(x)
101
+ if self.adanorm:
102
+ assert bandwidth_id is not None
103
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
104
+ else:
105
+ x = self.norm(x.transpose(1, 2))
106
+ x = x.transpose(1, 2)
107
+ for conv_block in self.convnext:
108
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
109
+ x = self.final_layer_norm(x.transpose(1, 2))
110
+ return x
111
+
112
+
113
+ class VocosResNetBackbone(Backbone):
114
+ """
115
+ Vocos backbone module built with ResBlocks.
116
+
117
+ Args:
118
+ input_channels (int): Number of input features channels.
119
+ dim (int): Hidden dimension of the model.
120
+ num_blocks (int): Number of ResBlock1 blocks.
121
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
122
+ """
123
+
124
+ def __init__(
125
+ self, input_channels, dim, num_blocks, layer_scale_init_value=None,
126
+ ):
127
+ super().__init__()
128
+ self.input_channels = input_channels
129
+ self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
130
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
131
+ self.resnet = nn.Sequential(
132
+ *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
133
+ )
134
+
135
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
136
+ x = self.embed(x)
137
+ x = self.resnet(x)
138
+ x = x.transpose(1, 2)
139
+ return x
140
+
141
+ if __name__ == '__main__':
142
+ # Define the model
143
+ model = VocosBackbone(
144
+ input_channels=1024,
145
+ dim=512,
146
+ intermediate_dim=1536,
147
+ num_layers=8,
148
+ ckpt="/root/OpenMusicVoco/vocos/pretrained.pth"
149
+ )
150
+
151
+ # Generate some random input
152
+ x = torch.randn(2, 1024, 100)
153
+
154
+ # Forward pass
155
+ output = model(x)
156
+ print(output.shape) # torch.Size([2, 100, 512])