File size: 1,891 Bytes
20d6bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from __future__ import annotations

from typing import Any, Dict, Tuple, Union, Optional

import torch
import yaml
from torch import nn
from .heads import ISTFTHead
from .models import VocosBackbone


class Vocos(nn.Module):
    """

    The Vocos class represents a Fourier-based neural vocoder for audio synthesis.

    This class is primarily designed for inference, with support for loading from pretrained

    model checkpoints. It consists of three main components: a feature extractor,

    a backbone, and a head.

    """

    def __init__(

        self, args,

    ):
        super().__init__()
        self.backbone = VocosBackbone(
            input_channels=args.vocos.backbone.input_channels,
            dim=args.vocos.backbone.dim,
            intermediate_dim=args.vocos.backbone.intermediate_dim,
            num_layers=args.vocos.backbone.num_layers,
        )
        self.head = ISTFTHead(
            dim=args.vocos.head.dim,
            n_fft=args.vocos.head.n_fft,
            hop_length=args.vocos.head.hop_length,
            padding=args.vocos.head.padding,
        )

    def forward(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """

        Method to decode audio waveform from already calculated features. The features input is passed through

        the backbone and the head to reconstruct the audio output.



        Args:

            features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,

                                     C denotes the feature dimension, and L is the sequence length.



        Returns:

            Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).

        """
        x = self.backbone(features_input, **kwargs)
        audio_output = self.head(x)
        return audio_output