Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Translate pre-processed data with a trained model. | |
| """ | |
| import ast | |
| import logging | |
| import math | |
| import os | |
| import sys | |
| from argparse import Namespace | |
| from itertools import chain | |
| import numpy as np | |
| import torch | |
| from omegaconf import DictConfig | |
| from fairseq import checkpoint_utils, options, scoring, tasks, utils | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from fairseq.logging import progress_bar | |
| from fairseq.logging.meters import StopwatchMeter, TimeMeter | |
| def main(cfg: DictConfig): | |
| if isinstance(cfg, Namespace): | |
| cfg = convert_namespace_to_omegaconf(cfg) | |
| assert cfg.common_eval.path is not None, "--path required for generation!" | |
| assert ( | |
| not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam | |
| ), "--sampling requires --nbest to be equal to --beam" | |
| assert ( | |
| cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" | |
| ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" | |
| if cfg.common_eval.results_path is not None: | |
| os.makedirs(cfg.common_eval.results_path, exist_ok=True) | |
| output_path = os.path.join( | |
| cfg.common_eval.results_path, | |
| "generate-{}.txt".format(cfg.dataset.gen_subset), | |
| ) | |
| with open(output_path, "w", buffering=1, encoding="utf-8") as h: | |
| return _main(cfg, h) | |
| else: | |
| return _main(cfg, sys.stdout) | |
| def get_symbols_to_strip_from_output(generator): | |
| if hasattr(generator, "symbols_to_strip_from_output"): | |
| return generator.symbols_to_strip_from_output | |
| else: | |
| return {generator.eos} | |
| def _main(cfg: DictConfig, output_file): | |
| logging.basicConfig( | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
| stream=output_file, | |
| ) | |
| logger = logging.getLogger("fairseq_cli.generate") | |
| utils.import_user_module(cfg.common) | |
| if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: | |
| cfg.dataset.max_tokens = 12000 | |
| logger.info(cfg) | |
| # Fix seed for stochastic decoding | |
| if cfg.common.seed is not None and not cfg.generation.no_seed_provided: | |
| np.random.seed(cfg.common.seed) | |
| utils.set_torch_seed(cfg.common.seed) | |
| use_cuda = torch.cuda.is_available() and not cfg.common.cpu | |
| # Load dataset splits | |
| task = tasks.setup_task(cfg.task) | |
| # Set dictionaries | |
| try: | |
| src_dict = getattr(task, "source_dictionary", None) | |
| except NotImplementedError: | |
| src_dict = None | |
| tgt_dict = task.target_dictionary | |
| overrides = ast.literal_eval(cfg.common_eval.model_overrides) | |
| # Load ensemble | |
| logger.info("loading model(s) from {}".format(cfg.common_eval.path)) | |
| models, saved_cfg = checkpoint_utils.load_model_ensemble( | |
| utils.split_paths(cfg.common_eval.path), | |
| arg_overrides=overrides, | |
| task=task, | |
| suffix=cfg.checkpoint.checkpoint_suffix, | |
| strict=(cfg.checkpoint.checkpoint_shard_count == 1), | |
| num_shards=cfg.checkpoint.checkpoint_shard_count, | |
| ) | |
| # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config | |
| task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) | |
| if cfg.generation.lm_path is not None: | |
| overrides["data"] = cfg.task.data | |
| try: | |
| lms, _ = checkpoint_utils.load_model_ensemble( | |
| [cfg.generation.lm_path], arg_overrides=overrides, task=None | |
| ) | |
| except: | |
| logger.warning( | |
| f"Failed to load language model! Please make sure that the language model dict is the same " | |
| f"as target dict and is located in the data dir ({cfg.task.data})" | |
| ) | |
| raise | |
| assert len(lms) == 1 | |
| else: | |
| lms = [None] | |
| # Optimize ensemble for generation | |
| for model in chain(models, lms): | |
| if model is None: | |
| continue | |
| if cfg.common.fp16: | |
| model.half() | |
| if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
| model.cuda() | |
| model.prepare_for_inference_(cfg) | |
| # Load alignment dictionary for unknown word replacement | |
| # (None if no unknown word replacement, empty if no path to align dictionary) | |
| align_dict = utils.load_align_dict(cfg.generation.replace_unk) | |
| # Load dataset (possibly sharded) | |
| itr = task.get_batch_iterator( | |
| dataset=task.dataset(cfg.dataset.gen_subset), | |
| max_tokens=cfg.dataset.max_tokens, | |
| max_sentences=cfg.dataset.batch_size, | |
| max_positions=utils.resolve_max_positions( | |
| task.max_positions(), *[m.max_positions() for m in models] | |
| ), | |
| ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, | |
| required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, | |
| seed=cfg.common.seed, | |
| num_shards=cfg.distributed_training.distributed_world_size, | |
| shard_id=cfg.distributed_training.distributed_rank, | |
| num_workers=cfg.dataset.num_workers, | |
| data_buffer_size=cfg.dataset.data_buffer_size, | |
| ).next_epoch_itr(shuffle=False) | |
| progress = progress_bar.progress_bar( | |
| itr, | |
| log_format=cfg.common.log_format, | |
| log_interval=cfg.common.log_interval, | |
| default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), | |
| ) | |
| # Initialize generator | |
| gen_timer = StopwatchMeter() | |
| extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} | |
| generator = task.build_generator( | |
| models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs | |
| ) | |
| # Handle tokenization and BPE | |
| tokenizer = task.build_tokenizer(cfg.tokenizer) | |
| bpe = task.build_bpe(cfg.bpe) | |
| def decode_fn(x): | |
| if bpe is not None: | |
| x = bpe.decode(x) | |
| if tokenizer is not None: | |
| x = tokenizer.decode(x) | |
| return x | |
| scorer = scoring.build_scorer(cfg.scoring, tgt_dict) | |
| num_sentences = 0 | |
| has_target = True | |
| wps_meter = TimeMeter() | |
| for sample in progress: | |
| sample = utils.move_to_cuda(sample) if use_cuda else sample | |
| if "net_input" not in sample: | |
| continue | |
| prefix_tokens = None | |
| if cfg.generation.prefix_size > 0: | |
| prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] | |
| constraints = None | |
| if "constraints" in sample: | |
| constraints = sample["constraints"] | |
| gen_timer.start() | |
| hypos = task.inference_step( | |
| generator, | |
| models, | |
| sample, | |
| prefix_tokens=prefix_tokens, | |
| constraints=constraints, | |
| ) | |
| num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) | |
| gen_timer.stop(num_generated_tokens) | |
| for i, sample_id in enumerate(sample["id"].tolist()): | |
| has_target = sample["target"] is not None | |
| # Remove padding | |
| if "src_tokens" in sample["net_input"]: | |
| src_tokens = utils.strip_pad( | |
| sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() | |
| ) | |
| else: | |
| src_tokens = None | |
| target_tokens = None | |
| if has_target: | |
| target_tokens = ( | |
| utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() | |
| ) | |
| # Either retrieve the original sentences or regenerate them from tokens. | |
| if align_dict is not None: | |
| src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( | |
| sample_id | |
| ) | |
| target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( | |
| sample_id | |
| ) | |
| else: | |
| if src_dict is not None: | |
| src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) | |
| else: | |
| src_str = "" | |
| if has_target: | |
| target_str = tgt_dict.string( | |
| target_tokens, | |
| cfg.common_eval.post_process, | |
| escape_unk=True, | |
| extra_symbols_to_ignore=get_symbols_to_strip_from_output( | |
| generator | |
| ), | |
| ) | |
| src_str = decode_fn(src_str) | |
| if has_target: | |
| target_str = decode_fn(target_str) | |
| if not cfg.common_eval.quiet: | |
| if src_dict is not None: | |
| print("S-{}\t{}".format(sample_id, src_str), file=output_file) | |
| if has_target: | |
| print("T-{}\t{}".format(sample_id, target_str), file=output_file) | |
| # Process top predictions | |
| for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): | |
| hypo_tokens, hypo_str, alignment = utils.post_process_prediction( | |
| hypo_tokens=hypo["tokens"].int().cpu(), | |
| src_str=src_str, | |
| alignment=hypo["alignment"], | |
| align_dict=align_dict, | |
| tgt_dict=tgt_dict, | |
| remove_bpe=cfg.common_eval.post_process, | |
| extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), | |
| ) | |
| detok_hypo_str = decode_fn(hypo_str) | |
| if not cfg.common_eval.quiet: | |
| score = hypo["score"] / math.log(2) # convert to base 2 | |
| # original hypothesis (after tokenization and BPE) | |
| print( | |
| "H-{}\t{}\t{}".format(sample_id, score, hypo_str), | |
| file=output_file, | |
| ) | |
| # detokenized hypothesis | |
| print( | |
| "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), | |
| file=output_file, | |
| ) | |
| print( | |
| "P-{}\t{}".format( | |
| sample_id, | |
| " ".join( | |
| map( | |
| lambda x: "{:.4f}".format(x), | |
| # convert from base e to base 2 | |
| hypo["positional_scores"] | |
| .div_(math.log(2)) | |
| .tolist(), | |
| ) | |
| ), | |
| ), | |
| file=output_file, | |
| ) | |
| if cfg.generation.print_alignment == "hard": | |
| print( | |
| "A-{}\t{}".format( | |
| sample_id, | |
| " ".join( | |
| [ | |
| "{}-{}".format(src_idx, tgt_idx) | |
| for src_idx, tgt_idx in alignment | |
| ] | |
| ), | |
| ), | |
| file=output_file, | |
| ) | |
| if cfg.generation.print_alignment == "soft": | |
| print( | |
| "A-{}\t{}".format( | |
| sample_id, | |
| " ".join( | |
| [",".join(src_probs) for src_probs in alignment] | |
| ), | |
| ), | |
| file=output_file, | |
| ) | |
| if cfg.generation.print_step: | |
| print( | |
| "I-{}\t{}".format(sample_id, hypo["steps"]), | |
| file=output_file, | |
| ) | |
| if cfg.generation.retain_iter_history: | |
| for step, h in enumerate(hypo["history"]): | |
| _, h_str, _ = utils.post_process_prediction( | |
| hypo_tokens=h["tokens"].int().cpu(), | |
| src_str=src_str, | |
| alignment=None, | |
| align_dict=None, | |
| tgt_dict=tgt_dict, | |
| remove_bpe=None, | |
| ) | |
| print( | |
| "E-{}_{}\t{}".format(sample_id, step, h_str), | |
| file=output_file, | |
| ) | |
| # Score only the top hypothesis | |
| if has_target and j == 0: | |
| if ( | |
| align_dict is not None | |
| or cfg.common_eval.post_process is not None | |
| ): | |
| # Convert back to tokens for evaluation with unk replacement and/or without BPE | |
| target_tokens = tgt_dict.encode_line( | |
| target_str, add_if_not_exist=True | |
| ) | |
| hypo_tokens = tgt_dict.encode_line( | |
| detok_hypo_str, add_if_not_exist=True | |
| ) | |
| if hasattr(scorer, "add_string"): | |
| scorer.add_string(target_str, detok_hypo_str) | |
| else: | |
| scorer.add(target_tokens, hypo_tokens) | |
| wps_meter.update(num_generated_tokens) | |
| progress.log({"wps": round(wps_meter.avg)}) | |
| num_sentences += ( | |
| sample["nsentences"] if "nsentences" in sample else sample["id"].numel() | |
| ) | |
| logger.info("NOTE: hypothesis and token scores are output in base 2") | |
| logger.info( | |
| "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( | |
| num_sentences, | |
| gen_timer.n, | |
| gen_timer.sum, | |
| num_sentences / gen_timer.sum, | |
| 1.0 / gen_timer.avg, | |
| ) | |
| ) | |
| if has_target: | |
| if cfg.bpe and not cfg.generation.sacrebleu: | |
| if cfg.common_eval.post_process: | |
| logger.warning( | |
| "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" | |
| ) | |
| else: | |
| logger.warning( | |
| "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" | |
| ) | |
| # use print to be consistent with other main outputs: S-, H-, T-, D- and so on | |
| print( | |
| "Generate {} with beam={}: {}".format( | |
| cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() | |
| ), | |
| file=output_file, | |
| ) | |
| return scorer | |
| def cli_main(): | |
| parser = options.get_generation_parser() | |
| # TODO: replace this workaround with refactoring of `AudioPretraining` | |
| parser.add_argument( | |
| "--arch", | |
| "-a", | |
| metavar="ARCH", | |
| default="wav2vec2", | |
| help="Model architecture. For constructing tasks that rely on " | |
| "model args (e.g. `AudioPretraining`)", | |
| ) | |
| args = options.parse_args_and_arch(parser) | |
| main(args) | |
| if __name__ == "__main__": | |
| cli_main() | |