Spaces:
Runtime error
Runtime error
| # Copyright 2021 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import gzip | |
| import json | |
| import math | |
| import os | |
| from os.path import exists | |
| from os.path import join as pjoin | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import torch | |
| import transformers | |
| from datasets import load_dataset | |
| from huggingface_hub import HfApi | |
| from tqdm import tqdm | |
| # from .dataset_utils import prepare_clustering_dataset | |
| pd.options.display.max_colwidth = 256 | |
| _CACHE_DIR = "cache_dir" | |
| _DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2" | |
| _MAX_MERGE = 20000000 # to run on 64GB RAM laptop | |
| def sentence_mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output[ | |
| 0 | |
| ] # First element of model_output contains all token embeddings | |
| input_mask_expanded = ( | |
| attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| ) | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| # get nearest neighbors of a centroid by dot product | |
| def get_examplars(example_ids, centroid, embeddings, dset, n_examplars): | |
| example_embeds = embeddings[example_ids] | |
| example_scores = torch.mv(example_embeds, centroid) | |
| s_scores, s_ids = example_scores.sort(dim=-1, descending=True) | |
| examplars = [ | |
| (example_ids[i.item()], s.item()) | |
| for i, s in zip(s_ids[:n_examplars], s_scores[:n_examplars]) | |
| ] | |
| res = [] | |
| for eid, score in examplars: | |
| dct = dict(dset[eid]) | |
| dct["score"] = score | |
| res += [dct] | |
| return res | |
| # order node children so that the large ones are in the middle | |
| # makes visualization more balanced | |
| def pretty_order(nodes, node_ids): | |
| sorted_ids = sorted(node_ids, key=lambda nid: nodes[nid]["weight"]) | |
| sorted_a = [nid for i, nid in enumerate(sorted_ids) if i % 2 == 0] | |
| sorted_b = [nid for i, nid in enumerate(sorted_ids) if i % 2 == 1] | |
| sorted_b.reverse() | |
| return sorted_a + sorted_b | |
| def make_tree_plot(node_list, root_id, max_depth=-1): | |
| # make plot nodes | |
| plot_nodes = [{} for _ in node_list] | |
| root = { | |
| "parent_id": -1, | |
| "node_id": root_id, | |
| "label": node_list[root_id]["hover_text"], | |
| "weight": node_list[root_id]["weight"], | |
| "num_leaves": 0, | |
| "children_ids": node_list[root_id]["children_ids"], | |
| "Xmin": 0, | |
| "Y": 0, | |
| } | |
| plot_nodes[root_id] = root | |
| root_depth = node_list[root_id]["depth"] | |
| def rec_make_coordinates(node): | |
| total_weight = 0 | |
| recurse = (max_depth == -1) or ( | |
| node_list[node["node_id"]]["depth"] - root_depth < max_depth - 1 | |
| ) | |
| for cid in node["children_ids"]: | |
| plot_nodes[cid] = { | |
| "parent_id": node["node_id"], | |
| "node_id": cid, | |
| "label": node_list[cid]["hover_text"], | |
| "weight": node_list[cid]["weight"], | |
| "children_ids": node_list[cid]["children_ids"] if recurse else [], | |
| "Xmin": node["Xmin"] + total_weight, | |
| "Y": node["Y"] - 1, | |
| } | |
| plot_nodes[cid]["num_leaves"] = 1 if len(plot_nodes[cid]["children_ids"]) == 0 else 0 | |
| rec_make_coordinates(plot_nodes[cid]) | |
| total_weight += plot_nodes[cid]["num_leaves"] | |
| node["num_leaves"] += plot_nodes[cid]["num_leaves"] | |
| node["Xmax"] = node["Xmin"] + node["num_leaves"] | |
| node["X"] = node["Xmin"] + (node["num_leaves"] / 2) | |
| rec_make_coordinates(root) | |
| subtree_nodes = [node for node in plot_nodes if len(node) > 0] | |
| nid_map = dict([(node["node_id"], nid) for nid, node in enumerate(subtree_nodes)]) | |
| labels = [node["label"] for node in subtree_nodes] | |
| E = [] # list of edges | |
| Xn = [] | |
| Yn = [] | |
| Xe = [] | |
| Ye = [] | |
| for nid, node in enumerate(subtree_nodes): | |
| Xn += [node["X"]] | |
| Yn += [node["Y"]] | |
| for cid in node["children_ids"]: | |
| child = plot_nodes[cid] | |
| E += [(nid, nid_map[child["node_id"]])] | |
| Xe += [node["X"], child["X"], None] | |
| Ye += [node["Y"], child["Y"], None] | |
| # make figure | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=Xe, | |
| y=Ye, | |
| mode="lines", | |
| name="", | |
| line=dict(color="rgb(210,210,210)", width=1), | |
| hoverinfo="none", | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter( | |
| x=Xn, | |
| y=Yn, | |
| mode="markers", | |
| name="nodes", | |
| marker=dict( | |
| symbol="circle-dot", | |
| size=18, | |
| color="#6175c1", | |
| line=dict(color="rgb(50,50,50)", width=1) | |
| # '#DB4551', | |
| ), | |
| text=labels, | |
| hoverinfo="text", | |
| opacity=0.8, | |
| ) | |
| ) | |
| fig.layout.showlegend = False | |
| return fig | |
| class ClusteringBuilder: | |
| def __init__( | |
| self, | |
| dataset_name, | |
| config_name, | |
| split_name, | |
| input_field_path, | |
| label_name, | |
| num_rows, | |
| model_name=_DEFAULT_MODEL, | |
| ): | |
| """Item embeddings and clustering""" | |
| self.dataset_name = dataset_name | |
| self.config_name = config_name | |
| self.split_name = split_name | |
| self.input_field_path = input_field_path | |
| self.label_name = label_name | |
| self.num_rows = num_rows | |
| self.cache_path_list = [ | |
| _CACHE_DIR, | |
| dataset_name.replace("/", "---"), | |
| f"{'default' if config_name is None else config_name}", | |
| f"{'train' if split_name is None else split_name}", | |
| f"field-{'->'.join(input_field_path)}-label-{label_name}", | |
| f"{num_rows}_rows", | |
| model_name.replace("/", "---"), | |
| ] | |
| self.cache_path = pjoin(*self.cache_path_list) | |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| self.model_name = model_name | |
| # prepare embeddings for the dataset | |
| def set_model(self): | |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = transformers.AutoModel.from_pretrained(self.model_name).to( | |
| self.device | |
| ) | |
| def set_features_dataset(self, use_streaming, use_auth_token, use_dataset): | |
| dset, dset_path = prepare_clustering_dataset( | |
| dataset_name=self.dataset_name, | |
| input_field_path=self.input_field_path, | |
| label_name=self.label_name, | |
| config_name=self.config_name, | |
| split_name=self.split_name, | |
| num_rows=self.num_rows, | |
| use_streaming=use_streaming, | |
| use_auth_token=use_auth_token, | |
| use_dataset=use_dataset, | |
| ) | |
| self.features_dset = dset | |
| def compute_feature_embeddings(self, sentences): | |
| batch = self.tokenizer( | |
| sentences, padding=True, truncation=True, return_tensors="pt" | |
| ) | |
| batch = {k: v.to(self.device) for k, v in batch.items()} | |
| with torch.no_grad(): | |
| model_output = self.model(**batch) | |
| sentence_embeds = sentence_mean_pooling( | |
| model_output, batch["attention_mask"] | |
| ) | |
| sentence_embeds /= sentence_embeds.norm(dim=-1, keepdim=True) | |
| return sentence_embeds | |
| def set_embeddings_dataset(self): | |
| def batch_embed(examples): | |
| return { | |
| "embedding": [ | |
| embed.tolist() | |
| for embed in self.compute_feature_embeddings(examples["field"]) | |
| ] | |
| } | |
| if not exists(self.cache_path): | |
| os.mkdir(self.cache_path) | |
| self.embeddings_dset = self.features_dset.map( | |
| batch_embed, | |
| batched=True, | |
| batch_size=32, | |
| cache_file_name=pjoin(self.cache_path, "embeddings_dset"), | |
| ) | |
| def prepare_embeddings( | |
| self, | |
| use_streaming=True, | |
| use_auth_token=None, | |
| use_dataset=None, | |
| ): | |
| self.set_model() | |
| self.set_features_dataset(use_streaming, use_auth_token, use_dataset) | |
| self.set_embeddings_dataset() | |
| # make cluster tree | |
| def prepare_merges(self, batch_size, low_thres): | |
| self.embeddings = torch.Tensor(self.embeddings_dset["embedding"]) | |
| all_indices = torch.LongTensor(torch.Size([0, 2])) | |
| all_scores = torch.Tensor(torch.Size([0])) | |
| n_batches = math.ceil(self.embeddings_dset.num_rows / batch_size) | |
| for a in range(n_batches): | |
| for b in tqdm(range(a, n_batches)): | |
| cos_scores = torch.mm( | |
| self.embeddings[a * batch_size : (a + 1) * batch_size], | |
| self.embeddings[b * batch_size : (b + 1) * batch_size].t(), | |
| ) | |
| if a == b: | |
| cos_scores = cos_scores.triu(diagonal=1) | |
| merge_indices = torch.nonzero(cos_scores > low_thres) | |
| merge_indices[:, 0] += a * batch_size | |
| merge_indices[:, 1] += b * batch_size | |
| merge_scores = cos_scores[cos_scores > low_thres] | |
| all_indices = torch.cat([all_indices, merge_indices], dim=0) | |
| all_scores = torch.cat([all_scores, merge_scores], dim=0) | |
| self.sorted_scores, sorted_score_ids = all_scores.sort(dim=0, descending=True) | |
| self.sorted_scores = self.sorted_scores[:_MAX_MERGE] | |
| sorted_score_ids = sorted_score_ids[:_MAX_MERGE] | |
| self.sorted_indices = all_indices[sorted_score_ids] | |
| def make_starting_nodes(self, identical_threshold): | |
| identical_indices = self.sorted_indices[ | |
| self.sorted_scores >= identical_threshold | |
| ] | |
| identical_inter = identical_indices[ | |
| identical_indices[:, 1].sort(stable=True).indices | |
| ] | |
| identical_sorted = identical_inter[ | |
| identical_inter[:, 0].sort(stable=True).indices | |
| ] | |
| self.parents = {} | |
| for a_pre, b_pre in identical_sorted: | |
| a = a_pre.item() | |
| b = b_pre.item() | |
| while self.parents.get(a, -1) != -1: | |
| a = self.parents[a] | |
| self.parents[b] = a | |
| self.duplicates = {} | |
| for a, b in self.parents.items(): | |
| self.duplicates[b] = self.duplicates.get(b, []) + [a] | |
| self.nodes = {} | |
| for node_id in range(self.features_dset.num_rows): | |
| if node_id in self.parents: | |
| continue | |
| else: | |
| self.nodes[node_id] = { | |
| "node_id": node_id, | |
| "parent_id": -1, | |
| "children": [], | |
| "children_ids": [], | |
| "example_ids": [node_id], | |
| "weight": 1, | |
| "merge_threshold": 0.98, | |
| "depth": 0, | |
| } | |
| def make_merge_nodes(self, identical_threshold, thres_step): | |
| new_node_id = self.features_dset.num_rows | |
| current_thres = identical_threshold | |
| depth = 1 | |
| merge_ids = self.sorted_indices[self.sorted_scores < identical_threshold] | |
| merge_scores = self.sorted_scores[self.sorted_scores < identical_threshold] | |
| for (node_id_a, node_id_b), merge_score in tqdm( | |
| zip(merge_ids, merge_scores), total=len(merge_ids) | |
| ): | |
| if merge_score.item() < current_thres: | |
| current_thres -= thres_step | |
| merge_a = node_id_a.item() | |
| while self.parents.get(merge_a, -1) != -1: | |
| merge_a = self.parents[merge_a] | |
| self.parents[node_id_a] = merge_a | |
| merge_b = node_id_b.item() | |
| while self.parents.get(merge_b, -1) != -1: | |
| merge_b = self.parents[merge_b] | |
| self.parents[node_id_b] = merge_b | |
| if merge_a == merge_b: | |
| continue | |
| else: | |
| merge_b, merge_a = sorted([merge_a, merge_b]) | |
| node_a = self.nodes[merge_a] | |
| node_b = self.nodes[merge_b] | |
| if (node_a["depth"]) > 0 and min( | |
| node_a["merge_threshold"], node_b["merge_threshold"] | |
| ) == current_thres: | |
| node_a["depth"] = max(node_a["depth"], node_b["depth"]) | |
| node_a["weight"] += node_b["weight"] | |
| node_a["children_ids"] += ( | |
| node_b["children_ids"] | |
| if node_b["depth"] > 0 | |
| else [node_b["node_id"]] | |
| ) | |
| for cid in node_b["children_ids"]: | |
| self.nodes[cid]["parent_id"] = node_a["node_id"] | |
| self.parents[cid] = node_a["node_id"] | |
| node_b["parent_id"] = node_a["node_id"] | |
| self.parents[node_b["node_id"]] = node_a["node_id"] | |
| else: | |
| new_nid = new_node_id | |
| new_node_id += 1 | |
| new_node = { | |
| "node_id": new_nid, | |
| "parent_id": -1, | |
| "children_ids": [node_a["node_id"], node_b["node_id"]], | |
| "example_ids": [], | |
| "weight": node_a["weight"] + node_b["weight"], | |
| "merge_threshold": current_thres, | |
| "depth": max(node_a["depth"], node_b["depth"]) + 1, | |
| } | |
| depth = max(depth, new_node["depth"]) | |
| node_a["parent_id"] = new_nid | |
| node_b["parent_id"] = new_nid | |
| self.parents[node_a["node_id"]] = new_nid | |
| self.parents[node_b["node_id"]] = new_nid | |
| self.parents[node_id_a] = new_nid | |
| self.parents[node_id_b] = new_nid | |
| self.nodes[new_nid] = new_node | |
| return new_node_id | |
| def collapse_nodes(self, node, min_weight): | |
| children = [ | |
| self.collapse_nodes(self.nodes[cid], min_weight) | |
| for cid in node["children_ids"] | |
| if self.nodes[cid]["weight"] >= min_weight | |
| ] | |
| extras = [ | |
| lid | |
| for cid in node["children_ids"] | |
| if self.nodes[cid]["weight"] < min_weight | |
| for lid in self.collapse_nodes(self.nodes[cid], min_weight)["example_ids"] | |
| ] + node["example_ids"] | |
| extras_embed = ( | |
| torch.cat( | |
| [self.embeddings[eid][None, :] for eid in extras], | |
| dim=0, | |
| ).sum(dim=0) | |
| if len(extras) > 0 | |
| else torch.zeros(self.embeddings.shape[-1]) | |
| ) | |
| if len(children) == 0: | |
| node["extras"] = extras | |
| node["children_ids"] = [] | |
| node["example_ids"] = extras | |
| node["embedding_sum"] = extras_embed | |
| elif len(children) == 1: | |
| node["extras"] = extras + children[0]["extras"] | |
| node["children_ids"] = children[0]["children_ids"] | |
| node["example_ids"] = extras + children[0]["example_ids"] | |
| node["embedding_sum"] = extras_embed + children[0]["embedding_sum"] | |
| else: | |
| node["extras"] = extras | |
| node["children_ids"] = [child["node_id"] for child in children] | |
| node["example_ids"] = extras + [ | |
| eid for child in children for eid in child["example_ids"] | |
| ] | |
| node["embedding_sum"] = ( | |
| extras_embed | |
| + torch.cat( | |
| [child["embedding_sum"][None, :] for child in children], | |
| dim=0, | |
| ).sum(dim=0) | |
| ) | |
| assert ( | |
| len(node["example_ids"]) == node["weight"] | |
| ), f"stuck at {node['node_id']} - {len(node['example_ids'])} - {node['weight']}" | |
| return node | |
| def finalize_node(self, node, parent_id, n_examplars, with_labels): | |
| new_node_id = len(self.tree_node_list) | |
| new_node = { | |
| "node_id": new_node_id, | |
| "parent_id": parent_id, | |
| "depth": 0 | |
| if parent_id == -1 | |
| else self.tree_node_list[parent_id]["depth"] + 1, | |
| "merged_at": node["merge_threshold"], | |
| "weight": node["weight"], | |
| "is_extra": False, | |
| } | |
| self.tree_node_list += [new_node] | |
| centroid = node["embedding_sum"] / node["embedding_sum"].norm() | |
| new_node["centroid"] = centroid.tolist() | |
| new_node["examplars"] = get_examplars( | |
| node["example_ids"], | |
| centroid, | |
| self.embeddings, | |
| self.features_dset, | |
| n_examplars, | |
| ) | |
| label_counts = {} | |
| if with_labels: | |
| for eid in node["example_ids"]: | |
| label = self.features_dset[eid]["label"] | |
| label_counts[label] = label_counts.get(label, 0) + 1 | |
| new_node["label_counts"] = sorted( | |
| label_counts.items(), key=lambda x: x[1], reverse=True | |
| ) | |
| if len(node["children_ids"]) == 0: | |
| new_node["children_ids"] = [] | |
| else: | |
| children = [ | |
| self.nodes[cid] | |
| for cid in pretty_order(self.nodes, node["children_ids"]) | |
| ] | |
| children_ids = [ | |
| self.finalize_node(child, new_node_id, n_examplars, with_labels) | |
| for child in children | |
| ] | |
| new_node["children_ids"] = children_ids | |
| if len(node["extras"]) > 0: | |
| extra_node = { | |
| "node_id": len(self.tree_node_list), | |
| "parent_id": new_node_id, | |
| "depth": new_node["depth"] + 1, | |
| "merged_at": node["merge_threshold"], | |
| "weight": len(node["extras"]), | |
| "is_extra": True, | |
| "centroid": new_node["centroid"], | |
| "examplars": get_examplars( | |
| node["extras"], | |
| centroid, | |
| self.embeddings, | |
| self.features_dset, | |
| n_examplars, | |
| ), | |
| } | |
| self.tree_node_list += [extra_node] | |
| label_counts = {} | |
| if with_labels: | |
| for eid in node["extras"]: | |
| label = self.features_dset[eid]["label"] | |
| label_counts[label] = label_counts.get(label, 0) + 1 | |
| extra_node["label_counts"] = sorted( | |
| label_counts.items(), key=lambda x: x[1], reverse=True | |
| ) | |
| extra_node["children_ids"] = [] | |
| new_node["children_ids"] += [extra_node["node_id"]] | |
| return new_node_id | |
| def make_hover_text(self, num_examples=5, text_width=64, with_labels=False): | |
| for nid, node in enumerate(self.tree_node_list): | |
| line_list = [ | |
| f"Node {nid:3d} - {node['weight']:6d} items - Linking threshold: {node['merged_at']:.2f}" | |
| ] | |
| for examplar in node["examplars"][:num_examples]: | |
| line_list += [ | |
| f"{examplar['ids']:6d}:{examplar['score']:.2f} - {examplar['field'][:text_width]}" | |
| + (f" - {examplar['label']}" if with_labels else "") | |
| ] | |
| if with_labels: | |
| line_list += ["Label distribution"] | |
| for label, count in node["label_counts"]: | |
| line_list += [f" - label: {label} - {count} items"] | |
| node["hover_text"] = "<br>".join(line_list) | |
| def build_tree( | |
| self, | |
| batch_size=10000, | |
| low_thres=0.5, | |
| identical_threshold=0.95, | |
| thres_step=0.05, | |
| min_weight=10, | |
| n_examplars=25, | |
| hover_examples=5, | |
| hover_text_width=64, | |
| ): | |
| self.prepare_merges(batch_size, low_thres) | |
| self.make_starting_nodes(identical_threshold) | |
| # make a root to join all trees | |
| root_node_id = self.make_merge_nodes(identical_threshold, thres_step) | |
| top_nodes = [node for node in self.nodes.values() if node["parent_id"] == -1] | |
| root_node = { | |
| "node_id": root_node_id, | |
| "parent_id": -1, | |
| "children_ids": [node["node_id"] for node in top_nodes], | |
| "example_ids": [], | |
| "weight": sum([node["weight"] for node in top_nodes]), | |
| "merge_threshold": -1.0, | |
| "depth": 1 + max([node["depth"] for node in top_nodes]), | |
| } | |
| for node in top_nodes: | |
| node["parent_id"] = root_node_id | |
| self.nodes[root_node_id] = root_node | |
| _ = self.collapse_nodes(root_node, min_weight) | |
| self.tree_node_list = [] | |
| self.finalize_node( | |
| root_node, | |
| -1, | |
| n_examplars, | |
| with_labels=(self.label_name is not None), | |
| ) | |
| self.make_hover_text( | |
| num_examples=hover_examples, | |
| text_width=hover_text_width, | |
| with_labels=(self.label_name is not None), | |
| ) | |
| def push_to_hub(self, use_auth_token=None, file_name=None): | |
| path_list = self.cache_path_list | |
| name = "tree" if file_name is None else file_name | |
| tree_file = pjoin(pjoin(*path_list), f"{name}.jsonl.gz") | |
| fout = gzip.open(tree_file, "w") | |
| for node in tqdm(self.tree_node_list): | |
| _ = fout.write((json.dumps(node) + "\n").encode("utf-8")) | |
| fout.close() | |
| api = HfApi() | |
| file_loc = api.upload_file( | |
| path_or_fileobj=tree_file, | |
| path_in_repo=pjoin(pjoin(*path_list[1:]), f"{name}.jsonl.gz"), | |
| repo_id="yjernite/datasets_clusters", | |
| token=use_auth_token, | |
| repo_type="dataset", | |
| ) | |
| return file_loc | |
| class Clustering: | |
| def __init__( | |
| self, | |
| dataset_name, | |
| config_name, | |
| split_name, | |
| input_field_path, | |
| label_name, | |
| num_rows, | |
| n_examplars=10, | |
| model_name=_DEFAULT_MODEL, | |
| file_name=None, | |
| max_depth_subtree=3, | |
| ): | |
| self.dataset_name = dataset_name | |
| self.config_name = config_name | |
| self.split_name = split_name | |
| self.input_field_path = input_field_path | |
| self.label_name = label_name | |
| self.num_rows = num_rows | |
| self.model_name = model_name | |
| self.n_examplars = n_examplars | |
| self.file_name = "tree" if file_name is None else file_name | |
| self.repo_path_list = [ | |
| dataset_name.replace("/", "---"), | |
| f"{'default' if config_name is None else config_name}", | |
| f"{'train' if split_name is None else split_name}", | |
| f"field-{'->'.join(input_field_path)}-label-{label_name}", | |
| f"{num_rows}_rows", | |
| model_name.replace("/", "---"), | |
| f"{self.file_name}.jsonl.gz", | |
| ] | |
| self.repo_path = pjoin(*self.repo_path_list) | |
| self.node_list = load_dataset( | |
| "yjernite/datasets_clusters", data_files=[self.repo_path] | |
| )["train"] | |
| self.node_reps = [{} for node in self.node_list] | |
| self.max_depth_subtree = max_depth_subtree | |
| def set_full_tree(self): | |
| self.node_reps[0]["tree"] = self.node_reps[0].get( | |
| "tree", | |
| make_tree_plot( | |
| self.node_list, | |
| 0, | |
| ), | |
| ) | |
| def get_full_tree(self): | |
| self.set_full_tree() | |
| return self.node_reps[0]["tree"] | |
| def set_node_subtree(self, node_id): | |
| self.node_reps[node_id]["subtree"] = self.node_reps[node_id].get( | |
| "subtree", | |
| make_tree_plot( | |
| self.node_list, | |
| node_id, | |
| self.max_depth_subtree, | |
| ), | |
| ) | |
| def get_node_subtree(self, node_id): | |
| self.set_node_subtree(node_id) | |
| return self.node_reps[node_id]["subtree"] | |
| def set_node_examplars(self, node_id): | |
| self.node_reps[node_id]["examplars"] = self.node_reps[node_id].get( | |
| "examplars", | |
| pd.DataFrame( | |
| [ | |
| { | |
| "id": exple["ids"], | |
| "score": exple["score"], | |
| "field": exple["field"], | |
| "label": exple.get("label", "N/A"), | |
| } | |
| for exple in self.node_list[node_id]["examplars"] | |
| ][: self.n_examplars] | |
| ), | |
| ) | |
| def get_node_examplars(self, node_id): | |
| self.set_node_examplars(node_id) | |
| return self.node_reps[node_id]["examplars"] | |
| def set_node_label_chart(self, node_id): | |
| self.node_reps[node_id]["label_chart"] = self.node_reps[node_id].get( | |
| "label_chart", | |
| px.pie( | |
| values=[ct for lab, ct in self.node_list[node_id]["label_counts"]], | |
| names=[ | |
| f"Label {lab}" | |
| for lab, ct in self.node_list[node_id]["label_counts"] | |
| ], | |
| color_discrete_sequence=px.colors.sequential.Rainbow, | |
| width=400, | |
| height=400, | |
| ), | |
| ) | |
| def get_node_label_chart(self, node_id): | |
| self.set_node_label_chart(node_id) | |
| return self.node_reps[node_id]["label_chart"] | |