Spaces:
Runtime error
Runtime error
File size: 5,992 Bytes
195bb33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
from dac.model import DAC
from torch import nn
from transformers import PreTrainedModel
from transformers.models.encodec.modeling_encodec import EncodecDecoderOutput, EncodecEncoderOutput
from .configuration_dac import DACConfig
# model doesn't support batching yet
class DACModel(PreTrainedModel):
config_class = DACConfig
main_input_name = "input_values"
# Set main input to 'input_values' for voice steering
main_input_name = "input_values"
def __init__(self, config):
super().__init__(config)
self.model = DAC(
n_codebooks=config.num_codebooks,
latent_dim=config.latent_dim,
codebook_size=config.codebook_size,
)
self.remove_weight_norm()
self.apply_weight_norm()
def encode(
self, input_values, padding_mask=None, bandwidth=None, return_dict=None, n_quantizers=None, sample_rate=None
):
"""
Encodes the input audio waveform into discrete codes.
Args:
input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Float values of the input audio waveform.
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Padding mask used to pad the `input_values`.
bandwidth (`float`, *optional*):
Not used, kept to have the same inferface as HF encodec.
n_quantizers (`int`, *optional*) :
Number of quantizers to use, by default None
If None, all quantizers are used.
sample_rate (`int`, *optional*) :
Signal sampling_rate
Returns:
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
`codebook` of shape `[batch_size, num_codebooks, frames]`.
Scale is not used here.
"""
_, channels, input_length = input_values.shape
if channels < 1 or channels > 2:
raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
audio_data = self.model.preprocess(input_values, sample_rate)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# TODO: for now, no chunk length
chunk_length = None # self.config.chunk_length
if chunk_length is None:
chunk_length = input_length
stride = input_length
else:
stride = self.config.chunk_stride
if padding_mask is None:
padding_mask = torch.ones_like(input_values).bool()
encoded_frames = []
scales = []
step = chunk_length - stride
if (input_length % stride) - step != 0:
raise ValueError(
"The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
)
for offset in range(0, input_length - step, stride):
mask = padding_mask[..., offset : offset + chunk_length].bool()
frame = audio_data[:, :, offset : offset + chunk_length]
scale = None
_, encoded_frame, _, _, _ = self.model.encode(frame, n_quantizers=n_quantizers)
encoded_frames.append(encoded_frame)
scales.append(scale)
encoded_frames = torch.stack(encoded_frames)
if not return_dict:
return (encoded_frames, scales)
return EncodecEncoderOutput(encoded_frames, scales)
def decode(
self,
audio_codes,
audio_scales,
padding_mask=None,
return_dict=None,
):
"""
Decodes the given frames into an output audio waveform.
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
trimmed.
Args:
audio_codes (`torch.FloatTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
Discret code embeddings computed using `model.encode`.
audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
Not used, kept to have the same inferface as HF encodec.
padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
Padding mask used to pad the `input_values`.
Not used yet, kept to have the same inferface as HF encodec.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
return_dict = return_dict or self.config.return_dict
# TODO: for now, no chunk length
if len(audio_codes) != 1:
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
audio_values = self.model.quantizer.from_codes(audio_codes.squeeze(0))[0]
audio_values = self.model.decode(audio_values)
if not return_dict:
return (audio_values,)
return EncodecDecoderOutput(audio_values)
def forward(self, tensor):
raise ValueError("`DACModel.forward` not implemented yet")
def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
def _apply_weight_norm(module):
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
weight_norm(module)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(module):
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
nn.utils.remove_weight_norm(module)
self.apply(_remove_weight_norm)
|