Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 -u | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Evaluate the perplexity of a trained language model. | |
| """ | |
| import logging | |
| import math | |
| import os | |
| import sys | |
| from argparse import Namespace | |
| from typing import Iterable, List, Optional | |
| import torch | |
| import fairseq | |
| from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils | |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
| from fairseq.logging import progress_bar | |
| from fairseq.logging.meters import StopwatchMeter | |
| from fairseq.sequence_scorer import SequenceScorer | |
| from omegaconf import DictConfig | |
| logging.basicConfig( | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=os.environ.get("LOGLEVEL", "INFO").upper(), | |
| stream=sys.stdout, | |
| ) | |
| logger = logging.getLogger("fairseq_cli.eval_lm") | |
| def eval_lm( | |
| models: List[fairseq.models.FairseqModel], | |
| source_dictionary: fairseq.data.Dictionary, | |
| batch_iterator: Iterable, | |
| post_process: Optional[str] = None, | |
| output_word_probs: bool = False, | |
| output_word_stats: bool = False, | |
| target_dictionary: Optional[fairseq.data.Dictionary] = None, | |
| softmax_batch: int = 0, | |
| remove_bos_token: bool = False, | |
| device: Optional[torch.device] = None, | |
| ): | |
| """ | |
| Args: | |
| models (List[~fairseq.models.FairseqModel]): list of models to | |
| evaluate. Models are essentially `nn.Module` instances, but | |
| must be compatible with fairseq's `SequenceScorer`. | |
| source_dictionary (~fairseq.data.Dictionary): dictionary for | |
| applying any relevant post processing or outputing word | |
| probs/stats. | |
| batch_iterator (Iterable): yield batches of data | |
| post_process (Optional[str]): post-process text by removing BPE, | |
| letter segmentation, etc. Valid options can be found in | |
| fairseq.data.utils.post_process, although not all options | |
| are implemented here. | |
| output_word_probs (Optional[bool]): output words and their | |
| predicted log probabilities | |
| output_word_stats (Optional[bool]): output word statistics such | |
| as word count and average probability | |
| target_dictionary (Optional[~fairseq.data.Dictionary]): output | |
| dictionary (defaults to *source_dictionary*) | |
| softmax_batch (Optional[bool]): if BxT is more than this, will | |
| batch the softmax over vocab to this amount of tokens, in | |
| order to fit into GPU memory | |
| remove_bos_token (Optional[bool]): if True, confirm that the | |
| first token is the beginning-of-sentence symbol (according | |
| to the relevant dictionary) and remove it from the output | |
| device (Optional[torch.device]): device to use for evaluation | |
| (defaults to device of first model parameter) | |
| """ | |
| if target_dictionary is None: | |
| target_dictionary = source_dictionary | |
| if device is None: | |
| device = next(models[0].parameters()).device | |
| gen_timer = StopwatchMeter() | |
| scorer = SequenceScorer(target_dictionary, softmax_batch) | |
| score_sum = 0.0 | |
| count = 0 | |
| if post_process is not None: | |
| if post_process in {"subword_nmt", "@@ "}: | |
| bpe_cont = post_process.rstrip() | |
| bpe_toks = { | |
| i | |
| for i in range(len(source_dictionary)) | |
| if source_dictionary[i].endswith(bpe_cont) | |
| } | |
| else: | |
| raise NotImplementedError( | |
| "--post-process={post_process} is not implemented" | |
| ) | |
| bpe_len = len(bpe_cont) | |
| else: | |
| bpe_toks = None | |
| bpe_len = 0 | |
| word_stats = dict() | |
| for sample in batch_iterator: | |
| if "net_input" not in sample: | |
| continue | |
| sample = utils.move_to_cuda(sample, device=device) | |
| gen_timer.start() | |
| hypos = scorer.generate(models, sample) | |
| gen_timer.stop(sample["ntokens"]) | |
| for i, hypos_i in enumerate(hypos): | |
| hypo = hypos_i[0] | |
| sample_id = sample["id"][i] | |
| tokens = hypo["tokens"] | |
| tgt_len = tokens.numel() | |
| pos_scores = hypo["positional_scores"].float() | |
| if remove_bos_token: | |
| assert hypo["tokens"][0].item() == target_dictionary.bos() | |
| tokens = tokens[1:] | |
| pos_scores = pos_scores[1:] | |
| skipped_toks = 0 | |
| if bpe_toks is not None: | |
| for i in range(tgt_len - 1): | |
| if tokens[i].item() in bpe_toks: | |
| skipped_toks += 1 | |
| pos_scores[i + 1] += pos_scores[i] | |
| pos_scores[i] = 0 | |
| inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) | |
| if inf_scores.any(): | |
| logger.info( | |
| "skipping tokens with inf scores:", | |
| target_dictionary.string(tokens[inf_scores.nonzero()]), | |
| ) | |
| pos_scores = pos_scores[(~inf_scores).nonzero()] | |
| score_sum += pos_scores.sum().cpu() | |
| count += pos_scores.numel() - skipped_toks | |
| if output_word_probs or output_word_stats: | |
| w = "" | |
| word_prob = [] | |
| is_bpe = False | |
| for i in range(len(tokens)): | |
| w_ind = tokens[i].item() | |
| w += source_dictionary[w_ind] | |
| if bpe_toks is not None and w_ind in bpe_toks: | |
| w = w[:-bpe_len] | |
| is_bpe = True | |
| else: | |
| word_prob.append((w, pos_scores[i].item())) | |
| next_prob = None | |
| ind = i + 1 | |
| while ind < len(tokens): | |
| if pos_scores[ind].item() != 0: | |
| next_prob = pos_scores[ind] | |
| break | |
| ind += 1 | |
| word_stats.setdefault(w, WordStat(w, is_bpe)).add( | |
| pos_scores[i].item(), next_prob | |
| ) | |
| is_bpe = False | |
| w = "" | |
| if output_word_probs: | |
| logger.info( | |
| str(int(sample_id)) | |
| + " " | |
| + ( | |
| "\t".join( | |
| "{} [{:2f}]".format(x[0], x[1]) for x in word_prob | |
| ) | |
| ) | |
| ) | |
| avg_nll_loss = ( | |
| -score_sum / count / math.log(2) if count > 0 else 0 | |
| ) # convert to base 2 | |
| logger.info( | |
| "Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format( | |
| gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 | |
| ) | |
| ) | |
| if output_word_stats: | |
| for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): | |
| logger.info(ws) | |
| return { | |
| "loss": avg_nll_loss, | |
| "perplexity": 2 ** avg_nll_loss, | |
| } | |
| class WordStat(object): | |
| def __init__(self, word, is_bpe): | |
| self.word = word | |
| self.is_bpe = is_bpe | |
| self.log_prob = 0 | |
| self.next_word_prob = 0 | |
| self.count = 0 | |
| self.missing_next_words = 0 | |
| def add(self, log_prob, next_word_prob): | |
| """increments counters for the sum of log probs of current word and next | |
| word (given context ending at current word). Since the next word might be at the end of the example, | |
| or it might be not counted because it is not an ending subword unit, | |
| also keeps track of how many of those we have seen""" | |
| if next_word_prob is not None: | |
| self.next_word_prob += next_word_prob | |
| else: | |
| self.missing_next_words += 1 | |
| self.log_prob += log_prob | |
| self.count += 1 | |
| def __str__(self): | |
| return "{}\t{}\t{}\t{}\t{}\t{}".format( | |
| self.word, | |
| self.count, | |
| self.log_prob, | |
| self.is_bpe, | |
| self.next_word_prob, | |
| self.count - self.missing_next_words, | |
| ) | |
| def main(cfg: DictConfig, **unused_kwargs): | |
| if isinstance(cfg, Namespace): | |
| cfg = convert_namespace_to_omegaconf(cfg) | |
| utils.import_user_module(cfg.common) | |
| logger.info(cfg) | |
| if cfg.eval_lm.context_window > 0: | |
| # reduce tokens per sample by the required context window size | |
| cfg.task.tokens_per_sample -= cfg.eval_lm.context_window | |
| # Initialize the task using the current *cfg* | |
| task = tasks.setup_task(cfg.task) | |
| # Load ensemble | |
| logger.info("loading model(s) from {}".format(cfg.common_eval.path)) | |
| models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( | |
| [cfg.common_eval.path], | |
| arg_overrides=eval(cfg.common_eval.model_overrides), | |
| suffix=cfg.checkpoint.checkpoint_suffix, | |
| strict=(cfg.checkpoint.checkpoint_shard_count == 1), | |
| num_shards=cfg.checkpoint.checkpoint_shard_count, | |
| task=task, | |
| ) | |
| use_fp16 = cfg.common.fp16 | |
| use_cuda = torch.cuda.is_available() and not cfg.common.cpu | |
| if use_cuda: | |
| torch.cuda.set_device(cfg.distributed_training.device_id) | |
| # Optimize ensemble for generation and set the source and dest dicts on the model | |
| # (required by scorer) | |
| for model in models: | |
| if use_fp16: | |
| model.half() | |
| if use_cuda and not cfg.distributed_training.pipeline_model_parallel: | |
| model.cuda() | |
| model.prepare_for_inference_(cfg) | |
| assert len(models) > 0 | |
| logger.info( | |
| "num. model params: {:,}".format(sum(p.numel() for p in models[0].parameters())) | |
| ) | |
| # Load dataset splits | |
| task.load_dataset(cfg.dataset.gen_subset) | |
| dataset = task.dataset(cfg.dataset.gen_subset) | |
| logger.info( | |
| "{} {} {:,} examples".format( | |
| cfg.task.data, cfg.dataset.gen_subset, len(dataset) | |
| ) | |
| ) | |
| itr = task.eval_lm_dataloader( | |
| dataset=dataset, | |
| max_tokens=cfg.dataset.max_tokens or 36000, | |
| batch_size=cfg.dataset.batch_size, | |
| max_positions=utils.resolve_max_positions( | |
| *[model.max_positions() for model in models] | |
| ), | |
| num_shards=max( | |
| cfg.dataset.num_shards, | |
| cfg.distributed_training.distributed_world_size, | |
| ), | |
| shard_id=max( | |
| cfg.dataset.shard_id, | |
| cfg.distributed_training.distributed_rank, | |
| ), | |
| num_workers=cfg.dataset.num_workers, | |
| data_buffer_size=cfg.dataset.data_buffer_size, | |
| context_window=cfg.eval_lm.context_window, | |
| ) | |
| itr = progress_bar.progress_bar( | |
| itr, | |
| log_format=cfg.common.log_format, | |
| log_interval=cfg.common.log_interval, | |
| default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), | |
| ) | |
| results = eval_lm( | |
| models=models, | |
| source_dictionary=task.source_dictionary, | |
| batch_iterator=itr, | |
| post_process=cfg.common_eval.post_process, | |
| output_word_probs=cfg.eval_lm.output_word_probs, | |
| output_word_stats=cfg.eval_lm.output_word_stats, | |
| target_dictionary=task.target_dictionary, | |
| softmax_batch=cfg.eval_lm.softmax_batch, | |
| remove_bos_token=getattr(cfg.task, "add_bos_token", False), | |
| ) | |
| logger.info( | |
| "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( | |
| results["loss"], results["perplexity"] | |
| ) | |
| ) | |
| return results | |
| def cli_main(): | |
| parser = options.get_eval_lm_parser() | |
| args = options.parse_args_and_arch(parser) | |
| distributed_utils.call_main(convert_namespace_to_omegaconf(args), main) | |
| if __name__ == "__main__": | |
| cli_main() | |