# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import os import random from collections import Counter import torch class EM: """ EM algorithm used to quantize the columns of W to minimize ||W - W_hat||^2 Args: - W: weight matrix of size (in_features x out_features) - n_iter: number of k-means iterations - n_centroids: number of centroids (size of codebook) - eps: for cluster reassignment when an empty cluster is found - max_tentatives for cluster reassignment when an empty cluster is found - verbose: print error after each iteration Remarks: - If one cluster is empty, the most populated cluster is split into two clusters - All the relevant dimensions are specified in the code """ def __init__( self, W, n_centroids=256, n_iter=20, eps=1e-6, max_tentatives=30, verbose=True ): self.W = W self.n_centroids = n_centroids self.n_iter = n_iter self.eps = eps self.max_tentatives = max_tentatives self.verbose = verbose self.centroids = torch.Tensor() self.assignments = torch.Tensor() self.objective = [] def initialize_centroids(self): """ Initializes the centroids by sampling random columns from W. """ in_features, out_features = self.W.size() indices = torch.randint( low=0, high=out_features, size=(self.n_centroids,) ).long() self.centroids = self.W[:, indices].t() # (n_centroids x in_features) def step(self, i): """ There are two standard steps for each iteration: expectation (E) and minimization (M). The E-step (assignment) is performed with an exhaustive search and the M-step (centroid computation) is performed with the exact solution. Args: - i: step number Remarks: - The E-step heavily uses PyTorch broadcasting to speed up computations and reduce the memory overhead """ # assignments (E-step) distances = self.compute_distances() # (n_centroids x out_features) self.assignments = torch.argmin(distances, dim=0) # (out_features) n_empty_clusters = self.resolve_empty_clusters() # centroids (M-step) for k in range(self.n_centroids): W_k = self.W[:, self.assignments == k] # (in_features x size_of_cluster_k) self.centroids[k] = W_k.mean(dim=1) # (in_features) # book-keeping obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item() self.objective.append(obj) if self.verbose: logging.info( f"Iteration: {i},\t" f"objective: {obj:.6f},\t" f"resolved empty clusters: {n_empty_clusters}" ) def resolve_empty_clusters(self): """ If one cluster is empty, the most populated cluster is split into two clusters by shifting the respective centroids. This is done iteratively for a fixed number of tentatives. """ # empty clusters counts = Counter(map(lambda x: x.item(), self.assignments)) empty_clusters = set(range(self.n_centroids)) - set(counts.keys()) n_empty_clusters = len(empty_clusters) tentatives = 0 while len(empty_clusters) > 0: # given an empty cluster, find most populated cluster and split it into two k = random.choice(list(empty_clusters)) m = counts.most_common(1)[0][0] e = torch.randn_like(self.centroids[m]) * self.eps self.centroids[k] = self.centroids[m].clone() self.centroids[k] += e self.centroids[m] -= e # recompute assignments distances = self.compute_distances() # (n_centroids x out_features) self.assignments = torch.argmin(distances, dim=0) # (out_features) # check for empty clusters counts = Counter(map(lambda x: x.item(), self.assignments)) empty_clusters = set(range(self.n_centroids)) - set(counts.keys()) # increment tentatives if tentatives == self.max_tentatives: logging.info( f"Could not resolve all empty clusters, {len(empty_clusters)} remaining" ) raise EmptyClusterResolveError tentatives += 1 return n_empty_clusters def compute_distances(self): """ For every centroid m, computes ||M - m[None, :]||_2 Remarks: - We rely on PyTorch's broadcasting to speed up computations and reduce the memory overhead - Without chunking, the sizes in the broadcasting are modified as: (n_centroids x n_samples x out_features) -> (n_centroids x out_features) - The broadcasting computation is automatically chunked so that the tensors fit into the memory of the GPU """ nb_centroids_chunks = 1 while True: try: return torch.cat( [ (self.W[None, :, :] - centroids_c[:, :, None]).norm(p=2, dim=1) for centroids_c in self.centroids.chunk( nb_centroids_chunks, dim=0 ) ], dim=0, ) except RuntimeError: nb_centroids_chunks *= 2 def assign(self): """ Assigns each column of W to its closest centroid, thus essentially performing the E-step in train(). Remarks: - The function must be called after train() or after loading centroids using self.load(), otherwise it will return empty tensors """ distances = self.compute_distances() # (n_centroids x out_features) self.assignments = torch.argmin(distances, dim=0) # (out_features) def save(self, path, layer): """ Saves centroids and assignments. Args: - path: folder used to save centroids and assignments """ torch.save(self.centroids, os.path.join(path, "{}_centroids.pth".format(layer))) torch.save( self.assignments, os.path.join(path, "{}_assignments.pth".format(layer)) ) torch.save(self.objective, os.path.join(path, "{}_objective.pth".format(layer))) def load(self, path, layer): """ Loads centroids and assignments from a given path Args: - path: folder use to load centroids and assignments """ self.centroids = torch.load( os.path.join(path, "{}_centroids.pth".format(layer)) ) self.assignments = torch.load( os.path.join(path, "{}_assignments.pth".format(layer)) ) self.objective = torch.load( os.path.join(path, "{}_objective.pth".format(layer)) ) class EmptyClusterResolveError(Exception): pass