Spaces:
Runtime error
Runtime error
File size: 4,988 Bytes
8a58cf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""Query Tree using embedding similarity between query and node text."""
import logging
from typing import Any, Dict, List, Optional, Tuple, cast
from gpt_index.data_structs.data_structs import IndexGraph, Node
from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.query.tree.leaf_query import GPTTreeIndexLeafQuery
from gpt_index.indices.utils import get_sorted_node_list
from gpt_index.prompts.prompts import TreeSelectMultiplePrompt, TreeSelectPrompt
class GPTTreeIndexEmbeddingQuery(GPTTreeIndexLeafQuery):
"""
GPT Tree Index embedding query.
This class traverses the index graph using the embedding similarity between the
query and the node text.
.. code-block:: python
response = index.query("<query_str>", mode="embedding")
Args:
query_template (Optional[TreeSelectPrompt]): Tree Select Query Prompt
(see :ref:`Prompt-Templates`).
query_template_multiple (Optional[TreeSelectMultiplePrompt]): Tree Select
Query Prompt (Multiple)
(see :ref:`Prompt-Templates`).
text_qa_template (Optional[QuestionAnswerPrompt]): Question-Answer Prompt
(see :ref:`Prompt-Templates`).
refine_template (Optional[RefinePrompt]): Refinement Prompt
(see :ref:`Prompt-Templates`).
child_branch_factor (int): Number of child nodes to consider at each level.
If child_branch_factor is 1, then the query will only choose one child node
to traverse for any given parent node.
If child_branch_factor is 2, then the query will choose two child nodes.
embed_model (Optional[BaseEmbedding]): Embedding model to use for
embedding similarity.
"""
def __init__(
self,
index_struct: IndexGraph,
query_template: Optional[TreeSelectPrompt] = None,
query_template_multiple: Optional[TreeSelectMultiplePrompt] = None,
child_branch_factor: int = 1,
embed_model: Optional[BaseEmbedding] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
super().__init__(
index_struct,
query_template=query_template,
query_template_multiple=query_template_multiple,
child_branch_factor=child_branch_factor,
embed_model=embed_model,
**kwargs,
)
self.child_branch_factor = child_branch_factor
def _query_level(
self,
cur_nodes: Dict[int, Node],
query_bundle: QueryBundle,
level: int = 0,
) -> str:
cur_node_list = get_sorted_node_list(cur_nodes)
# Get the node with the highest similarity to the query
selected_nodes, selected_indices = self._get_most_similar_nodes(
cur_node_list, query_bundle
)
result_response = None
for node, index in zip(selected_nodes, selected_indices):
logging.debug(
f">[Level {level}] Node [{index+1}] Summary text: "
f"{' '.join(node.get_text().splitlines())}"
)
# Get the response for the selected node
result_response = self._query_with_selected_node(
node, query_bundle, level=level, prev_response=result_response
)
return cast(str, result_response)
def _get_query_text_embedding_similarities(
self, query_bundle: QueryBundle, nodes: List[Node]
) -> List[float]:
"""
Get query text embedding similarity.
Cache the query embedding and the node text embedding.
"""
query_embedding = self._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
similarities = []
for node in nodes:
if node.embedding is not None:
text_embedding = node.embedding
else:
text_embedding = self._embed_model.get_text_embedding(node.get_text())
node.embedding = text_embedding
similarity = self._embed_model.similarity(query_embedding, text_embedding)
similarities.append(similarity)
return similarities
def _get_most_similar_nodes(
self, nodes: List[Node], query_bundle: QueryBundle
) -> Tuple[List[Node], List[int]]:
"""Get the node with the highest similarity to the query."""
similarities = self._get_query_text_embedding_similarities(query_bundle, nodes)
selected_nodes: List[Node] = []
selected_indices: List[int] = []
for node, _ in sorted(
zip(nodes, similarities), key=lambda x: x[1], reverse=True
):
if len(selected_nodes) < self.child_branch_factor:
selected_nodes.append(node)
selected_indices.append(nodes.index(node))
else:
break
return selected_nodes, selected_indices
|