|
from fnmatch import fnmatch |
|
from typing import Any, Dict, List, Optional, Union |
|
import torch |
|
from dataclasses import dataclass |
|
|
|
from optimum.quanto.quantize import _quantize_submodule |
|
from optimum.quanto.tensor import Optimizer, qtype, qtypes |
|
from torchao.quantization.quant_api import ( |
|
quantize_ as torchao_quantize_, |
|
Float8WeightOnlyConfig, |
|
UIntXWeightOnlyConfig |
|
) |
|
|
|
|
|
|
|
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag'] |
|
|
|
torchao_qtypes = { |
|
|
|
"uint2": UIntXWeightOnlyConfig(torch.uint2), |
|
"uint3": UIntXWeightOnlyConfig(torch.uint3), |
|
"uint4": UIntXWeightOnlyConfig(torch.uint4), |
|
"uint5": UIntXWeightOnlyConfig(torch.uint5), |
|
"uint6": UIntXWeightOnlyConfig(torch.uint6), |
|
"uint7": UIntXWeightOnlyConfig(torch.uint7), |
|
"uint8": UIntXWeightOnlyConfig(torch.uint8), |
|
"float8": Float8WeightOnlyConfig(), |
|
} |
|
|
|
class aotype: |
|
def __init__(self, name: str): |
|
self.name = name |
|
self.config = torchao_qtypes[name] |
|
|
|
def get_qtype(qtype: Union[str, qtype]) -> qtype: |
|
if qtype in torchao_qtypes: |
|
return aotype(qtype) |
|
if isinstance(qtype, str): |
|
return qtypes[qtype] |
|
else: |
|
return qtype |
|
|
|
def quantize( |
|
model: torch.nn.Module, |
|
weights: Optional[Union[str, qtype, aotype]] = None, |
|
activations: Optional[Union[str, qtype]] = None, |
|
optimizer: Optional[Optimizer] = None, |
|
include: Optional[Union[str, List[str]]] = None, |
|
exclude: Optional[Union[str, List[str]]] = None, |
|
): |
|
"""Quantize the specified model submodules |
|
|
|
Recursively quantize the submodules of the specified parent model. |
|
|
|
Only modules that have quantized counterparts will be quantized. |
|
|
|
If include patterns are specified, the submodule name must match one of them. |
|
|
|
If exclude patterns are specified, the submodule must not match one of them. |
|
|
|
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See |
|
https://docs.python.org/3/library/fnmatch.html for more details. |
|
|
|
Note: quantization happens in-place and modifies the original model and its descendants. |
|
|
|
Args: |
|
model (`torch.nn.Module`): the model whose submodules will be quantized. |
|
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. |
|
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. |
|
include (`Optional[Union[str, List[str]]]`): |
|
Patterns constituting the allowlist. If provided, module names must match at |
|
least one pattern from the allowlist. |
|
exclude (`Optional[Union[str, List[str]]]`): |
|
Patterns constituting the denylist. If provided, module names must not match |
|
any patterns from the denylist. |
|
""" |
|
if include is not None: |
|
include = [include] if isinstance(include, str) else include |
|
if exclude is not None: |
|
exclude = [exclude] if isinstance(exclude, str) else exclude |
|
for name, m in model.named_modules(): |
|
if include is not None and not any(fnmatch(name, pattern) for pattern in include): |
|
continue |
|
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): |
|
continue |
|
try: |
|
|
|
if m.__class__.__name__ in Q_MODULES: |
|
continue |
|
else: |
|
if isinstance(weights, aotype): |
|
torchao_quantize_(m, weights.config) |
|
else: |
|
_quantize_submodule(model, name, m, weights=weights, |
|
activations=activations, optimizer=optimizer) |
|
except Exception as e: |
|
print(f"Failed to quantize {name}: {e}") |
|
raise e |