KingNish commited on
Commit
0fa0976
·
verified ·
1 Parent(s): c5a0b98

Upload ./RepCodec/repcodec/tokenize.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. RepCodec/repcodec/tokenize.py +212 -0
RepCodec/repcodec/tokenize.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Tuple, List, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+ import yaml
15
+ from tqdm import tqdm
16
+
17
+ from repcodec.RepCodec import RepCodec
18
+
19
+ ALL_MODELS = {
20
+ "data2vec_base_l6": 768,
21
+ "data2vec_large_l18": 1024,
22
+ "hubert_base_l9": 768,
23
+ "hubert_large_l18": 1024,
24
+ "whisper_medium_l24": 1024,
25
+ "whisper_large_l32": 1280
26
+ }
27
+
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
31
+ parser.add_argument(
32
+ "in_dir",
33
+ type=str,
34
+ help="directory of representations to be tokenized."
35
+ )
36
+ parser.add_argument(
37
+ "--model",
38
+ required=True,
39
+ type=str,
40
+ help="path of the RepCodec model."
41
+ )
42
+ parser.add_argument(
43
+ "--tsv_path",
44
+ required=True,
45
+ type=str,
46
+ help="path of the tsv file."
47
+ )
48
+ parser.add_argument(
49
+ "--model_config_path",
50
+ default=None,
51
+ type=str,
52
+ help="please provide this training config if you are using the model you trained yourself."
53
+ )
54
+ parser.add_argument(
55
+ "--n_shard",
56
+ required=False,
57
+ type=int,
58
+ default=1,
59
+ help="number of shards of representations."
60
+ )
61
+ parser.add_argument(
62
+ "--use_gpu",
63
+ default=False,
64
+ action="store_true",
65
+ help="whether use gpu for inference."
66
+ )
67
+ parser.add_argument(
68
+ "--batch_size",
69
+ default=1,
70
+ type=int,
71
+ help="number of utterances for each mini batch."
72
+ )
73
+ parser.add_argument(
74
+ "--out_dir",
75
+ type=str,
76
+ default=".",
77
+ help="the directory to save the output."
78
+ )
79
+ return parser.parse_args()
80
+
81
+
82
+ def load_model(model_path: str, config_path: Optional[str] = None):
83
+ if config_path is None:
84
+ name = os.path.basename(model_path).strip(".pkl")
85
+ assert name in ALL_MODELS.keys(), f"Cannot find configs for {model_path}. " \
86
+ f"Please provide the config file you used for training."
87
+ config = os.path.join(os.path.dirname(__file__), "configs", f"repcodec_dim{ALL_MODELS[name]}.yaml")
88
+ with open(config) as fp:
89
+ conf = yaml.load(fp, Loader=yaml.FullLoader)
90
+ else:
91
+ with open(config_path) as fp:
92
+ conf = yaml.load(fp, Loader=yaml.FullLoader)["model_params"]
93
+
94
+ model = RepCodec(**conf)
95
+ model.load_state_dict(torch.load(model_path, map_location="cpu")["model"]["repcodec"])
96
+ model.quantizer.initial()
97
+ model.eval()
98
+ return model
99
+
100
+
101
+ def load_shard(in_dir: Path, rank: int, n_shard: int) -> Tuple[np.ndarray, List[int]]:
102
+ feat_path = in_dir / f"{rank}_{n_shard}.npy"
103
+ len_path = in_dir / f"{rank}_{n_shard}.len"
104
+
105
+ with open(len_path) as fp:
106
+ lengths = [int(line.strip()) for line in fp]
107
+
108
+ return np.load(feat_path.as_posix(), mmap_mode="r"), lengths
109
+
110
+
111
+ def pad_data(data: List[np.ndarray]) -> List[np.ndarray]:
112
+ max_len = max([d.shape[0] for d in data])
113
+ data = [
114
+ np.pad(d, [(0, max_len - d.shape[0]), (0, 0)], "constant", constant_values=0.0)
115
+ for d in data
116
+ ]
117
+ return data
118
+
119
+
120
+ def make_batch_data(data: np.ndarray, shard_lengths: List[int], batch_size: int):
121
+ batch_data = []
122
+ batch_lens = []
123
+ offsets = np.cumsum([0] + shard_lengths)
124
+ assert len(data) == offsets[-1], f"{len(data)} {offsets[-1]}"
125
+
126
+ # from longest to shortest
127
+ for i in range(len(shard_lengths)):
128
+ if batch_size > len(batch_data):
129
+ batch_data.append(data[offsets[i]: offsets[i + 1]])
130
+ batch_lens.append(shard_lengths[i])
131
+ else:
132
+ yield {
133
+ "data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float), # (bsz, seq len, hidden dim)
134
+ "lengths": batch_lens
135
+ }
136
+ batch_data = [data[offsets[i]: offsets[i + 1]]]
137
+ batch_lens = [shard_lengths[i]]
138
+ if len(batch_data) > 0:
139
+ yield {
140
+ "data": torch.tensor(np.stack(pad_data(batch_data)), dtype=torch.float),
141
+ "lengths": batch_lens
142
+ }
143
+
144
+
145
+ def tokenize_batch(model: RepCodec, batch: dict, device: str) -> List[List[int]]:
146
+ with torch.no_grad():
147
+ data = batch["data"].transpose(1, 2).to(device) # (bsz, hidden dim, seq len)
148
+ x = model.encoder(data)
149
+ z = model.projector(x)
150
+ _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
151
+
152
+ # when bsz=1: (1, seq len)
153
+ if idx.dim() == 2:
154
+ return idx.cpu().data.numpy().tolist()
155
+ # when bsz>1: (1, bsz, seq len)
156
+ tokens = idx.cpu().data.numpy().tolist()[0]
157
+ res = []
158
+ batch_lens = batch["lengths"]
159
+ for i in range(len(tokens)):
160
+ n_tokens = batch_lens[i]
161
+ res.append(tokens[i][:n_tokens])
162
+ return res
163
+
164
+
165
+ def load_tsv(path: str):
166
+ with open(path) as fp:
167
+ root = fp.readline().strip()
168
+ names = []
169
+ for line in fp:
170
+ names.append(line.strip().split("\t")[0])
171
+ return root, names
172
+
173
+
174
+ def cli():
175
+ args = parse_args()
176
+ device = "cuda" if args.use_gpu else "cpu"
177
+
178
+ model = load_model(model_path=args.model, config_path=args.model_config_path)
179
+ model.to(device)
180
+
181
+ in_dir = Path(args.in_dir)
182
+ n_shard = args.n_shard
183
+ batch_size = args.batch_size
184
+
185
+ root_dir, file_names = load_tsv(args.tsv_path)
186
+
187
+ output_dir = args.out_dir
188
+ os.makedirs(output_dir, exist_ok=True)
189
+
190
+ processed_cnt = 0
191
+ pbar = tqdm(total=len(file_names))
192
+ with open(os.path.join(output_dir, "tokens"), mode="w+") as fp:
193
+ fp.write(f"{root_dir}\n")
194
+
195
+ for rank in range(n_shard):
196
+ shard_data, shard_lengths = load_shard(in_dir, rank, n_shard)
197
+ for batch in make_batch_data(shard_data, shard_lengths, batch_size=batch_size):
198
+ batch_tokens = tokenize_batch(model, batch, device)
199
+
200
+ for tokens in batch_tokens:
201
+ fp.write(f"{file_names[processed_cnt]}\t{' '.join(map(str, tokens))}\n")
202
+ processed_cnt += 1
203
+
204
+ pbar.update(len(batch_tokens))
205
+ assert processed_cnt == len(file_names), f"# lines of tsv do not match # of representations!"
206
+
207
+ pbar.close()
208
+ print("Tokenize successfully!")
209
+
210
+
211
+ if __name__ == '__main__':
212
+ cli()