|
--- |
|
base_model: |
|
- descript/dac_44khz |
|
library_name: transformers.js |
|
--- |
|
|
|
ONNX weights for https://huggingface.co/descript/dac_44khz. |
|
|
|
## Inference sample code |
|
```py |
|
import onnxruntime as ort |
|
|
|
encoder_session = ort.InferenceSession("encoder_model.onnx") |
|
decoder_session = ort.InferenceSession("decoder_model.onnx") |
|
|
|
encoder_inputs = {encoder_session.get_inputs()[0].name: dummy_encoder_inputs.numpy()} |
|
encoder_outputs = encoder_session.run(None, encoder_inputs)[0] |
|
|
|
decoder_inputs = {decoder_session.get_inputs()[0].name: encoder_outputs} |
|
decoder_outputs = decoder_session.run(None, decoder_inputs)[0] |
|
|
|
# Print the results |
|
print("Encoder Output Shape:", encoder_outputs.shape) |
|
print("Decoder Output Shape:", decoder_outputs.shape) |
|
``` |
|
|
|
## Conversion code |
|
```py |
|
import torch |
|
import torch.nn as nn |
|
from transformers import DacModel |
|
|
|
class DacEncoder(nn.Module): |
|
def __init__(self, model): |
|
super(DacEncoder, self).__init__() |
|
self.model = model |
|
|
|
def forward(self, input_values): |
|
return self.model.encode(input_values).audio_codes |
|
|
|
class DacDecoder(nn.Module): |
|
def __init__(self, model): |
|
super(DacDecoder, self).__init__() |
|
self.model = model |
|
|
|
def forward(self, audio_codes): |
|
quantized_representation = self.model.quantizer.from_codes(audio_codes)[0] |
|
return self.model.decoder(quantized_representation) |
|
|
|
model = DacModel.from_pretrained("descript/dac_44khz") |
|
encoder = DacEncoder(model) |
|
decoder = DacDecoder(model) |
|
|
|
# Export encoder |
|
dummy_encoder_inputs = torch.randn((4, 1, 12340)) |
|
torch.onnx.export( |
|
encoder, |
|
dummy_encoder_inputs, |
|
"encoder_model.onnx", |
|
export_params=True, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
input_names=['input_values'], |
|
output_names=['audio_codes'], |
|
dynamic_axes={ |
|
'input_values': {0: 'batch_size', 1: 'num_channels', 2: 'sequence_length'}, |
|
'audio_codes': {0: 'batch_size', 2: 'time_steps'}, |
|
}, |
|
) |
|
|
|
# Export decoder |
|
dummy_decoder_inputs = torch.randint(model.config.codebook_size, (4, model.config.n_codebooks, 100)) |
|
torch.onnx.export( |
|
decoder, |
|
dummy_decoder_inputs, |
|
"decoder_model.onnx", |
|
export_params=True, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
input_names=['audio_codes'], |
|
output_names=['audio_values'], |
|
dynamic_axes={ |
|
'audio_codes': {0: 'batch_size', 2: 'time_steps'}, |
|
'audio_values': {0: 'batch_size', 1: 'num_channels', 2: 'sequence_length'}, |
|
}, |
|
) |
|
``` |