Spaces:
Build error
Build error
from scipy.spatial.distance import cosine | |
import argparse | |
import json | |
import pdb | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import time | |
from collections import OrderedDict | |
class TWCClustering: | |
def __init__(self): | |
print("In Zscore Clustering") | |
def compute_matrix(self,embeddings): | |
print("Computing similarity matrix ...)") | |
embeddings= np.array(embeddings) | |
start = time.time() | |
vec_a = embeddings.T #vec_a shape (1024,) | |
vec_a = vec_a/np.linalg.norm(vec_a,axis=0) #Norm is along axis 0 - rows | |
vec_a = vec_a.T #vec_a shape becomes (,1024) | |
similarity_matrix = np.inner(vec_a,vec_a) | |
end = time.time() | |
time_val = (end-start)*1000 | |
print(f"Similarity matrix computation complete. Time taken:{(time_val/(1000*60)):.2f} minutes") | |
return similarity_matrix | |
def get_terms_above_threshold(self,matrix,embeddings,pivot_index,threshold): | |
run_index = pivot_index | |
picked_arr = [] | |
while (run_index < len(embeddings)): | |
if (matrix[pivot_index][run_index] >= threshold): | |
#picked_arr.append({"index":run_index,"val":matrix[pivot_index][run_index]}) | |
picked_arr.append({"index":run_index}) | |
run_index += 1 | |
return picked_arr | |
def update_picked_dict(self,picked_dict,in_dict): | |
for key in in_dict: | |
picked_dict[key] = 1 | |
def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold): | |
center_index = pivot_index | |
center_score = 0 | |
center_dict = {} | |
for i in range(len(arr)): | |
node_i_index = arr[i]["index"] | |
running_score = 0 | |
temp_dict = {} | |
for j in range(len(arr)): | |
node_j_index = arr[j]["index"] | |
cosine_dist = matrix[node_i_index][node_j_index] | |
if (cosine_dist < threshold): | |
continue | |
running_score += cosine_dist | |
temp_dict[node_j_index] = cosine_dist | |
if (running_score > center_score): | |
center_index = node_i_index | |
center_dict = temp_dict | |
center_score = running_score | |
sorted_d = OrderedDict(sorted(center_dict.items(), key=lambda kv: kv[1], reverse=True)) | |
return {"pivot_index":center_index,"orig_index":pivot_index,"neighs":sorted_d} | |
def cluster(self,output_file,texts,embeddings,threshold = 1.5): | |
matrix = self.compute_matrix(embeddings) | |
mean = np.mean(matrix) | |
std = np.std(matrix) | |
zscores = [] | |
inc = 0 | |
value = mean | |
while (value < 1): | |
zscores.append(round(value,2)) | |
inc += 1 | |
value = mean + inc*std | |
print("In clustering:",round(std,2),zscores) | |
cluster_dict = {} | |
cluster_dict["clusters"] = [] | |
picked_dict = {} | |
for i in range(len(embeddings)): | |
if (i in picked_dict): | |
continue | |
zscore = mean + threshold*std | |
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore) | |
cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore) | |
self.update_picked_dict(picked_dict,cluster_info["neighs"]) | |
cluster_dict["clusters"].append(cluster_info) | |
cluster_dict["info"] ={"mean":mean,"std":std,"zscores":zscores} | |
return cluster_dict | |