File size: 5,714 Bytes
96fe5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b58efa
96fe5d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from inspiremusic.utils.mask import make_pad_mask
from inspiremusic.music_tokenizer.vqvae import VQVAE

class MaskedDiff(torch.nn.Module):
    def __init__(self,
                 input_size: int = 512,
                 output_size: int = 128,
                 output_type: str = "mel",
                 vocab_size: int = 4096,
                 input_frame_rate: int = 50,
                 only_mask_loss: bool = True,
                 encoder: torch.nn.Module = None,
                 length_regulator: torch.nn.Module = None,
                 decoder: torch.nn.Module = None,
                 decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80,
                                       'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
                                                                 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
                                       'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
                                                          'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
                 mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000,
                                        'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000},
                generator_model_dir: str = "pretrained_models/InspireMusic-Base/music_tokenizer",
                num_codebooks: int = 4
                ):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.decoder_conf = decoder_conf
        self.mel_feat_conf = mel_feat_conf
        self.vocab_size = vocab_size
        self.output_type = output_type
        self.input_frame_rate = input_frame_rate
        logging.info(f"input frame rate={self.input_frame_rate}")
        self.input_embedding = nn.Embedding(vocab_size, input_size)

        self.encoder = encoder
        self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
        self.decoder = decoder
        self.length_regulator = length_regulator
        self.only_mask_loss = only_mask_loss
        self.quantizer = VQVAE( f'{generator_model_dir}/config.json',
                                  f'{generator_model_dir}/model.pt',with_encoder=True).quantizer
        self.quantizer.eval()
        self.num_codebooks  = num_codebooks
        self.cond = None
        self.interpolate = False
                                  
    def forward(
            self,
            batch: dict,
            device: torch.device,
    ) -> Dict[str, Optional[torch.Tensor]]:

        audio_token = batch['acoustic_token'].to(device)
        audio_token_len = batch['acoustic_token_len'].to(device)
        audio_token  = audio_token.view(audio_token.size(0),-1,self.num_codebooks)
        if "semantic_token" not in batch:
            token = audio_token[:,:,0]
            token_len = (audio_token_len/self.num_codebooks).long()
    
        else:
            token = batch['semantic_token'].to(device)
            token_len = batch['semantic_token_len'].to(device)

        with torch.no_grad():
            feat = self.quantizer.embed(audio_token)
            feat_len = (audio_token_len/self.num_codebooks).long()

        token = self.input_embedding(token) 
        h, h_lengths = self.encoder(token, token_len)
        h, h_lengths = self.length_regulator(h, feat_len)   

        # get conditions
        if self.cond:
            conds = torch.zeros(feat.shape, device=token.device)
            for i, j in enumerate(feat_len):
                if random.random() < 0.5:
                    continue
                index = random.randint(0, int(0.3 * j))
                conds[i, :index] = feat[i, :index]
            conds = conds.transpose(1, 2)
        else:
            conds = None
        
        mask = (~make_pad_mask(feat_len)).to(h)

        loss, _ = self.decoder.compute_loss(
                feat,
                mask.unsqueeze(1),
                h.transpose(1, 2).contiguous(),
                None,
                cond=conds
        )
            
        return {'loss': loss}

    @torch.inference_mode()
    def inference(self,
                  token,
                  token_len,
                  sample_rate):
        assert token.shape[0] == 1

        token = self.input_embedding(torch.clamp(token, min=0)) 
        h, h_lengths = self.encoder(token, token_len)

        if sample_rate == 48000:
            token_len = 2 * token_len

        h, h_lengths = self.length_regulator(h, token_len)  

        # get conditions
        conds = None

        mask = (~make_pad_mask(token_len)).to(h)
        feat = self.decoder(
            mu=h.transpose(1, 2).contiguous(),
            mask=mask.unsqueeze(1),
            spks=None,
            cond=conds,
            n_timesteps=10
        )
        return feat