# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import gc import os import random import shutil import numpy as np import torch import tqdm from examples.textless_nlp.gslm.speech2unit.pretrained.cpc_feature_reader import ( CpcFeatureReader, ) from examples.textless_nlp.gslm.speech2unit.pretrained.hubert_feature_reader import ( HubertFeatureReader, ) from examples.textless_nlp.gslm.speech2unit.pretrained.logmel_feature_reader import ( LogMelFeatureReader, ) from examples.textless_nlp.gslm.speech2unit.pretrained.w2v2_feature_reader import ( Wav2VecFeatureReader, ) def get_feature_reader(feature_type): if feature_type == "logmel": return LogMelFeatureReader elif feature_type == "hubert": return HubertFeatureReader elif feature_type == "w2v2": return Wav2VecFeatureReader elif feature_type == "cpc": return CpcFeatureReader else: raise NotImplementedError(f"{feature_type} is not supported.") def get_feature_iterator( feature_type, checkpoint_path, layer, manifest_path, sample_pct ): feature_reader_cls = get_feature_reader(feature_type) with open(manifest_path, "r") as fp: lines = fp.read().split("\n") root = lines.pop(0).strip() file_path_list = [ os.path.join(root, line.split("\t")[0]) for line in lines if len(line) > 0 ] if sample_pct < 1.0: file_path_list = random.sample( file_path_list, int(sample_pct * len(file_path_list)) ) num_files = len(file_path_list) reader = feature_reader_cls( checkpoint_path=checkpoint_path, layer=layer ) def iterate(): for file_path in file_path_list: feats = reader.get_feats(file_path) yield feats.cpu().numpy() return iterate, num_files def get_features( feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten ): generator, num_files = get_feature_iterator( feature_type=feature_type, checkpoint_path=checkpoint_path, layer=layer, manifest_path=manifest_path, sample_pct=sample_pct, ) iterator = generator() features_list = [] for features in tqdm.tqdm(iterator, total=num_files): features_list.append(features) # Explicit clean up del iterator del generator gc.collect() torch.cuda.empty_cache() if flatten: return np.concatenate(features_list) return features_list def get_and_dump_features( feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten, out_features_path, ): # Feature extraction features_batch = get_features( feature_type=feature_type, checkpoint_path=checkpoint_path, layer=layer, manifest_path=manifest_path, sample_pct=sample_pct, flatten=flatten, ) # Save features out_dir_path = os.path.dirname(out_features_path) os.makedirs(out_dir_path, exist_ok=True) shutil.copyfile( manifest_path, os.path.join(out_dir_path, os.path.basename(manifest_path)), ) np.save(out_features_path, features_batch) return features_batch