AbeerTrial's picture
Upload folder using huggingface_hub
8a58cf3
raw
history blame
6.99 kB
"""Common classes/functions for tree index operations."""
import asyncio
import logging
from typing import Dict, List, Sequence, Tuple
from gpt_index.async_utils import run_async_tasks
from gpt_index.data_structs.data_structs import IndexGraph, Node
from gpt_index.indices.node_utils import get_text_splits_from_document
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.utils import get_sorted_node_list
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.text_splitter import TextSplitter
from gpt_index.prompts.prompts import SummaryPrompt
from gpt_index.schema import BaseDocument
class GPTTreeIndexBuilder:
"""GPT tree index builder.
Helper class to build the tree-structured index,
or to synthesize an answer.
"""
def __init__(
self,
num_children: int,
summary_prompt: SummaryPrompt,
llm_predictor: LLMPredictor,
prompt_helper: PromptHelper,
text_splitter: TextSplitter,
use_async: bool = False,
) -> None:
"""Initialize with params."""
if num_children < 2:
raise ValueError("Invalid number of children.")
self.num_children = num_children
self.summary_prompt = summary_prompt
self._llm_predictor = llm_predictor
self._prompt_helper = prompt_helper
self._text_splitter = text_splitter
self._use_async = use_async
def _get_nodes_from_document(
self, start_idx: int, document: BaseDocument
) -> Dict[int, Node]:
"""Add document to index."""
# NOTE: summary prompt does not need to be partially formatted
text_splits = get_text_splits_from_document(
document=document, text_splitter=self._text_splitter
)
text_chunks = [text_split.text_chunk for text_split in text_splits]
doc_nodes = {
(start_idx + i): Node(
text=t,
index=(start_idx + i),
ref_doc_id=document.get_doc_id(),
embedding=document.embedding,
extra_info=document.extra_info,
)
for i, t in enumerate(text_chunks)
}
return doc_nodes
def build_from_text(
self,
documents: Sequence[BaseDocument],
build_tree: bool = True,
) -> IndexGraph:
"""Build from text.
Returns:
IndexGraph: graph object consisting of all_nodes, root_nodes
"""
all_nodes: Dict[int, Node] = {}
for d in documents:
all_nodes.update(self._get_nodes_from_document(len(all_nodes), d))
if build_tree:
# instantiate all_nodes from initial text chunks
root_nodes = self.build_index_from_nodes(all_nodes, all_nodes)
else:
# if build_tree is False, then don't surface any root nodes
root_nodes = {}
return IndexGraph(all_nodes=all_nodes, root_nodes=root_nodes)
def _prepare_node_and_text_chunks(
self, cur_nodes: Dict[int, Node]
) -> Tuple[List[int], List[List[Node]], List[str]]:
"""Prepare node and text chunks."""
cur_node_list = get_sorted_node_list(cur_nodes)
logging.info(
f"> Building index from nodes: {len(cur_nodes) // self.num_children} chunks"
)
indices, cur_nodes_chunks, text_chunks = [], [], []
for i in range(0, len(cur_node_list), self.num_children):
cur_nodes_chunk = cur_node_list[i : i + self.num_children]
text_chunk = self._prompt_helper.get_text_from_nodes(
cur_nodes_chunk, prompt=self.summary_prompt
)
indices.append(i)
cur_nodes_chunks.append(cur_nodes_chunk)
text_chunks.append(text_chunk)
return indices, cur_nodes_chunks, text_chunks
def _construct_parent_nodes(
self,
cur_index: int,
indices: List[int],
cur_nodes_chunks: List[List[Node]],
summaries: List[str],
) -> Dict[int, Node]:
"""Construct parent nodes."""
new_node_dict = {}
for i, cur_nodes_chunk, new_summary in zip(
indices, cur_nodes_chunks, summaries
):
logging.debug(
f"> {i}/{len(cur_nodes_chunk)}, "
"summary: {truncate_text(new_summary, 50)}"
)
new_node = Node(
text=new_summary,
index=cur_index,
child_indices={n.index for n in cur_nodes_chunk},
)
new_node_dict[cur_index] = new_node
cur_index += 1
return new_node_dict
def build_index_from_nodes(
self,
cur_nodes: Dict[int, Node],
all_nodes: Dict[int, Node],
) -> Dict[int, Node]:
"""Consolidates chunks recursively, in a bottoms-up fashion."""
cur_index = len(all_nodes)
indices, cur_nodes_chunks, text_chunks = self._prepare_node_and_text_chunks(
cur_nodes
)
if self._use_async:
tasks = [
self._llm_predictor.apredict(
self.summary_prompt, context_str=text_chunk
)
for text_chunk in text_chunks
]
outputs: List[Tuple[str, str]] = run_async_tasks(tasks)
summaries = [output[0] for output in outputs]
else:
summaries = [
self._llm_predictor.predict(
self.summary_prompt, context_str=text_chunk
)[0]
for text_chunk in text_chunks
]
new_node_dict = self._construct_parent_nodes(
cur_index, indices, cur_nodes_chunks, summaries
)
all_nodes.update(new_node_dict)
if len(new_node_dict) <= self.num_children:
return new_node_dict
else:
return self.build_index_from_nodes(new_node_dict, all_nodes)
async def abuild_index_from_nodes(
self,
cur_nodes: Dict[int, Node],
all_nodes: Dict[int, Node],
) -> Dict[int, Node]:
"""Consolidates chunks recursively, in a bottoms-up fashion."""
cur_index = len(all_nodes)
indices, cur_nodes_chunks, text_chunks = self._prepare_node_and_text_chunks(
cur_nodes
)
tasks = [
self._llm_predictor.apredict(self.summary_prompt, context_str=text_chunk)
for text_chunk in text_chunks
]
outputs: List[Tuple[str, str]] = await asyncio.gather(*tasks)
summaries = [output[0] for output in outputs]
new_node_dict = self._construct_parent_nodes(
cur_index, indices, cur_nodes_chunks, summaries
)
all_nodes.update(new_node_dict)
if len(new_node_dict) <= self.num_children:
return new_node_dict
else:
return self.build_index_from_nodes(new_node_dict, all_nodes)