Spaces:
Runtime error
Runtime error
import torch | |
from dac.model import DAC | |
from torch import nn | |
from transformers import PreTrainedModel | |
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput | |
from .configuration_dac import DACConfig | |
# model doesn't support batching yet | |
class DACModel(PreTrainedModel): | |
config_class = DACConfig | |
main_input_name = "input_values" | |
# Set main input to 'input_values' for voice steering | |
main_input_name = "input_values" | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = DAC( | |
n_codebooks=config.num_codebooks, | |
latent_dim=config.latent_dim, | |
codebook_size=config.codebook_size, | |
) | |
self.remove_weight_norm() | |
self.apply_weight_norm() | |
def encode( | |
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None | |
): | |
""" | |
Encodes the input audio waveform into discrete codes. | |
Args: | |
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): | |
Float values of the input audio waveform. | |
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): | |
Padding mask used to pad the `input_values`. | |
bandwidth (`float`, *optional*): | |
Not used, kept to have the same inferface as HF encodec. | |
n_quantizers (`int`, *optional*) : | |
Number of quantizers to use, by default None | |
If None, all quantizers are used. | |
sample_rate (`int`, *optional*) : | |
Signal sampling_rate | |
Returns: | |
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling | |
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with | |
`codebook` of shape `[batch_size, num_codebooks, frames]`. | |
Scale is not used here. | |
""" | |
_, channels, input_length = input_values.shape | |
if channels < 1 or channels > 2: | |
raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}") | |
audio_data = self.model.preprocess(input_values, sample_rate) | |
return_dict = return_dict if return_dict is not None else self.config.return_dict | |
# TODO: for now, no chunk length | |
chunk_length = None # self.config.chunk_length | |
if chunk_length is None: | |
chunk_length = input_length | |
stride = input_length | |
else: | |
stride = self.config.chunk_stride | |
if padding_mask is None: | |
padding_mask = torch.ones_like(input_values).bool() | |
encoded_frames = [] | |
scales = [] | |
step = chunk_length - stride | |
if (input_length % stride) - step != 0: | |
raise ValueError( | |
"The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly." | |
) | |
for offset in range(0, input_length - step, stride): | |
mask = padding_mask[..., offset : offset + chunk_length].bool() | |
frame = audio_data[:, :, offset : offset + chunk_length] | |
scale = None | |
_, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers) | |
encoded_frames.append(encoded_frame) | |
scales.append(scale) | |
encoded_frames = torch.stack(encoded_frames) | |
if not return_dict: | |
return (encoded_frames, scales) | |
return EncodecEncoderOutput(encoded_frames, scales) | |
def decode( | |
self, | |
audio_codes, | |
audio_scales, | |
padding_mask=None, | |
return_dict=None, | |
): | |
""" | |
Decodes the given frames into an output audio waveform. | |
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be | |
trimmed. | |
Args: | |
audio_codes (`torch.FloatTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*): | |
Discret code embeddings computed using `model.encode`. | |
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*): | |
Not used, kept to have the same inferface as HF encodec. | |
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): | |
Padding mask used to pad the `input_values`. | |
Not used yet, kept to have the same inferface as HF encodec. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
return_dict = return_dict or self.config.return_dict | |
# TODO: for now, no chunk length | |
if len(audio_codes) != 1: | |
raise ValueError(f"Expected one frame, got {len(audio_codes)}") | |
audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0] | |
audio_values = self.model.decode(audio_values) | |
if not return_dict: | |
return (audio_values,) | |
return EncodecDecoderOutput(audio_values) | |
def forward(self, tensor): | |
raise ValueError("`DACModel.forward` not implemented yet") | |
def apply_weight_norm(self): | |
weight_norm = nn.utils.weight_norm | |
if hasattr(nn.utils.parametrizations, "weight_norm"): | |
weight_norm = nn.utils.parametrizations.weight_norm | |
def _apply_weight_norm(module): | |
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d): | |
weight_norm(module) | |
self.apply(_apply_weight_norm) | |
def remove_weight_norm(self): | |
def _remove_weight_norm(module): | |
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d): | |
nn.utils.remove_weight_norm(module) | |
self.apply(_remove_weight_norm) | |