import onnxruntime import torch import torch.nn.functional as F from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID from tools.vqgan.extract_vq import get_model PAD_TOKEN_ID = torch.LongTensor([CODEBOOK_PAD_TOKEN_ID]) class Encoder(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model self.model.spec_transform.spectrogram.return_complex = False def forward(self, audios): mels = self.model.spec_transform(audios) encoded_features = self.model.backbone(mels) z = self.model.quantizer.downsample(encoded_features) _, indices = self.model.quantizer.residual_fsq(z.transpose(-2, -1)) _, b, l, _ = indices.shape return indices.permute(1, 0, 3, 2).long().view(b, -1, l) class Decoder(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model self.model.head.training = False self.model.head.checkpointing = False def get_codes_from_indices(self, cur_index, indices): _, quantize_dim, _ = indices.shape d_dim = self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.shape[2] if ( quantize_dim < self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers ): assert ( self.model.quantizer.residual_fsq.rvqs[cur_index].quantize_dropout > 0.0 ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" indices = F.pad( indices, ( 0, self.model.quantizer.residual_fsq.rvqs[cur_index].num_quantizers - quantize_dim, ), value=-1, ) mask = indices == -1 indices = indices.masked_fill(mask, 0) all_codes = torch.gather( self.model.quantizer.residual_fsq.rvqs[cur_index].codebooks.unsqueeze(1), dim=2, index=indices.permute(2, 0, 1).unsqueeze(-1).repeat(1, 1, 1, d_dim), ) all_codes = all_codes.masked_fill(mask.permute(2, 0, 1).unsqueeze(-1), 0.0) scales = ( self.model.quantizer.residual_fsq.rvqs[cur_index] .scales.unsqueeze(1) .unsqueeze(1) ) all_codes = all_codes * scales return all_codes def get_output_from_indices(self, cur_index, indices): codes = self.get_codes_from_indices(cur_index, indices) codes_summed = codes.sum(dim=0) return self.model.quantizer.residual_fsq.rvqs[cur_index].project_out( codes_summed ) def forward(self, indices) -> torch.Tensor: batch_size, _, length = indices.shape dims = self.model.quantizer.residual_fsq.dim groups = self.model.quantizer.residual_fsq.groups dim_per_group = dims // groups # indices = rearrange(indices, "b (g r) l -> g b l r", g=groups) indices = indices.view(batch_size, groups, -1, length).permute(1, 0, 3, 2) # z_q = self.model.quantizer.residual_fsq.get_output_from_indices(indices) z_q = torch.empty((batch_size, length, dims)) for i in range(groups): z_q[:, :, i * dim_per_group : (i + 1) * dim_per_group] = ( self.get_output_from_indices(i, indices[i]) ) z = self.model.quantizer.upsample(z_q.transpose(1, 2)) x = self.model.head(z) return x def main(firefly_gan_vq_path, llama_path, export_prefix): GanModel = get_model("firefly_gan_vq", firefly_gan_vq_path, device="cpu") enc = Encoder(GanModel) dec = Decoder(GanModel) audio_example = torch.randn(1, 1, 96000) indices = enc(audio_example) torch.onnx.export( enc, audio_example, f"{export_prefix}encoder.onnx", dynamic_axes={ "audio": {0: "batch_size", 2: "audio_length"}, }, do_constant_folding=False, opset_version=18, verbose=False, input_names=["audio"], output_names=["prompt"], ) torch.onnx.export( dec, indices, f"{export_prefix}decoder.onnx", dynamic_axes={ "prompt": {0: "batch_size", 2: "frame_count"}, }, do_constant_folding=False, opset_version=18, verbose=False, input_names=["prompt"], output_names=["audio"], ) test_example = torch.randn(1, 1, 96000 * 5) encoder_session = onnxruntime.InferenceSession(f"{export_prefix}encoder.onnx") decoder_session = onnxruntime.InferenceSession(f"{export_prefix}decoder.onnx") # check graph has no error onnx_enc_out = encoder_session.run(["prompt"], {"audio": test_example.numpy()})[0] torch_enc_out = enc(test_example) onnx_dec_out = decoder_session.run(["audio"], {"prompt": onnx_enc_out})[0] torch_dec_out = dec(torch_enc_out) if __name__ == "__main__": main("checkpoints/pre/firefly-gan-vq-fsq-8x1024-21hz-generator.pth", None, "test_")