Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Run inference for pre-processed data with a trained model. | |
""" | |
import logging | |
import math | |
import os | |
import sentencepiece as spm | |
import torch | |
from fairseq import checkpoint_utils, options, progress_bar, utils, tasks | |
from fairseq.meters import StopwatchMeter, TimeMeter | |
from fairseq.utils import import_user_module | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
def add_asr_eval_argument(parser): | |
parser.add_argument("--kspmodel", default=None, help="sentence piece model") | |
parser.add_argument( | |
"--wfstlm", default=None, help="wfstlm on dictonary output units" | |
) | |
parser.add_argument( | |
"--rnnt_decoding_type", | |
default="greedy", | |
help="wfstlm on dictonary\ | |
output units", | |
) | |
parser.add_argument( | |
"--lm-weight", | |
"--lm_weight", | |
type=float, | |
default=0.2, | |
help="weight for lm while interpolating with neural score", | |
) | |
parser.add_argument( | |
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" | |
) | |
parser.add_argument( | |
"--w2l-decoder", choices=["viterbi", "kenlm"], help="use a w2l decoder" | |
) | |
parser.add_argument("--lexicon", help="lexicon for w2l decoder") | |
parser.add_argument("--kenlm-model", help="kenlm model for w2l decoder") | |
parser.add_argument("--beam-threshold", type=float, default=25.0) | |
parser.add_argument("--word-score", type=float, default=1.0) | |
parser.add_argument("--unk-weight", type=float, default=-math.inf) | |
parser.add_argument("--sil-weight", type=float, default=0.0) | |
return parser | |
def check_args(args): | |
assert args.path is not None, "--path required for generation!" | |
assert args.results_path is not None, "--results_path required for generation!" | |
assert ( | |
not args.sampling or args.nbest == args.beam | |
), "--sampling requires --nbest to be equal to --beam" | |
assert ( | |
args.replace_unk is None or args.raw_text | |
), "--replace-unk requires a raw text dataset (--raw-text)" | |
def get_dataset_itr(args, task): | |
return task.get_batch_iterator( | |
dataset=task.dataset(args.gen_subset), | |
max_tokens=args.max_tokens, | |
max_sentences=args.max_sentences, | |
max_positions=(1000000.0, 1000000.0), | |
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, | |
required_batch_size_multiple=args.required_batch_size_multiple, | |
num_shards=args.num_shards, | |
shard_id=args.shard_id, | |
num_workers=args.num_workers, | |
).next_epoch_itr(shuffle=False) | |
def process_predictions( | |
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id | |
): | |
for hypo in hypos[: min(len(hypos), args.nbest)]: | |
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) | |
hyp_words = sp.DecodePieces(hyp_pieces.split()) | |
print( | |
"{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"] | |
) | |
print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"]) | |
tgt_pieces = tgt_dict.string(target_tokens) | |
tgt_words = sp.DecodePieces(tgt_pieces.split()) | |
print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"]) | |
print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]) | |
# only score top hypothesis | |
if not args.quiet: | |
logger.debug("HYPO:" + hyp_words) | |
logger.debug("TARGET:" + tgt_words) | |
logger.debug("___________________") | |
def prepare_result_files(args): | |
def get_res_file(file_prefix): | |
path = os.path.join( | |
args.results_path, | |
"{}-{}-{}.txt".format( | |
file_prefix, os.path.basename(args.path), args.gen_subset | |
), | |
) | |
return open(path, "w", buffering=1) | |
return { | |
"hypo.words": get_res_file("hypo.word"), | |
"hypo.units": get_res_file("hypo.units"), | |
"ref.words": get_res_file("ref.word"), | |
"ref.units": get_res_file("ref.units"), | |
} | |
def load_models_and_criterions(filenames, arg_overrides=None, task=None): | |
models = [] | |
criterions = [] | |
for filename in filenames: | |
if not os.path.exists(filename): | |
raise IOError("Model file not found: {}".format(filename)) | |
state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides) | |
args = state["args"] | |
if task is None: | |
task = tasks.setup_task(args) | |
# build model for ensemble | |
model = task.build_model(args) | |
model.load_state_dict(state["model"], strict=True) | |
models.append(model) | |
criterion = task.build_criterion(args) | |
if "criterion" in state: | |
criterion.load_state_dict(state["criterion"], strict=True) | |
criterions.append(criterion) | |
return models, criterions, args | |
def optimize_models(args, use_cuda, models): | |
"""Optimize ensemble for generation | |
""" | |
for model in models: | |
model.make_generation_fast_( | |
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, | |
need_attn=args.print_alignment, | |
) | |
if args.fp16: | |
model.half() | |
if use_cuda: | |
model.cuda() | |
def main(args): | |
check_args(args) | |
import_user_module(args) | |
if args.max_tokens is None and args.max_sentences is None: | |
args.max_tokens = 30000 | |
logger.info(args) | |
use_cuda = torch.cuda.is_available() and not args.cpu | |
# Load dataset splits | |
task = tasks.setup_task(args) | |
task.load_dataset(args.gen_subset) | |
logger.info( | |
"| {} {} {} examples".format( | |
args.data, args.gen_subset, len(task.dataset(args.gen_subset)) | |
) | |
) | |
# Set dictionary | |
tgt_dict = task.target_dictionary | |
logger.info("| decoding with criterion {}".format(args.criterion)) | |
# Load ensemble | |
logger.info("| loading model(s) from {}".format(args.path)) | |
models, criterions, _model_args = load_models_and_criterions( | |
args.path.split(":"), | |
arg_overrides=eval(args.model_overrides), # noqa | |
task=task, | |
) | |
optimize_models(args, use_cuda, models) | |
# hack to pass transitions to W2lDecoder | |
if args.criterion == "asg_loss": | |
trans = criterions[0].asg.trans.data | |
args.asg_transitions = torch.flatten(trans).tolist() | |
# Load dataset (possibly sharded) | |
itr = get_dataset_itr(args, task) | |
# Initialize generator | |
gen_timer = StopwatchMeter() | |
generator = task.build_generator(args) | |
num_sentences = 0 | |
if not os.path.exists(args.results_path): | |
os.makedirs(args.results_path) | |
sp = spm.SentencePieceProcessor() | |
sp.Load(os.path.join(args.data, "spm.model")) | |
res_files = prepare_result_files(args) | |
with progress_bar.build_progress_bar(args, itr) as t: | |
wps_meter = TimeMeter() | |
for sample in t: | |
sample = utils.move_to_cuda(sample) if use_cuda else sample | |
if "net_input" not in sample: | |
continue | |
prefix_tokens = None | |
if args.prefix_size > 0: | |
prefix_tokens = sample["target"][:, : args.prefix_size] | |
gen_timer.start() | |
hypos = task.inference_step(generator, models, sample, prefix_tokens) | |
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) | |
gen_timer.stop(num_generated_tokens) | |
for i, sample_id in enumerate(sample["id"].tolist()): | |
speaker = task.dataset(args.gen_subset).speakers[int(sample_id)] | |
id = task.dataset(args.gen_subset).ids[int(sample_id)] | |
target_tokens = ( | |
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() | |
) | |
# Process top predictions | |
process_predictions( | |
args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id | |
) | |
wps_meter.update(num_generated_tokens) | |
t.log({"wps": round(wps_meter.avg)}) | |
num_sentences += sample["nsentences"] | |
logger.info( | |
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" | |
"sentences/s, {:.2f} tokens/s)".format( | |
num_sentences, | |
gen_timer.n, | |
gen_timer.sum, | |
num_sentences / gen_timer.sum, | |
1.0 / gen_timer.avg, | |
) | |
) | |
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) | |
def cli_main(): | |
parser = options.get_generation_parser() | |
parser = add_asr_eval_argument(parser) | |
args = options.parse_args_and_arch(parser) | |
main(args) | |
if __name__ == "__main__": | |
cli_main() | |