File size: 7,124 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# python3.7
"""Contains the class for recording the running stats.

Here, running stats refers to the statictical information in the running
process, such as loss values, learning rates, running time, etc.
"""

from .misc import format_time

__all__ = ['SingleStats', 'RunningStats']


class SingleStats(object):
    """A class to record the stats corresponding to a particular variable.

    This class is log-friendly and supports customized log format, including:

    (1) Numerical log format, such as `.3f`, `.1e`, `05d`, and `>10s`.
    (2) Customized log name (name of the stats to show in the log).
    (3) Additional string (e.g., measure unit) as the tail of log message.

    Furthermore, this class also supports logging the stats with different
    strategies, including:

    (1) CURRENT: The current value will be logged.
    (2) AVERAGE: The averaged value (from the beginning) will be logged.
    (3) SUM: The cumulative value (from the beginning) will be logged.
    """

    def __init__(self,
                 name,
                 log_format='.3f',
                 log_name=None,
                 log_tail=None,
                 log_strategy='AVERAGE'):
        """Initializes the stats with log format.

        Args:
            name: Name of the stats. Should be a string without spaces.
            log_format: The numerical log format. Use `time` to log time
                duration. (default: `.3f`)
            log_name: The name shown in the log. `None` means to directly use
                the stats name. (default: None)
            log_tail: The tailing log message. (default: None)
            log_strategy: Strategy to log this stats. `CURRENT`, `AVERAGE`, and
                `SUM` are supported. (default: `AVERAGE`)

        Raises:
            ValueError: If the input `log_strategy` is not supported.
        """
        log_strategy = log_strategy.upper()
        if log_strategy not in ['CURRENT', 'AVERAGE', 'SUM']:
            raise ValueError(f'Invalid log strategy `{self.log_strategy}`!')

        self._name = name
        self._log_format = log_format
        self._log_name = log_name or name
        self._log_tail = log_tail or ''
        self._log_strategy = log_strategy

        # Stats Data.
        self.val = 0  # Current value.
        self.sum = 0  # Cumulative value.
        self.avg = 0  # Averaged value.
        self.cnt = 0  # Count number.

    @property
    def name(self):
        """Gets the name of the stats."""
        return self._name

    @property
    def log_format(self):
        """Gets tne numerical log format of the stats."""
        return self._log_format

    @property
    def log_name(self):
        """Gets the log name of the stats."""
        return self._log_name

    @property
    def log_tail(self):
        """Gets the tailing log message of the stats."""
        return self._log_tail

    @property
    def log_strategy(self):
        """Gets the log strategy of the stats."""
        return self._log_strategy

    def clear(self):
        """Clears the stats data."""
        self.val = 0
        self.sum = 0
        self.avg = 0
        self.cnt = 0

    def update(self, value):
        """Updates the stats data."""
        self.val = value
        self.cnt = self.cnt + 1
        self.sum = self.sum + value
        self.avg = self.sum / self.cnt

    def get_log_value(self):
        """Gets value for logging according to the log strategy."""
        if self.log_strategy == 'CURRENT':
            return self.val
        if self.log_strategy == 'AVERAGE':
            return self.avg
        if self.log_strategy == 'SUM':
            return self.sum
        raise NotImplementedError(f'Log strategy `{self.log_strategy}` is not '
                                  f'implemented!')

    def __str__(self):
        """Gets log message."""
        if self.log_format == 'time':
            value_str = f'{format_time(self.get_log_value())}'
        else:
            value_str = f'{self.get_log_value():{self.log_format}}'
        return f'{self.log_name}: {value_str}{self.log_tail}'


class RunningStats(object):
    """A class to record all the running stats.

    Basically, this class contains a dictionary of SingleStats.

    Example:

    running_stats = RunningStats()
    running_stats.add('loss', log_format='.3f', log_strategy='AVERAGE')
    running_stats.add('time', log_format='time', log_name='Iter Time',
                      log_strategy='CURRENT')
    running_stats.log_order = ['time', 'loss']
    running_stats.update({'loss': 0.46, 'time': 12})
    running_stats.update({'time': 14.5, 'loss': 0.33})
    print(running_stats)
    """

    def __init__(self, log_delimiter=', '):
        """Initializes the running stats with the log delimiter.

        Args:
            log_delimiter: This delimiter is used to connect the log messages
                from different stats. (default: `, `)
        """
        self._log_delimiter = log_delimiter
        self.stats_pool = dict()  # The stats pool.
        self.log_order = None  # Order of the stats to log.

    @property
    def log_delimiter(self):
        """Gets the log delimiter between different stats."""
        return self._log_delimiter

    def add(self, name, **kwargs):
        """Adds a new SingleStats to the dictionary.

        Additional arguments include:

        log_format: The numerical log format. Use `time` to log time duration.
            (default: `.3f`)
        log_name: The name shown in the log. `None` means to directly use the
            stats name. (default: None)
        log_tail: The tailing log message. (default: None)
        log_strategy: Strategy to log this stats. `CURRENT`, `AVERAGE`, and
            `SUM` are supported. (default: `AVERAGE`)
        """
        if name in self.stats_pool:
            return
        self.stats_pool[name] = SingleStats(name, **kwargs)

    def clear(self, exclude_list=None):
        """Clears the stats data (if needed).

        Args:
            exclude_list: A list of stats names whose data will not be cleared.
        """
        exclude_list = set(exclude_list or [])
        for name, stats in self.stats_pool.items():
            if name not in exclude_list:
                stats.clear()

    def update(self, kwargs):
        """Updates the stats data by name."""
        for name, value in kwargs.items():
            if name not in self.stats_pool:
                self.add(name)
            self.stats_pool[name].update(value)

    def __getattr__(self, name):
        """Gets a particular SingleStats by name."""
        if name in self.stats_pool:
            return self.stats_pool[name]
        if name in self.__dict__:
            return self.__dict__[name]
        raise AttributeError(f'`{self.__class__.__name__}` object has no '
                             f'attribute `{name}`!')

    def __str__(self):
        """Gets log message."""
        self.log_order = self.log_order or list(self.stats_pool)
        log_strings = [str(self.stats_pool[name]) for name in self.log_order]
        return self.log_delimiter.join(log_strings)