File size: 3,962 Bytes
4cda277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
# 
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Based on fairseq (https://github.com/facebookresearch/fairseq)

import logging
import os
import sys

from feature_utils import get_path_iterator, dump_feature

logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("dump_feature")


def main(
        model_type: str,
        tsv_path: str,
        ckpt_path: str,
        whisper_root: str,
        whisper_name: str,
        layer: int,
        nshard: int,
        rank: int,
        feat_dir: str,
        max_chunk: int,
        use_cpu: bool = False
):
    device = "cpu" if use_cpu else "cuda"

    # some checks
    if model_type in ["hubert", "data2vec"]:
        assert ckpt_path and os.path.exists(ckpt_path)
    elif model_type in ["whisper"]:
        assert whisper_name and whisper_root
    else:
        raise ValueError(f"Unsupported model type {model_type}")

    reader = None
    if model_type == "hubert":
        from hubert_feature_reader import HubertFeatureReader
        reader = HubertFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
    elif model_type == "data2vec":
        from data2vec_feature_reader import Data2vecFeatureReader
        reader = Data2vecFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
    elif model_type == "whisper":
        from whisper_feature_reader import WhisperFeatureReader
        reader = WhisperFeatureReader(whisper_root, whisper_name, layer, device=device)

    assert reader is not None

    generator, num = get_path_iterator(tsv_path, nshard, rank)
    dump_feature(reader, generator, num, nshard, rank, feat_dir)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_type",
        required=True,
        type=str,
        choices=["data2vec", "hubert", "whisper"],
        help="the type of the speech encoder."
    )
    parser.add_argument(
        "--tsv_path",
        required=True,
        type=str,
        help="the path to the tsv file."
    )
    parser.add_argument(
        "--ckpt_path",
        required=False,
        type=str,
        default=None,
        help="path to the speech model. must provide for HuBERT and data2vec"
    )
    parser.add_argument(
        "--whisper_root",
        required=False,
        type=str,
        default=None,
        help="root dir to download/store whisper model. must provide for whisper model."
    )
    parser.add_argument(
        "--whisper_name",
        required=False,
        type=str,
        default=None,
        help="name of whisper model. e.g., large-v2. must provide for whisper model."
    )
    parser.add_argument(
        "--layer",
        required=True,
        type=int,
        help="which layer of the model. this is 1-based."
    )
    parser.add_argument(
        "--feat_dir",
        required=True,
        type=str,
        help="the output dir to save the representations."
    )
    parser.add_argument(
        "--nshard",
        required=False,
        type=int,
        default=1,
        help="total number of shards."
    )
    parser.add_argument(
        "--rank",
        required=False,
        type=int,
        default=0,
        help="shard id of this process."
    )
    parser.add_argument(
        "--max_chunk",
        type=int,
        default=1600000,
        help="max number of frames of each batch."
    )
    parser.add_argument(
        "--use_cpu",
        default=False,
        action="store_true",
        help="whether use cpu instead of gpu."
    )
    args = parser.parse_args()
    logger.info(args)

    main(**vars(args))