Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Alibaba Inc | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import torch | |
import torch.nn as nn | |
from inspiremusic.music_tokenizer.env import AttrDict | |
from inspiremusic.music_tokenizer.models import Encoder | |
from inspiremusic.music_tokenizer.models import Generator | |
from inspiremusic.music_tokenizer.models import Quantizer | |
class VQVAE(nn.Module): | |
def __init__(self, | |
config_path, | |
ckpt_path, | |
with_encoder=False): | |
super(VQVAE, self).__init__() | |
ckpt = torch.load(ckpt_path) | |
with open(config_path) as f: | |
data = f.read() | |
json_config = json.loads(data) | |
self.h = AttrDict(json_config) | |
self.quantizer = Quantizer(self.h) | |
self.generator = Generator(self.h) | |
self.generator.load_state_dict(ckpt['generator']) | |
self.quantizer.load_state_dict(ckpt['quantizer']) | |
if with_encoder: | |
self.encoder = Encoder(self.h) | |
self.encoder.load_state_dict(ckpt['encoder']) | |
def forward(self, x): | |
# x is the codebook | |
# x.shape (B, T, Nq) | |
quant_emb = self.quantizer.embed(x) | |
return self.generator(quant_emb) | |
def encode(self, x): | |
batch_size = x.size(0) | |
if len(x.shape) == 3 and x.shape[-1] == 1: | |
x = x.squeeze(-1) | |
c = self.encoder(x.unsqueeze(1)) | |
q, loss_q, c = self.quantizer(c) | |
c = [code.reshape(batch_size, -1) for code in c] | |
# shape: [N, T, 4] | |
return torch.stack(c, -1) | |