| from chainer.training.extension import Extension | |
| class TensorboardLogger(Extension): | |
| """A tensorboard logger extension""" | |
| default_name = "espnet_tensorboard_logger" | |
| def __init__( | |
| self, logger, att_reporter=None, ctc_reporter=None, entries=None, epoch=0 | |
| ): | |
| """Init the extension | |
| :param SummaryWriter logger: The logger to use | |
| :param PlotAttentionReporter att_reporter: The (optional) PlotAttentionReporter | |
| :param entries: The entries to watch | |
| :param int epoch: The starting epoch | |
| """ | |
| self._entries = entries | |
| self._att_reporter = att_reporter | |
| self._ctc_reporter = ctc_reporter | |
| self._logger = logger | |
| self._epoch = epoch | |
| def __call__(self, trainer): | |
| """Updates the events file with the new values | |
| :param trainer: The trainer | |
| """ | |
| observation = trainer.observation | |
| for k, v in observation.items(): | |
| if (self._entries is not None) and (k not in self._entries): | |
| continue | |
| if k is not None and v is not None: | |
| if "cupy" in str(type(v)): | |
| v = v.get() | |
| if "cupy" in str(type(k)): | |
| k = k.get() | |
| self._logger.add_scalar(k, v, trainer.updater.iteration) | |
| if ( | |
| self._att_reporter is not None | |
| and trainer.updater.get_iterator("main").epoch > self._epoch | |
| ): | |
| self._epoch = trainer.updater.get_iterator("main").epoch | |
| self._att_reporter.log_attentions(self._logger, trainer.updater.iteration) | |
| if ( | |
| self._ctc_reporter is not None | |
| and trainer.updater.get_iterator("main").epoch > self._epoch | |
| ): | |
| self._epoch = trainer.updater.get_iterator("main").epoch | |
| self._ctc_reporter.log_ctc_probs(self._logger, trainer.updater.iteration) | |