Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from __future__ import absolute_import, division, print_function, unicode_literals | |
| import argparse | |
| import concurrent.futures | |
| import json | |
| import multiprocessing | |
| import os | |
| from collections import namedtuple | |
| from itertools import chain | |
| import sentencepiece as spm | |
| from fairseq.data import Dictionary | |
| MILLISECONDS_TO_SECONDS = 0.001 | |
| def process_sample(aud_path, lable, utt_id, sp, tgt_dict): | |
| import torchaudio | |
| input = {} | |
| output = {} | |
| si, ei = torchaudio.info(aud_path) | |
| input["length_ms"] = int( | |
| si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS | |
| ) | |
| input["path"] = aud_path | |
| token = " ".join(sp.EncodeAsPieces(lable)) | |
| ids = tgt_dict.encode_line(token, append_eos=False) | |
| output["text"] = lable | |
| output["token"] = token | |
| output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids])) | |
| return {utt_id: {"input": input, "output": output}} | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--audio-dirs", | |
| nargs="+", | |
| default=["-"], | |
| required=True, | |
| help="input directories with audio files", | |
| ) | |
| parser.add_argument( | |
| "--labels", | |
| required=True, | |
| help="aggregated input labels with format <ID LABEL> per line", | |
| type=argparse.FileType("r", encoding="UTF-8"), | |
| ) | |
| parser.add_argument( | |
| "--spm-model", | |
| required=True, | |
| help="sentencepiece model to use for encoding", | |
| type=argparse.FileType("r", encoding="UTF-8"), | |
| ) | |
| parser.add_argument( | |
| "--dictionary", | |
| required=True, | |
| help="file to load fairseq dictionary from", | |
| type=argparse.FileType("r", encoding="UTF-8"), | |
| ) | |
| parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav") | |
| parser.add_argument( | |
| "--output", | |
| required=True, | |
| type=argparse.FileType("w"), | |
| help="path to save json output", | |
| ) | |
| args = parser.parse_args() | |
| sp = spm.SentencePieceProcessor() | |
| sp.Load(args.spm_model.name) | |
| tgt_dict = Dictionary.load(args.dictionary) | |
| labels = {} | |
| for line in args.labels: | |
| (utt_id, label) = line.split(" ", 1) | |
| labels[utt_id] = label | |
| if len(labels) == 0: | |
| raise Exception("No labels found in ", args.labels_path) | |
| Sample = namedtuple("Sample", "aud_path utt_id") | |
| samples = [] | |
| for path, _, files in chain.from_iterable( | |
| os.walk(path) for path in args.audio_dirs | |
| ): | |
| for f in files: | |
| if f.endswith(args.audio_format): | |
| if len(os.path.splitext(f)) != 2: | |
| raise Exception("Expect <utt_id.extension> file name. Got: ", f) | |
| utt_id = os.path.splitext(f)[0] | |
| if utt_id not in labels: | |
| continue | |
| samples.append(Sample(os.path.join(path, f), utt_id)) | |
| utts = {} | |
| num_cpu = multiprocessing.cpu_count() | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor: | |
| future_to_sample = { | |
| executor.submit( | |
| process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict | |
| ): s | |
| for s in samples | |
| } | |
| for future in concurrent.futures.as_completed(future_to_sample): | |
| try: | |
| data = future.result() | |
| except Exception as exc: | |
| print("generated an exception: ", exc) | |
| else: | |
| utts.update(data) | |
| json.dump({"utts": utts}, args.output, indent=4) | |
| if __name__ == "__main__": | |
| main() | |