File size: 7,124 Bytes
8c212a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
# python3.7
"""Contains the class for recording the running stats.
Here, running stats refers to the statictical information in the running
process, such as loss values, learning rates, running time, etc.
"""
from .misc import format_time
__all__ = ['SingleStats', 'RunningStats']
class SingleStats(object):
"""A class to record the stats corresponding to a particular variable.
This class is log-friendly and supports customized log format, including:
(1) Numerical log format, such as `.3f`, `.1e`, `05d`, and `>10s`.
(2) Customized log name (name of the stats to show in the log).
(3) Additional string (e.g., measure unit) as the tail of log message.
Furthermore, this class also supports logging the stats with different
strategies, including:
(1) CURRENT: The current value will be logged.
(2) AVERAGE: The averaged value (from the beginning) will be logged.
(3) SUM: The cumulative value (from the beginning) will be logged.
"""
def __init__(self,
name,
log_format='.3f',
log_name=None,
log_tail=None,
log_strategy='AVERAGE'):
"""Initializes the stats with log format.
Args:
name: Name of the stats. Should be a string without spaces.
log_format: The numerical log format. Use `time` to log time
duration. (default: `.3f`)
log_name: The name shown in the log. `None` means to directly use
the stats name. (default: None)
log_tail: The tailing log message. (default: None)
log_strategy: Strategy to log this stats. `CURRENT`, `AVERAGE`, and
`SUM` are supported. (default: `AVERAGE`)
Raises:
ValueError: If the input `log_strategy` is not supported.
"""
log_strategy = log_strategy.upper()
if log_strategy not in ['CURRENT', 'AVERAGE', 'SUM']:
raise ValueError(f'Invalid log strategy `{self.log_strategy}`!')
self._name = name
self._log_format = log_format
self._log_name = log_name or name
self._log_tail = log_tail or ''
self._log_strategy = log_strategy
# Stats Data.
self.val = 0 # Current value.
self.sum = 0 # Cumulative value.
self.avg = 0 # Averaged value.
self.cnt = 0 # Count number.
@property
def name(self):
"""Gets the name of the stats."""
return self._name
@property
def log_format(self):
"""Gets tne numerical log format of the stats."""
return self._log_format
@property
def log_name(self):
"""Gets the log name of the stats."""
return self._log_name
@property
def log_tail(self):
"""Gets the tailing log message of the stats."""
return self._log_tail
@property
def log_strategy(self):
"""Gets the log strategy of the stats."""
return self._log_strategy
def clear(self):
"""Clears the stats data."""
self.val = 0
self.sum = 0
self.avg = 0
self.cnt = 0
def update(self, value):
"""Updates the stats data."""
self.val = value
self.cnt = self.cnt + 1
self.sum = self.sum + value
self.avg = self.sum / self.cnt
def get_log_value(self):
"""Gets value for logging according to the log strategy."""
if self.log_strategy == 'CURRENT':
return self.val
if self.log_strategy == 'AVERAGE':
return self.avg
if self.log_strategy == 'SUM':
return self.sum
raise NotImplementedError(f'Log strategy `{self.log_strategy}` is not '
f'implemented!')
def __str__(self):
"""Gets log message."""
if self.log_format == 'time':
value_str = f'{format_time(self.get_log_value())}'
else:
value_str = f'{self.get_log_value():{self.log_format}}'
return f'{self.log_name}: {value_str}{self.log_tail}'
class RunningStats(object):
"""A class to record all the running stats.
Basically, this class contains a dictionary of SingleStats.
Example:
running_stats = RunningStats()
running_stats.add('loss', log_format='.3f', log_strategy='AVERAGE')
running_stats.add('time', log_format='time', log_name='Iter Time',
log_strategy='CURRENT')
running_stats.log_order = ['time', 'loss']
running_stats.update({'loss': 0.46, 'time': 12})
running_stats.update({'time': 14.5, 'loss': 0.33})
print(running_stats)
"""
def __init__(self, log_delimiter=', '):
"""Initializes the running stats with the log delimiter.
Args:
log_delimiter: This delimiter is used to connect the log messages
from different stats. (default: `, `)
"""
self._log_delimiter = log_delimiter
self.stats_pool = dict() # The stats pool.
self.log_order = None # Order of the stats to log.
@property
def log_delimiter(self):
"""Gets the log delimiter between different stats."""
return self._log_delimiter
def add(self, name, **kwargs):
"""Adds a new SingleStats to the dictionary.
Additional arguments include:
log_format: The numerical log format. Use `time` to log time duration.
(default: `.3f`)
log_name: The name shown in the log. `None` means to directly use the
stats name. (default: None)
log_tail: The tailing log message. (default: None)
log_strategy: Strategy to log this stats. `CURRENT`, `AVERAGE`, and
`SUM` are supported. (default: `AVERAGE`)
"""
if name in self.stats_pool:
return
self.stats_pool[name] = SingleStats(name, **kwargs)
def clear(self, exclude_list=None):
"""Clears the stats data (if needed).
Args:
exclude_list: A list of stats names whose data will not be cleared.
"""
exclude_list = set(exclude_list or [])
for name, stats in self.stats_pool.items():
if name not in exclude_list:
stats.clear()
def update(self, kwargs):
"""Updates the stats data by name."""
for name, value in kwargs.items():
if name not in self.stats_pool:
self.add(name)
self.stats_pool[name].update(value)
def __getattr__(self, name):
"""Gets a particular SingleStats by name."""
if name in self.stats_pool:
return self.stats_pool[name]
if name in self.__dict__:
return self.__dict__[name]
raise AttributeError(f'`{self.__class__.__name__}` object has no '
f'attribute `{name}`!')
def __str__(self):
"""Gets log message."""
self.log_order = self.log_order or list(self.stats_pool)
log_strings = [str(self.stats_pool[name]) for name in self.log_order]
return self.log_delimiter.join(log_strings)
|