Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass | |
| from concurrent import futures | |
| from fnmatch import fnmatch | |
| from functools import partial | |
| import io | |
| import math | |
| from multiprocessing import cpu_count | |
| import typing as tp | |
| import zlib | |
| import torch | |
| class BaseQuantizer: | |
| class _QuantizedParam: | |
| name: str | |
| param: torch.nn.Parameter | |
| module: torch.nn.Module | |
| # If a Parameter is used multiple times, `other` can be used | |
| # to share state between the different Quantizers | |
| other: tp.Optional[tp.Any] | |
| def __init__(self, model: torch.nn.Module, min_size: float = 0.01, float16: bool = False, | |
| exclude: tp.Optional[tp.List[str]] = [], detect_bound: bool = True): | |
| self.model = model | |
| self.min_size = min_size | |
| self.float16 = float16 | |
| self.exclude = exclude | |
| self.detect_bound = detect_bound | |
| self._quantized = False | |
| self._pre_handle = self.model.register_forward_pre_hook(self._forward_pre_hook) | |
| self._post_handle = self.model.register_forward_hook(self._forward_hook) | |
| self._quantized_state = None | |
| self._qparams = [] | |
| self._float16 = [] | |
| self._others = [] | |
| self._rnns = [] | |
| self._saved = [] | |
| self._find_params() | |
| def _find_params(self): | |
| min_params = self.min_size * 2**20 // 4 | |
| previous = {} | |
| for module_name, module in self.model.named_modules(): | |
| if isinstance(module, torch.nn.RNNBase): | |
| self._rnns.append(module) | |
| for name, param in list(module.named_parameters(recurse=False)): | |
| full_name = f"{module_name}.{name}" | |
| matched = False | |
| for pattern in self.exclude: | |
| if fnmatch(full_name, pattern) or fnmatch(name, pattern): | |
| matched = True | |
| break | |
| if param.numel() <= min_params or matched: | |
| if id(param) in previous: | |
| continue | |
| if self.detect_bound: | |
| previous[id(param)] = None | |
| if self.float16: | |
| self._float16.append(param) | |
| else: | |
| self._others.append(param) | |
| else: | |
| qparam = self._register_param(name, param, module, previous.get(id(param))) | |
| if self.detect_bound: | |
| previous[id(param)] = qparam | |
| self._qparams.append(qparam) | |
| def _register_param(self, name, param, module, other): | |
| return self.__class__._QuantizedParam(name, param, module, other) | |
| def _forward_pre_hook(self, module, input): | |
| if self.model.training: | |
| self._quantized_state = None | |
| if self._quantized: | |
| self.unquantize() | |
| if self._pre_forward_train(): | |
| self._fix_rnns() | |
| else: | |
| self.quantize() | |
| def _forward_hook(self, module, input, output): | |
| if self.model.training: | |
| if self._post_forward_train(): | |
| self._fix_rnns(flatten=False) # Hacky, next forward will flatten | |
| def quantize(self, save=True): | |
| """ | |
| Immediately apply quantization to the model parameters. | |
| If `save` is True, save a copy of the unquantized parameters, that can be | |
| restored with `unquantize()`. | |
| """ | |
| if self._quantized: | |
| return | |
| if save: | |
| self._saved = [qp.param.data.to('cpu', copy=True) | |
| for qp in self._qparams if qp.other is None] | |
| self.restore_quantized_state(self.get_quantized_state()) | |
| self._quantized = True | |
| self._fix_rnns() | |
| def unquantize(self): | |
| """ | |
| Revert a previous call to `quantize()`. | |
| """ | |
| if not self._quantized: | |
| raise RuntimeError("Can only be called on a quantized model.") | |
| if not self._saved: | |
| raise RuntimeError("Nothing to restore.") | |
| for qparam in self._qparams: | |
| if qparam.other is None: | |
| qparam.param.data[:] = self._saved.pop(0) | |
| assert len(self._saved) == 0 | |
| self._quantized = False | |
| self._fix_rnns() | |
| def _pre_forward_train(self) -> bool: | |
| """ | |
| Called once before each forward for continuous quantization. | |
| Should return True if parameters were changed. | |
| """ | |
| return False | |
| def _post_forward_train(self) -> bool: | |
| """ | |
| Called once after each forward (to restore state for instance). | |
| Should return True if parameters were changed. | |
| """ | |
| return False | |
| def _fix_rnns(self, flatten=True): | |
| """ | |
| To be called after quantization happened to fix RNNs. | |
| """ | |
| for rnn in self._rnns: | |
| rnn._flat_weights = [ | |
| (lambda wn: getattr(rnn, wn) if hasattr(rnn, wn) else None)(wn) | |
| for wn in rnn._flat_weights_names] | |
| if flatten: | |
| rnn.flatten_parameters() | |
| def get_quantized_state(self): | |
| """ | |
| Returns sufficient quantized information to rebuild the model state. | |
| ..Note:: | |
| To achieve maximum compression, you should compress this with | |
| gzip or other, as quantized weights are not optimally coded! | |
| """ | |
| if self._quantized_state is None: | |
| self._quantized_state = self._get_quantized_state() | |
| return self._quantized_state | |
| def _get_quantized_state(self): | |
| """ | |
| Actual implementation for `get_quantized_state`. | |
| """ | |
| float16_params = [] | |
| for p in self._float16: | |
| q = p.data.half() | |
| float16_params.append(q) | |
| return { | |
| "quantized": [self._quantize_param(qparam) for qparam in self._qparams | |
| if qparam.other is None], | |
| "float16": float16_params, | |
| "others": [p.data.clone() for p in self._others], | |
| } | |
| def _quantize_param(self, qparam: _QuantizedParam) -> tp.Any: | |
| """ | |
| To be overriden. | |
| """ | |
| raise NotImplementedError() | |
| def _unquantize_param(self, qparam: _QuantizedParam, quantized: tp.Any) -> torch.Tensor: | |
| """ | |
| To be overriden. | |
| """ | |
| raise NotImplementedError() | |
| def restore_quantized_state(self, state) -> None: | |
| """ | |
| Restore the state of the model from the quantized state. | |
| """ | |
| for p, q in zip(self._float16, state["float16"]): | |
| p.data[:] = q.to(p) | |
| for p, q in zip(self._others, state["others"]): | |
| p.data[:] = q | |
| remaining = list(state["quantized"]) | |
| for qparam in self._qparams: | |
| if qparam.other is not None: | |
| # Only unquantize first appearance of nn.Parameter. | |
| continue | |
| quantized = remaining.pop(0) | |
| qparam.param.data[:] = self._unquantize_param(qparam, quantized) | |
| self._fix_rnns() | |
| def detach(self) -> None: | |
| """ | |
| Detach from the model, removes hooks and anything else. | |
| """ | |
| self._pre_handle.remove() | |
| self._post_handle.remove() | |
| def model_size(self) -> torch.Tensor: | |
| """ | |
| Returns an estimate of the quantized model size. | |
| """ | |
| total = torch.tensor(0.) | |
| for p in self._float16: | |
| total += 16 * p.numel() | |
| for p in self._others: | |
| total += 32 * p.numel() | |
| return total / 2**20 / 8 # bits to MegaBytes | |
| def true_model_size(self) -> float: | |
| """ | |
| Return the true quantized model size, in MB, without extra | |
| compression. | |
| """ | |
| return self.model_size().item() | |
| def compressed_model_size(self, compress_level=-1, num_workers=8) -> float: | |
| """ | |
| Return the compressed quantized model size, in MB. | |
| Args: | |
| compress_level (int): compression level used with zlib, | |
| see `zlib.compress` for details. | |
| num_workers (int): will split the final big byte representation in that | |
| many chunks processed in parallels. | |
| """ | |
| out = io.BytesIO() | |
| torch.save(self.get_quantized_state(), out) | |
| ms = _parallel_compress_len(out.getvalue(), compress_level, num_workers) | |
| return ms / 2 ** 20 | |
| def _compress_len(data, compress_level): | |
| return len(zlib.compress(data, level=compress_level)) | |
| def _parallel_compress_len(data, compress_level, num_workers): | |
| num_workers = min(cpu_count(), num_workers) | |
| chunk_size = int(math.ceil(len(data) / num_workers)) | |
| chunks = [data[offset:offset + chunk_size] for offset in range(0, len(data), chunk_size)] | |
| with futures.ProcessPoolExecutor(num_workers) as pool: | |
| return sum(pool.map(partial(_compress_len, compress_level=compress_level), chunks)) | |