File size: 3,829 Bytes
cf2a15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras integration for TensorBoard hparams.

Use `tensorboard.plugins.hparams.api` to access this module's contents.
"""


import tensorflow as tf

from tensorboard.plugins.hparams import api_pb2
from tensorboard.plugins.hparams import summary
from tensorboard.plugins.hparams import summary_v2


class Callback(tf.keras.callbacks.Callback):
    """Callback for logging hyperparameters to TensorBoard.

    NOTE: This callback only works in TensorFlow eager mode.
    """

    def __init__(self, writer, hparams, trial_id=None):
        """Create a callback for logging hyperparameters to TensorBoard.

        As with the standard `tf.keras.callbacks.TensorBoard` class, each
        callback object is valid for only one call to `model.fit`.

        Args:
          writer: The `SummaryWriter` object to which hparams should be
            written, or a logdir (as a `str`) to be passed to
            `tf.summary.create_file_writer` to create such a writer.
          hparams: A `dict` mapping hyperparameters to the values used in
            this session. Keys should be the names of `HParam` objects used
            in an experiment, or the `HParam` objects themselves. Values
            should be Python `bool`, `int`, `float`, or `string` values,
            depending on the type of the hyperparameter.
          trial_id: An optional `str` ID for the set of hyperparameter
            values used in this trial. Defaults to a hash of the
            hyperparameters.

        Raises:
          ValueError: If two entries in `hparams` share the same
            hyperparameter name.
        """
        # Defer creating the actual summary until we write it, so that the
        # timestamp is correct. But create a "dry-run" first to fail fast in
        # case the `hparams` are invalid.
        self._hparams = dict(hparams)
        self._trial_id = trial_id
        summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id)
        if writer is None:
            raise TypeError(
                "writer must be a `SummaryWriter` or `str`, not None"
            )
        elif isinstance(writer, str):
            self._writer = tf.compat.v2.summary.create_file_writer(writer)
        else:
            self._writer = writer

    def _get_writer(self):
        if self._writer is None:
            raise RuntimeError(
                "hparams Keras callback cannot be reused across training sessions"
            )
        if not tf.executing_eagerly():
            raise RuntimeError(
                "hparams Keras callback only supported in TensorFlow eager mode"
            )
        return self._writer

    def on_train_begin(self, logs=None):
        del logs  # unused
        with self._get_writer().as_default():
            summary_v2.hparams(self._hparams, trial_id=self._trial_id)

    def on_train_end(self, logs=None):
        del logs  # unused
        with self._get_writer().as_default():
            pb = summary.session_end_pb(api_pb2.STATUS_SUCCESS)
            raw_pb = pb.SerializeToString()
            tf.compat.v2.summary.experimental.write_raw_pb(raw_pb, step=0)
        self._writer = None