# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import numpy as np from examples.textless_nlp.gslm.unit2speech.tacotron2.text import ( EOS_TOK, SOS_TOK, code_to_sequence, text_to_sequence, ) from examples.textless_nlp.gslm.unit2speech.tacotron2.utils import ( load_code_dict, ) class TacotronInputDataset: def __init__(self, hparams, append_str=""): self.is_text = getattr(hparams, "text_or_code", "text") == "text" if not self.is_text: self.code_dict = load_code_dict(hparams.code_dict) self.code_key = hparams.code_key self.add_sos = hparams.add_sos self.add_eos = hparams.add_eos self.collapse_code = hparams.collapse_code self.append_str = append_str def process_code(self, inp_str): inp_toks = inp_str.split() if self.add_sos: inp_toks = [SOS_TOK] + inp_toks if self.add_eos: inp_toks = inp_toks + [EOS_TOK] return code_to_sequence(inp_toks, self.code_dict, self.collapse_code) def process_text(self, inp_str): return text_to_sequence(inp_str, ["english_cleaners"]) def get_tensor(self, inp_str): # uid, txt, inp_str = self._get_data(idx) inp_str = inp_str + self.append_str if self.is_text: inp_toks = self.process_text(inp_str) else: inp_toks = self.process_code(inp_str) return torch.from_numpy(np.array(inp_toks)).long() def __len__(self): return len(self.data)