aliabd
full working demo
d5175d3
raw
history blame
7.33 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import random
from collections import Counter
import torch
class EM:
"""
EM algorithm used to quantize the columns of W to minimize
||W - W_hat||^2
Args:
- W: weight matrix of size (in_features x out_features)
- n_iter: number of k-means iterations
- n_centroids: number of centroids (size of codebook)
- eps: for cluster reassignment when an empty cluster is found
- max_tentatives for cluster reassignment when an empty cluster is found
- verbose: print error after each iteration
Remarks:
- If one cluster is empty, the most populated cluster is split into
two clusters
- All the relevant dimensions are specified in the code
"""
def __init__(
self, W, n_centroids=256, n_iter=20, eps=1e-6, max_tentatives=30, verbose=True
):
self.W = W
self.n_centroids = n_centroids
self.n_iter = n_iter
self.eps = eps
self.max_tentatives = max_tentatives
self.verbose = verbose
self.centroids = torch.Tensor()
self.assignments = torch.Tensor()
self.objective = []
def initialize_centroids(self):
"""
Initializes the centroids by sampling random columns from W.
"""
in_features, out_features = self.W.size()
indices = torch.randint(
low=0, high=out_features, size=(self.n_centroids,)
).long()
self.centroids = self.W[:, indices].t() # (n_centroids x in_features)
def step(self, i):
"""
There are two standard steps for each iteration: expectation (E) and
minimization (M). The E-step (assignment) is performed with an exhaustive
search and the M-step (centroid computation) is performed with
the exact solution.
Args:
- i: step number
Remarks:
- The E-step heavily uses PyTorch broadcasting to speed up computations
and reduce the memory overhead
"""
# assignments (E-step)
distances = self.compute_distances() # (n_centroids x out_features)
self.assignments = torch.argmin(distances, dim=0) # (out_features)
n_empty_clusters = self.resolve_empty_clusters()
# centroids (M-step)
for k in range(self.n_centroids):
W_k = self.W[:, self.assignments == k] # (in_features x size_of_cluster_k)
self.centroids[k] = W_k.mean(dim=1) # (in_features)
# book-keeping
obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item()
self.objective.append(obj)
if self.verbose:
logging.info(
f"Iteration: {i},\t"
f"objective: {obj:.6f},\t"
f"resolved empty clusters: {n_empty_clusters}"
)
def resolve_empty_clusters(self):
"""
If one cluster is empty, the most populated cluster is split into
two clusters by shifting the respective centroids. This is done
iteratively for a fixed number of tentatives.
"""
# empty clusters
counts = Counter(map(lambda x: x.item(), self.assignments))
empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
n_empty_clusters = len(empty_clusters)
tentatives = 0
while len(empty_clusters) > 0:
# given an empty cluster, find most populated cluster and split it into two
k = random.choice(list(empty_clusters))
m = counts.most_common(1)[0][0]
e = torch.randn_like(self.centroids[m]) * self.eps
self.centroids[k] = self.centroids[m].clone()
self.centroids[k] += e
self.centroids[m] -= e
# recompute assignments
distances = self.compute_distances() # (n_centroids x out_features)
self.assignments = torch.argmin(distances, dim=0) # (out_features)
# check for empty clusters
counts = Counter(map(lambda x: x.item(), self.assignments))
empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
# increment tentatives
if tentatives == self.max_tentatives:
logging.info(
f"Could not resolve all empty clusters, {len(empty_clusters)} remaining"
)
raise EmptyClusterResolveError
tentatives += 1
return n_empty_clusters
def compute_distances(self):
"""
For every centroid m, computes
||M - m[None, :]||_2
Remarks:
- We rely on PyTorch's broadcasting to speed up computations
and reduce the memory overhead
- Without chunking, the sizes in the broadcasting are modified as:
(n_centroids x n_samples x out_features) -> (n_centroids x out_features)
- The broadcasting computation is automatically chunked so that
the tensors fit into the memory of the GPU
"""
nb_centroids_chunks = 1
while True:
try:
return torch.cat(
[
(self.W[None, :, :] - centroids_c[:, :, None]).norm(p=2, dim=1)
for centroids_c in self.centroids.chunk(
nb_centroids_chunks, dim=0
)
],
dim=0,
)
except RuntimeError:
nb_centroids_chunks *= 2
def assign(self):
"""
Assigns each column of W to its closest centroid, thus essentially
performing the E-step in train().
Remarks:
- The function must be called after train() or after loading
centroids using self.load(), otherwise it will return empty tensors
"""
distances = self.compute_distances() # (n_centroids x out_features)
self.assignments = torch.argmin(distances, dim=0) # (out_features)
def save(self, path, layer):
"""
Saves centroids and assignments.
Args:
- path: folder used to save centroids and assignments
"""
torch.save(self.centroids, os.path.join(path, "{}_centroids.pth".format(layer)))
torch.save(
self.assignments, os.path.join(path, "{}_assignments.pth".format(layer))
)
torch.save(self.objective, os.path.join(path, "{}_objective.pth".format(layer)))
def load(self, path, layer):
"""
Loads centroids and assignments from a given path
Args:
- path: folder use to load centroids and assignments
"""
self.centroids = torch.load(
os.path.join(path, "{}_centroids.pth".format(layer))
)
self.assignments = torch.load(
os.path.join(path, "{}_assignments.pth".format(layer))
)
self.objective = torch.load(
os.path.join(path, "{}_objective.pth".format(layer))
)
class EmptyClusterResolveError(Exception):
pass