Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 Alibaba Inc | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| from inspiremusic.music_tokenizer.env import AttrDict | |
| from inspiremusic.music_tokenizer.models import Encoder | |
| from inspiremusic.music_tokenizer.models import Generator | |
| from inspiremusic.music_tokenizer.models import Quantizer | |
| class VQVAE(nn.Module): | |
| def __init__(self, | |
| config_path, | |
| ckpt_path, | |
| with_encoder=False): | |
| super(VQVAE, self).__init__() | |
| ckpt = torch.load(ckpt_path) | |
| with open(config_path) as f: | |
| data = f.read() | |
| json_config = json.loads(data) | |
| self.h = AttrDict(json_config) | |
| self.quantizer = Quantizer(self.h) | |
| self.generator = Generator(self.h) | |
| self.generator.load_state_dict(ckpt['generator']) | |
| self.quantizer.load_state_dict(ckpt['quantizer']) | |
| if with_encoder: | |
| self.encoder = Encoder(self.h) | |
| self.encoder.load_state_dict(ckpt['encoder']) | |
| def forward(self, x): | |
| # x is the codebook | |
| # x.shape (B, T, Nq) | |
| quant_emb = self.quantizer.embed(x) | |
| return self.generator(quant_emb) | |
| def encode(self, x): | |
| batch_size = x.size(0) | |
| if len(x.shape) == 3 and x.shape[-1] == 1: | |
| x = x.squeeze(-1) | |
| c = self.encoder(x.unsqueeze(1)) | |
| q, loss_q, c = self.quantizer(c) | |
| c = [code.reshape(batch_size, -1) for code in c] | |
| # shape: [N, T, 4] | |
| return torch.stack(c, -1) | |