Spaces:
Build error
Build error
| # Copyright 2021 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from os.path import exists | |
| from os.path import join as pjoin | |
| import plotly.graph_objects as go | |
| import torch | |
| import transformers | |
| from datasets import load_from_disk | |
| from plotly.io import read_json | |
| from tqdm import tqdm | |
| from .dataset_utils import EMBEDDING_FIELD | |
| def sentence_mean_pooling(model_output, attention_mask): | |
| """Mean pooling of token embeddings for a sentence.""" | |
| token_embeddings = model_output[ | |
| 0 | |
| ] # First element of model_output contains all token embeddings | |
| input_mask_expanded = ( | |
| attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| ) | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| class Embeddings: | |
| def __init__( | |
| self, | |
| dstats=None, | |
| text_dset=None, | |
| text_field_name="text", | |
| cache_path="", | |
| use_cache=False, | |
| ): | |
| """Item embeddings and clustering""" | |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| self.model_name = "sentence-transformers/all-mpnet-base-v2" | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = transformers.AutoModel.from_pretrained(self.model_name).to( | |
| self.device | |
| ) | |
| self.text_dset = text_dset if dstats is None else dstats.text_dset | |
| self.text_field_name = ( | |
| text_field_name if dstats is None else dstats.our_text_field | |
| ) | |
| self.cache_path = cache_path if dstats is None else dstats.cache_path | |
| self.embeddings_dset_fid = pjoin(self.cache_path, "embeddings_dset") | |
| self.embeddings_dset = None | |
| self.node_list_fid = pjoin(self.cache_path, "node_list.th") | |
| self.node_list = None | |
| self.nid_map = None | |
| self.fig_tree_fid = pjoin(self.cache_path, "node_figure.json") | |
| self.fig_tree = None | |
| self.cached_clusters = {} | |
| self.use_cache = use_cache | |
| def compute_sentence_embeddings(self, sentences): | |
| """ | |
| Takes a list of sentences and computes their embeddings | |
| using self.tokenizer and self.model (with output dimension D) | |
| followed by mean pooling of the token representations and normalization | |
| Args: | |
| sentences ([string]): list of N input sentences | |
| Returns: | |
| torch.Tensor: sentence embeddings, dimension NxD | |
| """ | |
| batch = self.tokenizer( | |
| sentences, padding=True, truncation=True, return_tensors="pt" | |
| ) | |
| batch = {k: v.to(self.device) for k, v in batch.items()} | |
| with torch.no_grad(): | |
| model_output = self.model(**batch) | |
| sentence_embeds = sentence_mean_pooling( | |
| model_output, batch["attention_mask"] | |
| ) | |
| sentence_embeds /= sentence_embeds.norm(dim=-1, keepdim=True) | |
| return sentence_embeds | |
| def make_embeddings(self): | |
| """ | |
| Batch computes the embeddings of the Dataset self.text_dset, | |
| using the field self.text_field_name as input. | |
| Returns: | |
| Dataset: HF dataset object with a single EMBEDDING_FIELD field | |
| corresponding to the embeddings (list of floats) | |
| """ | |
| def batch_embed_sentences(sentences): | |
| return { | |
| EMBEDDING_FIELD: [ | |
| embed.tolist() | |
| for embed in self.compute_sentence_embeddings( | |
| sentences[self.text_field_name] | |
| ) | |
| ] | |
| } | |
| self.embeddings_dset = self.text_dset.map( | |
| batch_embed_sentences, | |
| batched=True, | |
| batch_size=32, | |
| remove_columns=[self.text_field_name], | |
| ) | |
| return self.embeddings_dset | |
| def make_text_embeddings(self): | |
| """Load embeddings dataset from cache or compute it.""" | |
| if self.use_cache and exists(self.embeddings_dset_fid): | |
| self.embeddings_dset = load_from_disk(self.embeddings_dset_fid) | |
| else: | |
| self.embeddings_dset = self.make_embeddings() | |
| self.embeddings_dset.save_to_disk(self.embeddings_dset_fid) | |
| def make_hierarchical_clustering( | |
| self, | |
| batch_size=1000, | |
| approx_neighbors=1000, | |
| min_cluster_size=10, | |
| ): | |
| if self.use_cache and exists(self.node_list_fid): | |
| self.node_list, self.nid_map = torch.load(self.node_list_fid) | |
| else: | |
| self.make_text_embeddings() | |
| embeddings = torch.Tensor(self.embeddings_dset[EMBEDDING_FIELD]) | |
| self.node_list = fast_cluster( | |
| embeddings, batch_size, approx_neighbors, min_cluster_size | |
| ) | |
| self.nid_map = dict( | |
| [(node["nid"], nid) for nid, node in enumerate(self.node_list)] | |
| ) | |
| torch.save((self.node_list, self.nid_map), self.node_list_fid) | |
| print(exists(self.fig_tree_fid), self.fig_tree_fid) | |
| if self.use_cache and exists(self.fig_tree_fid): | |
| self.fig_tree = read_json(self.fig_tree_fid) | |
| else: | |
| self.fig_tree = make_tree_plot( | |
| self.node_list, self.nid_map, self.text_dset, self.text_field_name | |
| ) | |
| self.fig_tree.write_json(self.fig_tree_fid) | |
| def find_cluster_beam(self, sentence, beam_size=20): | |
| """ | |
| This function finds the `beam_size` leaf clusters that are closest to the | |
| proposed sentence and returns the full path from the root to the cluster | |
| along with the dot product between the sentence embedding and the | |
| cluster centroid | |
| Args: | |
| sentence (string): input sentence for which to find clusters | |
| beam_size (int): this is a beam size algorithm to explore the tree | |
| Returns: | |
| [([int], float)]: list of (path_from_root, score) sorted by score | |
| """ | |
| embed = self.compute_sentence_embeddings([sentence])[0].to("cpu") | |
| active_paths = [([0], torch.dot(embed, self.node_list[0]["centroid"]).item())] | |
| finished_paths = [] | |
| children_ids_list = [ | |
| [ | |
| self.nid_map[nid] | |
| for nid in self.node_list[path[-1]]["children_ids"] | |
| if nid in self.nid_map | |
| ] | |
| for path, score in active_paths | |
| ] | |
| while len(active_paths) > 0: | |
| next_ids = sorted( | |
| [ | |
| ( | |
| beam_id, | |
| nid, | |
| torch.dot(embed, self.node_list[nid]["centroid"]).item(), | |
| ) | |
| for beam_id, children_ids in enumerate(children_ids_list) | |
| for nid in children_ids | |
| ], | |
| key=lambda x: x[2], | |
| reverse=True, | |
| )[:beam_size] | |
| paths = [ | |
| (active_paths[beam_id][0] + [next_id], score) | |
| for beam_id, next_id, score in next_ids | |
| ] | |
| active_paths = [] | |
| for path, score in paths: | |
| if ( | |
| len( | |
| [ | |
| nid | |
| for nid in self.node_list[path[-1]]["children_ids"] | |
| if nid in self.nid_map | |
| ] | |
| ) | |
| > 0 | |
| ): | |
| active_paths += [(path, score)] | |
| else: | |
| finished_paths += [(path, score)] | |
| children_ids_list = [ | |
| [ | |
| self.nid_map[nid] | |
| for nid in self.node_list[path[-1]]["children_ids"] | |
| if nid in self.nid_map | |
| ] | |
| for path, score in active_paths | |
| ] | |
| return sorted( | |
| finished_paths, | |
| key=lambda x: x[-1], | |
| reverse=True, | |
| )[:beam_size] | |
| def prepare_merges(embeddings, batch_size=1000, approx_neighbors=1000, low_thres=0.5): | |
| """ | |
| Prepares an initial list of merges for hierarchical | |
| clustering. First compute the `approx_neighbors` nearest neighbors, | |
| then propose a merge for any two points that are closer than `low_thres` | |
| Note that if a point has more than `approx_neighbors` neighbors | |
| closer than `low_thres`, this approach will miss some of those merges | |
| Args: | |
| embeddings (toch.Tensor): Tensor of sentence embeddings - dimension NxD | |
| batch_size (int): compute nearest neighbors of `batch_size` points at a time | |
| approx_neighbors (int): only keep `approx_neighbors` nearest neighbors of a point | |
| low_thres (float): only return merges where the dot product is greater than `low_thres` | |
| Returns: | |
| torch.LongTensor: proposed merges ([i, j] with i>j) - dimension: Mx2 | |
| torch.Tensor: merge scores - dimension M | |
| """ | |
| top_idx_pre = torch.cat( | |
| [torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1 | |
| ) | |
| top_val_all = torch.Tensor(0, approx_neighbors) | |
| top_idx_all = torch.LongTensor(0, approx_neighbors) | |
| n_batches = math.ceil(len(embeddings) / batch_size) | |
| for b in tqdm(range(n_batches)): | |
| # TODO: batch across second dimension | |
| cos_scores = torch.mm( | |
| embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t() | |
| ) | |
| for i in range(cos_scores.shape[0]): | |
| cos_scores[i, (b * batch_size) + i :] = -1 | |
| top_val_large, top_idx_large = cos_scores.topk( | |
| k=approx_neighbors, dim=-1, largest=True | |
| ) | |
| top_val_all = torch.cat([top_val_all, top_val_large], dim=0) | |
| top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0) | |
| max_neighbor_dist = top_val_large[:, -1].max().item() | |
| if max_neighbor_dist > low_thres: | |
| print( | |
| f"WARNING: with the current set of neireast neighbor, the farthest is {max_neighbor_dist}" | |
| ) | |
| all_merges = torch.cat( | |
| [ | |
| top_idx_pre[top_val_all > low_thres][:, None], | |
| top_idx_all[top_val_all > low_thres][:, None], | |
| ], | |
| dim=1, | |
| ) | |
| all_merge_scores = top_val_all[top_val_all > low_thres] | |
| return (all_merges, all_merge_scores) | |
| def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores): | |
| """ | |
| Merge all nodes if the max dot product between any of their descendants | |
| is greater than current_thres. | |
| Args: | |
| nodes ([dict]): list of dicts representing the current set of nodes | |
| current_thres (float): merge all nodes closer than current_thres | |
| previous_thres (float): nodes closer than previous_thres are already merged | |
| all_merges (torch.LongTensor): proposed merges ([i, j] with i>j) - dimension: Mx2 | |
| all_merge_scores (torch.Tensor): merge scores - dimension M | |
| Returns: | |
| [dict]: extended list with the newly created internal nodes | |
| """ | |
| merge_ids = (all_merge_scores <= previous_thres) * ( | |
| all_merge_scores > current_thres | |
| ) | |
| if merge_ids.sum().item() > 0: | |
| merges = all_merges[merge_ids] | |
| for a, b in merges.tolist(): | |
| node_a = nodes[a] | |
| while node_a["parent_id"] != -1: | |
| node_a = nodes[node_a["parent_id"]] | |
| node_b = nodes[b] | |
| while node_b["parent_id"] != -1: | |
| node_b = nodes[node_b["parent_id"]] | |
| if node_a["nid"] == node_b["nid"]: | |
| continue | |
| else: | |
| # merge if threshold allows | |
| if (node_a["depth"] + node_b["depth"]) > 0 and min( | |
| node_a["merge_threshold"], node_b["merge_threshold"] | |
| ) == current_thres: | |
| merge_to = None | |
| merge_from = None | |
| if node_a["nid"] < node_b["nid"]: | |
| merge_from = node_a | |
| merge_to = node_b | |
| if node_a["nid"] > node_b["nid"]: | |
| merge_from = node_b | |
| merge_to = node_a | |
| merge_to["depth"] = max(merge_to["depth"], merge_from["depth"]) | |
| merge_to["weight"] += merge_from["weight"] | |
| merge_to["children_ids"] += ( | |
| merge_from["children_ids"] | |
| if merge_from["depth"] > 0 | |
| else [merge_from["nid"]] | |
| ) | |
| for cid in merge_from["children_ids"]: | |
| nodes[cid]["parent_id"] = merge_to["nid"] | |
| merge_from["parent_id"] = merge_to["nid"] | |
| # else new node | |
| else: | |
| new_nid = len(nodes) | |
| new_node = { | |
| "nid": new_nid, | |
| "parent_id": -1, | |
| "depth": max(node_a["depth"], node_b["depth"]) + 1, | |
| "weight": node_a["weight"] + node_b["weight"], | |
| "children": [], | |
| "children_ids": [node_a["nid"], node_b["nid"]], | |
| "example_ids": [], | |
| "merge_threshold": current_thres, | |
| } | |
| node_a["parent_id"] = new_nid | |
| node_b["parent_id"] = new_nid | |
| nodes += [new_node] | |
| return nodes | |
| def finalize_node(node, nodes, min_cluster_size): | |
| """Post-process nodes to sort children by descending weight, | |
| get full list of leaves in the sub-tree, and direct links | |
| to the cildren nodes, then recurses to all children. | |
| Nodes with fewer than `min_cluster_size` descendants are collapsed | |
| into a single leaf. | |
| """ | |
| node["children"] = sorted( | |
| [ | |
| finalize_node(nodes[cid], nodes, min_cluster_size) | |
| for cid in node["children_ids"] | |
| ], | |
| key=lambda x: x["weight"], | |
| reverse=True, | |
| ) | |
| if node["depth"] > 0: | |
| node["example_ids"] = [ | |
| eid for child in node["children"] for eid in child["example_ids"] | |
| ] | |
| node["children"] = [ | |
| child for child in node["children"] if child["weight"] >= min_cluster_size | |
| ] | |
| assert node["weight"] == len(node["example_ids"]), print(node) | |
| return node | |
| def fast_cluster( | |
| embeddings, | |
| batch_size=1000, | |
| approx_neighbors=1000, | |
| min_cluster_size=10, | |
| low_thres=0.5, | |
| ): | |
| """ | |
| Computes an approximate hierarchical clustering based on example | |
| embeddings. The join criterion is min clustering, i.e. two clusters | |
| are joined if any pair of their descendants are closer than a threshold | |
| The approximate comes from the fact that only the `approx_neighbors` nearest | |
| neighbors of an example are considered for merges | |
| """ | |
| batch_size = min(embeddings.shape[0], batch_size) | |
| all_merges, all_merge_scores = prepare_merges( | |
| embeddings, batch_size, approx_neighbors, low_thres | |
| ) | |
| # prepare leaves | |
| nodes = [ | |
| { | |
| "nid": nid, | |
| "parent_id": -1, | |
| "depth": 0, | |
| "weight": 1, | |
| "children": [], | |
| "children_ids": [], | |
| "example_ids": [nid], | |
| "merge_threshold": 1.0, | |
| } | |
| for nid in range(embeddings.shape[0]) | |
| ] | |
| # one level per threshold range | |
| for i in range(10): | |
| p_thres = 1 - i * 0.05 | |
| c_thres = 0.95 - i * 0.05 | |
| nodes = merge_nodes(nodes, c_thres, p_thres, all_merges, all_merge_scores) | |
| # make root | |
| root_children = [ | |
| node | |
| for node in nodes | |
| if node["parent_id"] == -1 and node["weight"] >= min_cluster_size | |
| ] | |
| root = { | |
| "nid": len(nodes), | |
| "parent_id": -1, | |
| "depth": max([node["depth"] for node in root_children]) + 1, | |
| "weight": sum([node["weight"] for node in root_children]), | |
| "children": [], | |
| "children_ids": [node["nid"] for node in root_children], | |
| "example_ids": [], | |
| "merge_threshold": -1.0, | |
| } | |
| nodes += [root] | |
| for node in root_children: | |
| node["parent_id"] = root["nid"] | |
| # finalize tree | |
| tree = finalize_node(root, nodes, min_cluster_size) | |
| node_list = [] | |
| def rec_map_nodes(node, node_list): | |
| node_list += [node] | |
| for child in node["children"]: | |
| rec_map_nodes(child, node_list) | |
| rec_map_nodes(tree, node_list) | |
| # get centroids and distances | |
| for node in node_list: | |
| node_embeds = embeddings[node["example_ids"]] | |
| node["centroid"] = node_embeds.sum(dim=0) | |
| node["centroid"] /= node["centroid"].norm() | |
| node["centroid_dot_prods"] = torch.mv(node_embeds, node["centroid"]) | |
| node["sorted_examples_centroid"] = sorted( | |
| [ | |
| (eid, edp.item()) | |
| for eid, edp in zip(node["example_ids"], node["centroid_dot_prods"]) | |
| ], | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| return node_list | |
| def make_tree_plot(node_list, nid_map, text_dset, text_field_name): | |
| """ | |
| Makes a graphical representation of the tree encoded | |
| in node-list. The hover label for each node shows the number | |
| of descendants and the 5 examples that are closest to the centroid | |
| """ | |
| for nid, node in enumerate(node_list): | |
| # get list of | |
| node_examples = {} | |
| for sid, score in node["sorted_examples_centroid"]: | |
| node_examples[text_dset[sid][text_field_name]] = score | |
| if len(node_examples) >= 5: | |
| break | |
| node["label"] = node.get( | |
| "label", | |
| f"{nid:2d} - {node['weight']:5d} items <br>" | |
| + "<br>".join( | |
| [ | |
| f" {score:.2f} > {txt[:64]}" + ("..." if len(txt) >= 63 else "") | |
| for txt, score in node_examples.items() | |
| ] | |
| ), | |
| ) | |
| # make plot nodes | |
| labels = [node["label"] for node in node_list] | |
| root = node_list[0] | |
| root["X"] = 0 | |
| root["Y"] = 0 | |
| def rec_make_coordinates(node): | |
| total_weight = 0 | |
| add_weight = len(node["example_ids"]) - sum( | |
| [child["weight"] for child in node["children"]] | |
| ) | |
| for child in node["children"]: | |
| child["X"] = node["X"] + total_weight | |
| child["Y"] = node["Y"] - 1 | |
| total_weight += child["weight"] + add_weight / len(node["children"]) | |
| rec_make_coordinates(child) | |
| rec_make_coordinates(root) | |
| E = [] # list of edges | |
| Xn = [] | |
| Yn = [] | |
| Xe = [] | |
| Ye = [] | |
| for nid, node in enumerate(node_list): | |
| Xn += [node["X"]] | |
| Yn += [node["Y"]] | |
| for child in node["children"]: | |
| E += [(nid, nid_map[child["nid"]])] | |
| Xe += [node["X"], child["X"], None] | |
| Ye += [node["Y"], child["Y"], None] | |
| # make figure | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=Xe, | |
| y=Ye, | |
| mode="lines", | |
| line=dict(color="rgb(210,210,210)", width=1), | |
| hoverinfo="none", | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=Xn, | |
| y=Yn, | |
| mode="markers", | |
| name="nodes", | |
| marker=dict( | |
| symbol="circle-dot", | |
| size=18, | |
| color="#6175c1", | |
| line=dict(color="rgb(50,50,50)", width=1) | |
| # '#DB4551', | |
| ), | |
| text=labels, | |
| hoverinfo="text", | |
| opacity=0.8, | |
| ) | |
| ) | |
| return fig | |