# """Command-line for audio compression.""" import os import torch from omegaconf import OmegaConf import logging from ..abs_tokenizer import AbsTokenizer from .models.soundstream import SoundStream import sys class AudioTokenizer(AbsTokenizer): def __init__(self, ckpt_path, device=torch.device('cuda'), ): """ soundstream with fixed bandwidth of 4kbps It encodes audio with 50 fps and 8-dim vector for each frame The value of each entry is in [0, 1023] """ super(AudioTokenizer, self).__init__() # GPU is only for offline tokenization # So, when distributed training is launched, this should still be on CPU self.device = device config_path = os.path.join(os.path.dirname(ckpt_path), 'config.yaml') if not os.path.isfile(config_path): raise ValueError(f"{config_path} file does not exist.") config = OmegaConf.load(config_path) self.ckpt_path = ckpt_path logging.info(f"using config {config_path} and model {self.ckpt_path}") self.soundstream = self.build_codec_model(config) # properties self.sr = 16000 self.dim_codebook = 1024 self.n_codebook = 3 self.bw = 1.5 # bw=1.5 ---> 3 codebooks self.freq = self.n_codebook * 50 self.mask_id = self.dim_codebook * self.n_codebook def build_codec_model(self, config): model = eval(config.generator.name)(**config.generator.config) parameter_dict = torch.load(self.ckpt_path, map_location='cpu') model.load_state_dict(parameter_dict['codec_model']) # load model model = model.to(self.device) return model @torch.no_grad() def encode(self, wav): wav = wav.unsqueeze(1)# .to(self.device) # (1,1,len) compressed = self.soundstream.encoder(wav) # [n_codebook, 1, n_frames] return compressed @torch.no_grad() def decode(self, audio): out = self.soundstream.decoder(audio) check_clipping(out, rescale=False) return out def check_clipping(wav, rescale): if rescale: return mx = wav.abs().max() limit = 0.99 if mx > limit: print( f"Clipping!! max scale {mx}, limit is {limit}. " "To avoid clipping, use the `-r` option to rescale the output.", file=sys.stderr) if __name__ == '__main__': tokenizer = AudioTokenizer(device=torch.device('cuda:0')).cuda() wav = '/home/v-dongyang/data/FSD/mnt/fast/nobackup/scratch4weeks/xm00178/WavCaps/data/waveforms/FreeSound_flac/537271.flac' codec = tokenizer.tokenize(wav) print(codec) # wav = tokenizer.detokenize(codec) # import torchaudio # torchaudio.save('desing.wav', wav, 16000, bits_per_sample=16, encoding='PCM_S') # print(wav.shape)