AbeerTrial's picture
Upload folder using huggingface_hub
8a58cf3
raw
history blame
8.37 kB
"""Leaf query mechanism."""
import logging
from typing import Any, Dict, Optional, cast
from gpt_index.data_structs.data_structs import IndexGraph, Node
from gpt_index.indices.query.base import BaseGPTIndexQuery
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.response.builder import ResponseBuilder
from gpt_index.indices.utils import (
extract_numbers_given_response,
get_sorted_node_list,
truncate_text,
)
from gpt_index.prompts.default_prompts import (
DEFAULT_QUERY_PROMPT,
DEFAULT_QUERY_PROMPT_MULTIPLE,
)
from gpt_index.prompts.prompts import TreeSelectMultiplePrompt, TreeSelectPrompt
from gpt_index.response.schema import Response
class GPTTreeIndexLeafQuery(BaseGPTIndexQuery[IndexGraph]):
"""GPT Tree Index leaf query.
This class traverses the index graph and searches for a leaf node that can best
answer the query.
.. code-block:: python
response = index.query("<query_str>", mode="default")
Args:
query_template (Optional[TreeSelectPrompt]): Tree Select Query Prompt
(see :ref:`Prompt-Templates`).
query_template_multiple (Optional[TreeSelectMultiplePrompt]): Tree Select
Query Prompt (Multiple)
(see :ref:`Prompt-Templates`).
child_branch_factor (int): Number of child nodes to consider at each level.
If child_branch_factor is 1, then the query will only choose one child node
to traverse for any given parent node.
If child_branch_factor is 2, then the query will choose two child nodes.
"""
def __init__(
self,
index_struct: IndexGraph,
query_template: Optional[TreeSelectPrompt] = None,
query_template_multiple: Optional[TreeSelectMultiplePrompt] = None,
child_branch_factor: int = 1,
**kwargs: Any,
) -> None:
"""Initialize params."""
super().__init__(index_struct, **kwargs)
self.query_template = query_template or DEFAULT_QUERY_PROMPT
self.query_template_multiple = (
query_template_multiple or DEFAULT_QUERY_PROMPT_MULTIPLE
)
self.child_branch_factor = child_branch_factor
def _query_with_selected_node(
self,
selected_node: Node,
query_bundle: QueryBundle,
prev_response: Optional[str] = None,
level: int = 0,
) -> str:
"""Get response for selected node.
If not leaf node, it will recursively call _query on the child nodes.
If prev_response is provided, we will update prev_response with the answer.
"""
query_str = query_bundle.query_str
if len(selected_node.child_indices) == 0:
response_builder = ResponseBuilder(
self._prompt_helper,
self._llm_predictor,
self.text_qa_template,
self.refine_template,
)
self.response_builder.add_node_as_source(selected_node)
# use response builder to get answer from node
node_text, sub_response = self._get_text_from_node(
query_bundle, selected_node, level=level
)
if sub_response is not None:
# these are source nodes from within this node (when it's an index)
for source_node in sub_response.source_nodes:
self.response_builder.add_source_node(source_node)
cur_response = response_builder.get_response_over_chunks(
query_str, [node_text], prev_response=prev_response
)
cur_response = cast(str, cur_response)
logging.debug(f">[Level {level}] Current answer response: {cur_response} ")
else:
cur_response = self._query_level(
{
i: self.index_struct.all_nodes[i]
for i in selected_node.child_indices
},
query_bundle,
level=level + 1,
)
if prev_response is None:
return cur_response
else:
context_msg = "\n".join([selected_node.get_text(), cur_response])
cur_response, formatted_refine_prompt = self._llm_predictor.predict(
self.refine_template,
query_str=query_str,
existing_answer=prev_response,
context_msg=context_msg,
)
logging.debug(f">[Level {level}] Refine prompt: {formatted_refine_prompt}")
logging.debug(f">[Level {level}] Current refined response: {cur_response} ")
return cur_response
def _query_level(
self,
cur_nodes: Dict[int, Node],
query_bundle: QueryBundle,
level: int = 0,
) -> str:
"""Answer a query recursively."""
query_str = query_bundle.query_str
cur_node_list = get_sorted_node_list(cur_nodes)
if len(cur_node_list) == 1:
logging.debug(f">[Level {level}] Only one node left. Querying node.")
return self._query_with_selected_node(
cur_node_list[0], query_bundle, level=level
)
elif self.child_branch_factor == 1:
query_template = self.query_template.partial_format(
num_chunks=len(cur_node_list), query_str=query_str
)
numbered_node_text = self._prompt_helper.get_numbered_text_from_nodes(
cur_node_list, prompt=query_template
)
response, formatted_query_prompt = self._llm_predictor.predict(
query_template,
context_list=numbered_node_text,
)
else:
query_template_multiple = self.query_template_multiple.partial_format(
num_chunks=len(cur_node_list),
query_str=query_str,
branching_factor=self.child_branch_factor,
)
numbered_node_text = self._prompt_helper.get_numbered_text_from_nodes(
cur_node_list, prompt=query_template_multiple
)
response, formatted_query_prompt = self._llm_predictor.predict(
query_template_multiple,
context_list=numbered_node_text,
)
logging.debug(
f">[Level {level}] current prompt template: {formatted_query_prompt}"
)
numbers = extract_numbers_given_response(response, n=self.child_branch_factor)
if numbers is None:
logging.debug(
f">[Level {level}] Could not retrieve response - no numbers present"
)
# just join text from current nodes as response
return response
result_response = None
for number_str in numbers:
number = int(number_str)
if number > len(cur_node_list):
logging.debug(
f">[Level {level}] Invalid response: {response} - "
f"number {number} out of range"
)
return response
# number is 1-indexed, so subtract 1
selected_node = cur_node_list[number - 1]
logging.info(
f">[Level {level}] Selected node: "
f"[{number}]/[{','.join([str(int(n)) for n in numbers])}]"
)
debug_str = " ".join(selected_node.get_text().splitlines())
logging.debug(
f">[Level {level}] Node "
f"[{number}] Summary text: "
f"{ truncate_text(debug_str, 100) }"
)
result_response = self._query_with_selected_node(
selected_node,
query_bundle,
prev_response=result_response,
level=level,
)
# result_response should not be None
return cast(str, result_response)
def _query(self, query_bundle: QueryBundle) -> Response:
"""Answer a query."""
# NOTE: this overrides the _query method in the base class
logging.info(f"> Starting query: {query_bundle.query_str}")
response_str = self._query_level(
self.index_struct.root_nodes,
query_bundle,
level=0,
).strip()
return Response(response_str, source_nodes=self.response_builder.get_sources())