"""Leaf query mechanism.""" import logging from typing import Any, Dict, Optional, cast from gpt_index.data_structs.data_structs import IndexGraph, Node from gpt_index.indices.query.base import BaseGPTIndexQuery from gpt_index.indices.query.schema import QueryBundle from gpt_index.indices.response.builder import ResponseBuilder from gpt_index.indices.utils import ( extract_numbers_given_response, get_sorted_node_list, truncate_text, ) from gpt_index.prompts.default_prompts import ( DEFAULT_QUERY_PROMPT, DEFAULT_QUERY_PROMPT_MULTIPLE, ) from gpt_index.prompts.prompts import TreeSelectMultiplePrompt, TreeSelectPrompt from gpt_index.response.schema import Response class GPTTreeIndexLeafQuery(BaseGPTIndexQuery[IndexGraph]): """GPT Tree Index leaf query. This class traverses the index graph and searches for a leaf node that can best answer the query. .. code-block:: python response = index.query("", mode="default") Args: query_template (Optional[TreeSelectPrompt]): Tree Select Query Prompt (see :ref:`Prompt-Templates`). query_template_multiple (Optional[TreeSelectMultiplePrompt]): Tree Select Query Prompt (Multiple) (see :ref:`Prompt-Templates`). child_branch_factor (int): Number of child nodes to consider at each level. If child_branch_factor is 1, then the query will only choose one child node to traverse for any given parent node. If child_branch_factor is 2, then the query will choose two child nodes. """ def __init__( self, index_struct: IndexGraph, query_template: Optional[TreeSelectPrompt] = None, query_template_multiple: Optional[TreeSelectMultiplePrompt] = None, child_branch_factor: int = 1, **kwargs: Any, ) -> None: """Initialize params.""" super().__init__(index_struct, **kwargs) self.query_template = query_template or DEFAULT_QUERY_PROMPT self.query_template_multiple = ( query_template_multiple or DEFAULT_QUERY_PROMPT_MULTIPLE ) self.child_branch_factor = child_branch_factor def _query_with_selected_node( self, selected_node: Node, query_bundle: QueryBundle, prev_response: Optional[str] = None, level: int = 0, ) -> str: """Get response for selected node. If not leaf node, it will recursively call _query on the child nodes. If prev_response is provided, we will update prev_response with the answer. """ query_str = query_bundle.query_str if len(selected_node.child_indices) == 0: response_builder = ResponseBuilder( self._prompt_helper, self._llm_predictor, self.text_qa_template, self.refine_template, ) self.response_builder.add_node_as_source(selected_node) # use response builder to get answer from node node_text, sub_response = self._get_text_from_node( query_bundle, selected_node, level=level ) if sub_response is not None: # these are source nodes from within this node (when it's an index) for source_node in sub_response.source_nodes: self.response_builder.add_source_node(source_node) cur_response = response_builder.get_response_over_chunks( query_str, [node_text], prev_response=prev_response ) cur_response = cast(str, cur_response) logging.debug(f">[Level {level}] Current answer response: {cur_response} ") else: cur_response = self._query_level( { i: self.index_struct.all_nodes[i] for i in selected_node.child_indices }, query_bundle, level=level + 1, ) if prev_response is None: return cur_response else: context_msg = "\n".join([selected_node.get_text(), cur_response]) cur_response, formatted_refine_prompt = self._llm_predictor.predict( self.refine_template, query_str=query_str, existing_answer=prev_response, context_msg=context_msg, ) logging.debug(f">[Level {level}] Refine prompt: {formatted_refine_prompt}") logging.debug(f">[Level {level}] Current refined response: {cur_response} ") return cur_response def _query_level( self, cur_nodes: Dict[int, Node], query_bundle: QueryBundle, level: int = 0, ) -> str: """Answer a query recursively.""" query_str = query_bundle.query_str cur_node_list = get_sorted_node_list(cur_nodes) if len(cur_node_list) == 1: logging.debug(f">[Level {level}] Only one node left. Querying node.") return self._query_with_selected_node( cur_node_list[0], query_bundle, level=level ) elif self.child_branch_factor == 1: query_template = self.query_template.partial_format( num_chunks=len(cur_node_list), query_str=query_str ) numbered_node_text = self._prompt_helper.get_numbered_text_from_nodes( cur_node_list, prompt=query_template ) response, formatted_query_prompt = self._llm_predictor.predict( query_template, context_list=numbered_node_text, ) else: query_template_multiple = self.query_template_multiple.partial_format( num_chunks=len(cur_node_list), query_str=query_str, branching_factor=self.child_branch_factor, ) numbered_node_text = self._prompt_helper.get_numbered_text_from_nodes( cur_node_list, prompt=query_template_multiple ) response, formatted_query_prompt = self._llm_predictor.predict( query_template_multiple, context_list=numbered_node_text, ) logging.debug( f">[Level {level}] current prompt template: {formatted_query_prompt}" ) numbers = extract_numbers_given_response(response, n=self.child_branch_factor) if numbers is None: logging.debug( f">[Level {level}] Could not retrieve response - no numbers present" ) # just join text from current nodes as response return response result_response = None for number_str in numbers: number = int(number_str) if number > len(cur_node_list): logging.debug( f">[Level {level}] Invalid response: {response} - " f"number {number} out of range" ) return response # number is 1-indexed, so subtract 1 selected_node = cur_node_list[number - 1] logging.info( f">[Level {level}] Selected node: " f"[{number}]/[{','.join([str(int(n)) for n in numbers])}]" ) debug_str = " ".join(selected_node.get_text().splitlines()) logging.debug( f">[Level {level}] Node " f"[{number}] Summary text: " f"{ truncate_text(debug_str, 100) }" ) result_response = self._query_with_selected_node( selected_node, query_bundle, prev_response=result_response, level=level, ) # result_response should not be None return cast(str, result_response) def _query(self, query_bundle: QueryBundle) -> Response: """Answer a query.""" # NOTE: this overrides the _query method in the base class logging.info(f"> Starting query: {query_bundle.query_str}") response_str = self._query_level( self.index_struct.root_nodes, query_bundle, level=0, ).strip() return Response(response_str, source_nodes=self.response_builder.get_sources())