|
|
|
"""Contains the class for recording the running stats. |
|
|
|
Here, running stats refers to the statictical information in the running |
|
process, such as loss values, learning rates, running time, etc. |
|
""" |
|
|
|
from .misc import format_time |
|
|
|
__all__ = ['SingleStats', 'RunningStats'] |
|
|
|
|
|
class SingleStats(object): |
|
"""A class to record the stats corresponding to a particular variable. |
|
|
|
This class is log-friendly and supports customized log format, including: |
|
|
|
(1) Numerical log format, such as `.3f`, `.1e`, `05d`, and `>10s`. |
|
(2) Customized log name (name of the stats to show in the log). |
|
(3) Additional string (e.g., measure unit) as the tail of log message. |
|
|
|
Furthermore, this class also supports logging the stats with different |
|
strategies, including: |
|
|
|
(1) CURRENT: The current value will be logged. |
|
(2) AVERAGE: The averaged value (from the beginning) will be logged. |
|
(3) SUM: The cumulative value (from the beginning) will be logged. |
|
""" |
|
|
|
def __init__(self, |
|
name, |
|
log_format='.3f', |
|
log_name=None, |
|
log_tail=None, |
|
log_strategy='AVERAGE'): |
|
"""Initializes the stats with log format. |
|
|
|
Args: |
|
name: Name of the stats. Should be a string without spaces. |
|
log_format: The numerical log format. Use `time` to log time |
|
duration. (default: `.3f`) |
|
log_name: The name shown in the log. `None` means to directly use |
|
the stats name. (default: None) |
|
log_tail: The tailing log message. (default: None) |
|
log_strategy: Strategy to log this stats. `CURRENT`, `AVERAGE`, and |
|
`SUM` are supported. (default: `AVERAGE`) |
|
|
|
Raises: |
|
ValueError: If the input `log_strategy` is not supported. |
|
""" |
|
log_strategy = log_strategy.upper() |
|
if log_strategy not in ['CURRENT', 'AVERAGE', 'SUM']: |
|
raise ValueError(f'Invalid log strategy `{self.log_strategy}`!') |
|
|
|
self._name = name |
|
self._log_format = log_format |
|
self._log_name = log_name or name |
|
self._log_tail = log_tail or '' |
|
self._log_strategy = log_strategy |
|
|
|
|
|
self.val = 0 |
|
self.sum = 0 |
|
self.avg = 0 |
|
self.cnt = 0 |
|
|
|
@property |
|
def name(self): |
|
"""Gets the name of the stats.""" |
|
return self._name |
|
|
|
@property |
|
def log_format(self): |
|
"""Gets tne numerical log format of the stats.""" |
|
return self._log_format |
|
|
|
@property |
|
def log_name(self): |
|
"""Gets the log name of the stats.""" |
|
return self._log_name |
|
|
|
@property |
|
def log_tail(self): |
|
"""Gets the tailing log message of the stats.""" |
|
return self._log_tail |
|
|
|
@property |
|
def log_strategy(self): |
|
"""Gets the log strategy of the stats.""" |
|
return self._log_strategy |
|
|
|
def clear(self): |
|
"""Clears the stats data.""" |
|
self.val = 0 |
|
self.sum = 0 |
|
self.avg = 0 |
|
self.cnt = 0 |
|
|
|
def update(self, value): |
|
"""Updates the stats data.""" |
|
self.val = value |
|
self.cnt = self.cnt + 1 |
|
self.sum = self.sum + value |
|
self.avg = self.sum / self.cnt |
|
|
|
def get_log_value(self): |
|
"""Gets value for logging according to the log strategy.""" |
|
if self.log_strategy == 'CURRENT': |
|
return self.val |
|
if self.log_strategy == 'AVERAGE': |
|
return self.avg |
|
if self.log_strategy == 'SUM': |
|
return self.sum |
|
raise NotImplementedError(f'Log strategy `{self.log_strategy}` is not ' |
|
f'implemented!') |
|
|
|
def __str__(self): |
|
"""Gets log message.""" |
|
if self.log_format == 'time': |
|
value_str = f'{format_time(self.get_log_value())}' |
|
else: |
|
value_str = f'{self.get_log_value():{self.log_format}}' |
|
return f'{self.log_name}: {value_str}{self.log_tail}' |
|
|
|
|
|
class RunningStats(object): |
|
"""A class to record all the running stats. |
|
|
|
Basically, this class contains a dictionary of SingleStats. |
|
|
|
Example: |
|
|
|
running_stats = RunningStats() |
|
running_stats.add('loss', log_format='.3f', log_strategy='AVERAGE') |
|
running_stats.add('time', log_format='time', log_name='Iter Time', |
|
log_strategy='CURRENT') |
|
running_stats.log_order = ['time', 'loss'] |
|
running_stats.update({'loss': 0.46, 'time': 12}) |
|
running_stats.update({'time': 14.5, 'loss': 0.33}) |
|
print(running_stats) |
|
""" |
|
|
|
def __init__(self, log_delimiter=', '): |
|
"""Initializes the running stats with the log delimiter. |
|
|
|
Args: |
|
log_delimiter: This delimiter is used to connect the log messages |
|
from different stats. (default: `, `) |
|
""" |
|
self._log_delimiter = log_delimiter |
|
self.stats_pool = dict() |
|
self.log_order = None |
|
|
|
@property |
|
def log_delimiter(self): |
|
"""Gets the log delimiter between different stats.""" |
|
return self._log_delimiter |
|
|
|
def add(self, name, **kwargs): |
|
"""Adds a new SingleStats to the dictionary. |
|
|
|
Additional arguments include: |
|
|
|
log_format: The numerical log format. Use `time` to log time duration. |
|
(default: `.3f`) |
|
log_name: The name shown in the log. `None` means to directly use the |
|
stats name. (default: None) |
|
log_tail: The tailing log message. (default: None) |
|
log_strategy: Strategy to log this stats. `CURRENT`, `AVERAGE`, and |
|
`SUM` are supported. (default: `AVERAGE`) |
|
""" |
|
if name in self.stats_pool: |
|
return |
|
self.stats_pool[name] = SingleStats(name, **kwargs) |
|
|
|
def clear(self, exclude_list=None): |
|
"""Clears the stats data (if needed). |
|
|
|
Args: |
|
exclude_list: A list of stats names whose data will not be cleared. |
|
""" |
|
exclude_list = set(exclude_list or []) |
|
for name, stats in self.stats_pool.items(): |
|
if name not in exclude_list: |
|
stats.clear() |
|
|
|
def update(self, kwargs): |
|
"""Updates the stats data by name.""" |
|
for name, value in kwargs.items(): |
|
if name not in self.stats_pool: |
|
self.add(name) |
|
self.stats_pool[name].update(value) |
|
|
|
def __getattr__(self, name): |
|
"""Gets a particular SingleStats by name.""" |
|
if name in self.stats_pool: |
|
return self.stats_pool[name] |
|
if name in self.__dict__: |
|
return self.__dict__[name] |
|
raise AttributeError(f'`{self.__class__.__name__}` object has no ' |
|
f'attribute `{name}`!') |
|
|
|
def __str__(self): |
|
"""Gets log message.""" |
|
self.log_order = self.log_order or list(self.stats_pool) |
|
log_strings = [str(self.stats_pool[name]) for name in self.log_order] |
|
return self.log_delimiter.join(log_strings) |
|
|