File size: 4,292 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .em import EM, EmptyClusterResolveError
class PQ(EM):
"""
Quantizes the layer weights W with the standard Product Quantization
technique. This learns a codebook of codewords or centroids of size
block_size from W. For further reference on using PQ to quantize
neural networks, see "And the Bit Goes Down: Revisiting the Quantization
of Neural Networks", Stock et al., ICLR 2020.
PQ is performed in two steps:
(1) The matrix W (weights or fully-connected or convolutional layer)
is reshaped to (block_size, -1).
- If W is fully-connected (2D), its columns are split into
blocks of size block_size.
- If W is convolutional (4D), its filters are split along the
spatial dimension.
(2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix.
Args:
- W: weight matrix to quantize of size (in_features x out_features)
- block_size: size of the blocks (subvectors)
- n_centroids: number of centroids
- n_iter: number of k-means iterations
- eps: for cluster reassignment when an empty cluster is found
- max_tentatives for cluster reassignment when an empty cluster is found
- verbose: print information after each iteration
Remarks:
- block_size be compatible with the shape of W
"""
def __init__(
self,
W,
block_size,
n_centroids=256,
n_iter=20,
eps=1e-6,
max_tentatives=30,
verbose=True,
):
self.block_size = block_size
W_reshaped = self._reshape(W)
super(PQ, self).__init__(
W_reshaped,
n_centroids=n_centroids,
n_iter=n_iter,
eps=eps,
max_tentatives=max_tentatives,
verbose=verbose,
)
def _reshape(self, W):
"""
Reshapes the matrix W as expained in step (1).
"""
# fully connected: by convention the weight has size out_features x in_features
if len(W.size()) == 2:
self.out_features, self.in_features = W.size()
assert (
self.in_features % self.block_size == 0
), "Linear: n_blocks must be a multiple of in_features"
return (
W.reshape(self.out_features, -1, self.block_size)
.permute(2, 1, 0)
.flatten(1, 2)
)
# convolutional: we reshape along the spatial dimension
elif len(W.size()) == 4:
self.out_channels, self.in_channels, self.k_h, self.k_w = W.size()
assert (
self.in_channels * self.k_h * self.k_w
) % self.block_size == 0, (
"Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w"
)
return (
W.reshape(self.out_channels, -1, self.block_size)
.permute(2, 1, 0)
.flatten(1, 2)
)
# not implemented
else:
raise NotImplementedError(W.size())
def encode(self):
"""
Performs self.n_iter EM steps.
"""
self.initialize_centroids()
for i in range(self.n_iter):
try:
self.step(i)
except EmptyClusterResolveError:
break
def decode(self):
"""
Returns the encoded full weight matrix. Must be called after
the encode function.
"""
# fully connected case
if "k_h" not in self.__dict__:
return (
self.centroids[self.assignments]
.reshape(-1, self.out_features, self.block_size)
.permute(1, 0, 2)
.flatten(1, 2)
)
# convolutional case
else:
return (
self.centroids[self.assignments]
.reshape(-1, self.out_channels, self.block_size)
.permute(1, 0, 2)
.reshape(self.out_channels, self.in_channels, self.k_h, self.k_w)
)
|