3v324v23's picture
Add code
a84a65c
raw
history blame
2.92 kB
#
"""Command-line for audio compression."""
import os
import torch
from omegaconf import OmegaConf
import logging
from ..abs_tokenizer import AbsTokenizer
from .models.soundstream import SoundStream
import sys
class AudioTokenizer(AbsTokenizer):
def __init__(self,
ckpt_path,
device=torch.device('cuda'),
):
""" soundstream with fixed bandwidth of 4kbps
It encodes audio with 50 fps and 8-dim vector for each frame
The value of each entry is in [0, 1023]
"""
super(AudioTokenizer, self).__init__()
# GPU is only for offline tokenization
# So, when distributed training is launched, this should still be on CPU
self.device = device
config_path = os.path.join(os.path.dirname(ckpt_path), 'config.yaml')
if not os.path.isfile(config_path):
raise ValueError(f"{config_path} file does not exist.")
config = OmegaConf.load(config_path)
self.ckpt_path = ckpt_path
logging.info(f"using config {config_path} and model {self.ckpt_path}")
self.soundstream = self.build_codec_model(config)
# properties
self.sr = 16000
self.dim_codebook = 1024
self.n_codebook = 3
self.bw = 1.5 # bw=1.5 ---> 3 codebooks
self.freq = self.n_codebook * 50
self.mask_id = self.dim_codebook * self.n_codebook
def build_codec_model(self, config):
model = eval(config.generator.name)(**config.generator.config)
parameter_dict = torch.load(self.ckpt_path, map_location='cpu')
model.load_state_dict(parameter_dict['codec_model']) # load model
model = model.to(self.device)
return model
@torch.no_grad()
def encode(self, wav):
wav = wav.unsqueeze(1)# .to(self.device) # (1,1,len)
compressed = self.soundstream.encoder(wav) # [n_codebook, 1, n_frames]
return compressed
@torch.no_grad()
def decode(self, audio):
out = self.soundstream.decoder(audio)
check_clipping(out, rescale=False)
return out
def check_clipping(wav, rescale):
if rescale:
return
mx = wav.abs().max()
limit = 0.99
if mx > limit:
print(
f"Clipping!! max scale {mx}, limit is {limit}. "
"To avoid clipping, use the `-r` option to rescale the output.",
file=sys.stderr)
if __name__ == '__main__':
tokenizer = AudioTokenizer(device=torch.device('cuda:0')).cuda()
wav = '/home/v-dongyang/data/FSD/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/FreeSound_flac/537271.flac'
codec = tokenizer.tokenize(wav)
print(codec)
# wav = tokenizer.detokenize(codec)
# import torchaudio
# torchaudio.save('desing.wav', wav, 16000, bits_per_sample=16, encoding='PCM_S')
# print(wav.shape)