Spaces:
Running
on
Zero
Running
on
Zero
# | |
"""Command-line for audio compression.""" | |
import os | |
import torch | |
from omegaconf import OmegaConf | |
import logging | |
from ..abs_tokenizer import AbsTokenizer | |
from .models.soundstream import SoundStream | |
import sys | |
class AudioTokenizer(AbsTokenizer): | |
def __init__(self, | |
ckpt_path, | |
device=torch.device('cuda'), | |
): | |
""" soundstream with fixed bandwidth of 4kbps | |
It encodes audio with 50 fps and 8-dim vector for each frame | |
The value of each entry is in [0, 1023] | |
""" | |
super(AudioTokenizer, self).__init__() | |
# GPU is only for offline tokenization | |
# So, when distributed training is launched, this should still be on CPU | |
self.device = device | |
config_path = os.path.join(os.path.dirname(ckpt_path), 'config.yaml') | |
if not os.path.isfile(config_path): | |
raise ValueError(f"{config_path} file does not exist.") | |
config = OmegaConf.load(config_path) | |
self.ckpt_path = ckpt_path | |
logging.info(f"using config {config_path} and model {self.ckpt_path}") | |
self.soundstream = self.build_codec_model(config) | |
# properties | |
self.sr = 16000 | |
self.dim_codebook = 1024 | |
self.n_codebook = 3 | |
self.bw = 1.5 # bw=1.5 ---> 3 codebooks | |
self.freq = self.n_codebook * 50 | |
self.mask_id = self.dim_codebook * self.n_codebook | |
def build_codec_model(self, config): | |
model = eval(config.generator.name)(**config.generator.config) | |
parameter_dict = torch.load(self.ckpt_path, map_location='cpu') | |
model.load_state_dict(parameter_dict['codec_model']) # load model | |
model = model.to(self.device) | |
return model | |
def encode(self, wav): | |
wav = wav.unsqueeze(1)# .to(self.device) # (1,1,len) | |
compressed = self.soundstream.encoder(wav) # [n_codebook, 1, n_frames] | |
return compressed | |
def decode(self, audio): | |
out = self.soundstream.decoder(audio) | |
check_clipping(out, rescale=False) | |
return out | |
def check_clipping(wav, rescale): | |
if rescale: | |
return | |
mx = wav.abs().max() | |
limit = 0.99 | |
if mx > limit: | |
print( | |
f"Clipping!! max scale {mx}, limit is {limit}. " | |
"To avoid clipping, use the `-r` option to rescale the output.", | |
file=sys.stderr) | |
if __name__ == '__main__': | |
tokenizer = AudioTokenizer(device=torch.device('cuda:0')).cuda() | |
wav = '/home/v-dongyang/data/FSD/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/FreeSound_flac/537271.flac' | |
codec = tokenizer.tokenize(wav) | |
print(codec) | |
# wav = tokenizer.detokenize(codec) | |
# import torchaudio | |
# torchaudio.save('desing.wav', wav, 16000, bits_per_sample=16, encoding='PCM_S') | |
# print(wav.shape) | |