|
|
|
|
|
|
|
|
|
|
|
from .em import EM, EmptyClusterResolveError |
|
|
|
|
|
class PQ(EM): |
|
""" |
|
Quantizes the layer weights W with the standard Product Quantization |
|
technique. This learns a codebook of codewords or centroids of size |
|
block_size from W. For further reference on using PQ to quantize |
|
neural networks, see "And the Bit Goes Down: Revisiting the Quantization |
|
of Neural Networks", Stock et al., ICLR 2020. |
|
|
|
PQ is performed in two steps: |
|
(1) The matrix W (weights or fully-connected or convolutional layer) |
|
is reshaped to (block_size, -1). |
|
- If W is fully-connected (2D), its columns are split into |
|
blocks of size block_size. |
|
- If W is convolutional (4D), its filters are split along the |
|
spatial dimension. |
|
(2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix. |
|
|
|
Args: |
|
- W: weight matrix to quantize of size (in_features x out_features) |
|
- block_size: size of the blocks (subvectors) |
|
- n_centroids: number of centroids |
|
- n_iter: number of k-means iterations |
|
- eps: for cluster reassignment when an empty cluster is found |
|
- max_tentatives for cluster reassignment when an empty cluster is found |
|
- verbose: print information after each iteration |
|
|
|
Remarks: |
|
- block_size be compatible with the shape of W |
|
""" |
|
|
|
def __init__( |
|
self, |
|
W, |
|
block_size, |
|
n_centroids=256, |
|
n_iter=20, |
|
eps=1e-6, |
|
max_tentatives=30, |
|
verbose=True, |
|
): |
|
self.block_size = block_size |
|
W_reshaped = self._reshape(W) |
|
super(PQ, self).__init__( |
|
W_reshaped, |
|
n_centroids=n_centroids, |
|
n_iter=n_iter, |
|
eps=eps, |
|
max_tentatives=max_tentatives, |
|
verbose=verbose, |
|
) |
|
|
|
def _reshape(self, W): |
|
""" |
|
Reshapes the matrix W as expained in step (1). |
|
""" |
|
|
|
|
|
if len(W.size()) == 2: |
|
self.out_features, self.in_features = W.size() |
|
assert ( |
|
self.in_features % self.block_size == 0 |
|
), "Linear: n_blocks must be a multiple of in_features" |
|
return ( |
|
W.reshape(self.out_features, -1, self.block_size) |
|
.permute(2, 1, 0) |
|
.flatten(1, 2) |
|
) |
|
|
|
|
|
elif len(W.size()) == 4: |
|
self.out_channels, self.in_channels, self.k_h, self.k_w = W.size() |
|
assert ( |
|
self.in_channels * self.k_h * self.k_w |
|
) % self.block_size == 0, ( |
|
"Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w" |
|
) |
|
return ( |
|
W.reshape(self.out_channels, -1, self.block_size) |
|
.permute(2, 1, 0) |
|
.flatten(1, 2) |
|
) |
|
|
|
else: |
|
raise NotImplementedError(W.size()) |
|
|
|
def encode(self): |
|
""" |
|
Performs self.n_iter EM steps. |
|
""" |
|
|
|
self.initialize_centroids() |
|
for i in range(self.n_iter): |
|
try: |
|
self.step(i) |
|
except EmptyClusterResolveError: |
|
break |
|
|
|
def decode(self): |
|
""" |
|
Returns the encoded full weight matrix. Must be called after |
|
the encode function. |
|
""" |
|
|
|
|
|
if "k_h" not in self.__dict__: |
|
return ( |
|
self.centroids[self.assignments] |
|
.reshape(-1, self.out_features, self.block_size) |
|
.permute(1, 0, 2) |
|
.flatten(1, 2) |
|
) |
|
|
|
|
|
else: |
|
return ( |
|
self.centroids[self.assignments] |
|
.reshape(-1, self.out_channels, self.block_size) |
|
.permute(1, 0, 2) |
|
.reshape(self.out_channels, self.in_channels, self.k_h, self.k_w) |
|
) |
|
|