Spaces:
Build error
Build error
File size: 9,171 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import numpy as np
class HistoryBuffer:
"""Unified storage format for different log types.
``HistoryBuffer`` records the history of log for further statistics.
Examples:
>>> history_buffer = HistoryBuffer()
>>> # Update history_buffer.
>>> history_buffer.update(1)
>>> history_buffer.update(2)
>>> history_buffer.min() # minimum of (1, 2)
1
>>> history_buffer.max() # maximum of (1, 2)
2
>>> history_buffer.mean() # mean of (1, 2)
1.5
>>> history_buffer.statistics('mean') # access method by string.
1.5
Args:
log_history (Sequence): History logs. Defaults to [].
count_history (Sequence): Counts of history logs. Defaults to [].
max_length (int): The max length of history logs. Defaults to 1000000.
"""
_statistics_methods: dict = dict()
def __init__(self,
log_history: Sequence = [],
count_history: Sequence = [],
max_length: int = 1000000):
self.max_length = max_length
self._set_default_statistics()
assert len(log_history) == len(count_history), \
'The lengths of log_history and count_histroy should be equal'
if len(log_history) > max_length:
warnings.warn(f'The length of history buffer({len(log_history)}) '
f'exceeds the max_length({max_length}), the first '
'few elements will be ignored.')
self._log_history = np.array(log_history[-max_length:])
self._count_history = np.array(count_history[-max_length:])
else:
self._log_history = np.array(log_history)
self._count_history = np.array(count_history)
def _set_default_statistics(self) -> None:
"""Register default statistic methods: min, max, current and mean."""
self._statistics_methods.setdefault('min', HistoryBuffer.min)
self._statistics_methods.setdefault('max', HistoryBuffer.max)
self._statistics_methods.setdefault('current', HistoryBuffer.current)
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history.
If the length of the buffer exceeds ``self._max_length``, the oldest
element will be removed from the buffer.
Args:
log_val (int or float): The value of log.
count (int): The accumulation times of log, defaults to 1.
``count`` will be used in smooth statistics.
"""
if (not isinstance(log_val, (int, float))
or not isinstance(count, (int, float))):
raise TypeError(f'log_val must be int or float but got '
f'{type(log_val)}, count must be int but got '
f'{type(count)}')
self._log_history = np.append(self._log_history, log_val)
self._count_history = np.append(self._count_history, count)
if len(self._log_history) > self.max_length:
self._log_history = self._log_history[-self.max_length:]
self._count_history = self._count_history[-self.max_length:]
@property
def data(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get the ``_log_history`` and ``_count_history``.
Returns:
Tuple[np.ndarray, np.ndarray]: History logs and the counts of
the history logs.
"""
return self._log_history, self._count_history
@classmethod
def register_statistics(cls, method: Callable) -> Callable:
"""Register custom statistics method to ``_statistics_methods``.
The registered method can be called by ``history_buffer.statistics``
with corresponding method name and arguments.
Examples:
>>> @HistoryBuffer.register_statistics
>>> def weighted_mean(self, window_size, weight):
>>> assert len(weight) == window_size
>>> return (self._log_history[-window_size:] *
>>> np.array(weight)).sum() / \
>>> self._count_history[-window_size:]
>>> log_buffer = HistoryBuffer([1, 2], [1, 1])
>>> log_buffer.statistics('weighted_mean', 2, [2, 1])
2
Args:
method (Callable): Custom statistics method.
Returns:
Callable: Original custom statistics method.
"""
method_name = method.__name__
assert method_name not in cls._statistics_methods, \
'method_name cannot be registered twice!'
cls._statistics_methods[method_name] = method
return method
def statistics(self, method_name: str, *arg, **kwargs) -> Any:
"""Access statistics method by name.
Args:
method_name (str): Name of method.
Returns:
Any: Depends on corresponding method.
"""
if method_name not in self._statistics_methods:
raise KeyError(f'{method_name} has not been registered in '
'HistoryBuffer._statistics_methods')
method = self._statistics_methods[method_name]
# Provide self arguments for registered functions.
return method(self, *arg, **kwargs)
def mean(self, window_size: Optional[int] = None) -> np.ndarray:
"""Return the mean of the latest ``window_size`` values in log
histories.
If ``window_size is None`` or ``window_size > len(self._log_history)``,
return the global mean value of history logs.
Args:
window_size (int, optional): Size of statistics window.
Returns:
np.ndarray: Mean value within the window.
"""
if window_size is not None:
assert isinstance(window_size, int), \
'The type of window size should be int, but got ' \
f'{type(window_size)}'
else:
window_size = len(self._log_history)
logs_sum = self._log_history[-window_size:].sum()
counts_sum = self._count_history[-window_size:].sum()
return logs_sum / counts_sum
def max(self, window_size: Optional[int] = None) -> np.ndarray:
"""Return the maximum value of the latest ``window_size`` values in log
histories.
If ``window_size is None`` or ``window_size > len(self._log_history)``,
return the global maximum value of history logs.
Args:
window_size (int, optional): Size of statistics window.
Returns:
np.ndarray: The maximum value within the window.
"""
if window_size is not None:
assert isinstance(window_size, int), \
'The type of window size should be int, but got ' \
f'{type(window_size)}'
else:
window_size = len(self._log_history)
return self._log_history[-window_size:].max()
def min(self, window_size: Optional[int] = None) -> np.ndarray:
"""Return the minimum value of the latest ``window_size`` values in log
histories.
If ``window_size is None`` or ``window_size > len(self._log_history)``,
return the global minimum value of history logs.
Args:
window_size (int, optional): Size of statistics window.
Returns:
np.ndarray: The minimum value within the window.
"""
if window_size is not None:
assert isinstance(window_size, int), \
'The type of window size should be int, but got ' \
f'{type(window_size)}'
else:
window_size = len(self._log_history)
return self._log_history[-window_size:].min()
def current(self) -> np.ndarray:
"""Return the recently updated values in log histories.
Returns:
np.ndarray: Recently updated values in log histories.
"""
if len(self._log_history) == 0:
raise ValueError('HistoryBuffer._log_history is an empty array! '
'please call update first')
return self._log_history[-1]
def __getstate__(self) -> dict:
"""Make ``_statistics_methods`` can be resumed.
Returns:
dict: State dict including statistics_methods.
"""
self.__dict__.update(statistics_methods=self._statistics_methods)
return self.__dict__
def __setstate__(self, state):
"""Try to load ``_statistics_methods`` from state.
Args:
state (dict): State dict.
"""
statistics_methods = state.pop('statistics_methods', {})
self._set_default_statistics()
self._statistics_methods.update(statistics_methods)
self.__dict__.update(state)
|