KingNish's picture
Upload ./RepCodec/examples/dump_feature.py with huggingface_hub
4cda277 verified
raw
history blame
3.96 kB
# Copyright (c) ByteDance, Inc. and its affiliates.
# Copyright (c) Chutong Meng
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Based on fairseq (https://github.com/facebookresearch/fairseq)
import logging
import os
import sys
from feature_utils import get_path_iterator, dump_feature
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("dump_feature")
def main(
model_type: str,
tsv_path: str,
ckpt_path: str,
whisper_root: str,
whisper_name: str,
layer: int,
nshard: int,
rank: int,
feat_dir: str,
max_chunk: int,
use_cpu: bool = False
):
device = "cpu" if use_cpu else "cuda"
# some checks
if model_type in ["hubert", "data2vec"]:
assert ckpt_path and os.path.exists(ckpt_path)
elif model_type in ["whisper"]:
assert whisper_name and whisper_root
else:
raise ValueError(f"Unsupported model type {model_type}")
reader = None
if model_type == "hubert":
from hubert_feature_reader import HubertFeatureReader
reader = HubertFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
elif model_type == "data2vec":
from data2vec_feature_reader import Data2vecFeatureReader
reader = Data2vecFeatureReader(ckpt_path, layer, device=device, max_chunk=max_chunk)
elif model_type == "whisper":
from whisper_feature_reader import WhisperFeatureReader
reader = WhisperFeatureReader(whisper_root, whisper_name, layer, device=device)
assert reader is not None
generator, num = get_path_iterator(tsv_path, nshard, rank)
dump_feature(reader, generator, num, nshard, rank, feat_dir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
required=True,
type=str,
choices=["data2vec", "hubert", "whisper"],
help="the type of the speech encoder."
)
parser.add_argument(
"--tsv_path",
required=True,
type=str,
help="the path to the tsv file."
)
parser.add_argument(
"--ckpt_path",
required=False,
type=str,
default=None,
help="path to the speech model. must provide for HuBERT and data2vec"
)
parser.add_argument(
"--whisper_root",
required=False,
type=str,
default=None,
help="root dir to download/store whisper model. must provide for whisper model."
)
parser.add_argument(
"--whisper_name",
required=False,
type=str,
default=None,
help="name of whisper model. e.g., large-v2. must provide for whisper model."
)
parser.add_argument(
"--layer",
required=True,
type=int,
help="which layer of the model. this is 1-based."
)
parser.add_argument(
"--feat_dir",
required=True,
type=str,
help="the output dir to save the representations."
)
parser.add_argument(
"--nshard",
required=False,
type=int,
default=1,
help="total number of shards."
)
parser.add_argument(
"--rank",
required=False,
type=int,
default=0,
help="shard id of this process."
)
parser.add_argument(
"--max_chunk",
type=int,
default=1600000,
help="max number of frames of each batch."
)
parser.add_argument(
"--use_cpu",
default=False,
action="store_true",
help="whether use cpu instead of gpu."
)
args = parser.parse_args()
logger.info(args)
main(**vars(args))