Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 10,328 Bytes
81cdd5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import List
from PIL import Image
from langchain.docstore.document import Document as LangchainDocument
from .knowledge_base import KnowledgeBase
logger = logging.getLogger(__name__)
def format_context_messages_to_string(context_messages: list[dict]) -> str:
"""Takes a list of context message dicts and formats them into a single string."""
if not context_messages:
return "No relevant context was retrieved from the guideline document."
full_text = [
msg.get("text", "") for msg in context_messages if msg.get("type") == "text"
]
return "\n".join(full_text)
class RAGContextEngine:
"""Uses a pre-built KnowledgeBase to retrieve and format context for queries."""
def __init__(self, knowledge_base: KnowledgeBase, config_overrides: dict | None = None):
if not isinstance(knowledge_base, KnowledgeBase) or not knowledge_base.retriever:
raise ValueError("An initialized KnowledgeBase with a built retriever is required.")
self.kb = knowledge_base
self.config = self._get_default_config()
if config_overrides:
self.config.update(config_overrides)
def _get_default_config(self):
return {
"FINAL_CONTEXT_TOP_K": 5,
"CONTEXT_SELECTION_STRATEGY": "chapter_aware_window_expansion",
"CONTEXT_WINDOW_SIZE": 0,
"ADD_MAPPED_FIGURES_TO_PROMPT": False,
}
def get_context_messages(self, query_text: str) -> list[dict] | None:
"""Public API to get final, formatted context messages for a long query."""
final_context_docs = self.retrieve_context_docs(query_text)
if not final_context_docs:
logger.warning(f"No relevant context found for query: {query_text}")
return None
context_messages, _ = self.build_context_messages(final_context_docs)
return context_messages
def retrieve_context_docs(self, query_text: str) -> list:
"""Handles both short and long queries to retrieve context documents."""
logger.info(f"Retrieving context documents with query: {query_text}")
if len(query_text.split()) > 5:
logger.info("Long query detected. Decomposing into sub-queries...")
temp_doc = LangchainDocument(page_content=query_text)
enriched_temp_docs = self.kb.document_enricher([temp_doc], summarize=False)
query_chunks_as_docs = self.kb.chunker(enriched_docs=enriched_temp_docs, display_results=False)
sub_queries = list(set([doc.page_content for doc in query_chunks_as_docs]))
else:
logger.info("Short query detected. Using direct retrieval.")
sub_queries = [query_text]
return self.retrieve_context_docs_for_simple_queries(sub_queries)
def get_context_messages_for_simple_queries(self, queries: list[str]) -> list:
"""Retrieves context docs and builds them into formatted messages."""
final_context_docs = self.retrieve_context_docs_for_simple_queries(queries)
if not final_context_docs:
logger.warning(f"No relevant context found for queries: {queries}")
return []
context_messages, _ = self.build_context_messages(final_context_docs)
return context_messages
def retrieve_context_docs_for_simple_queries(self, queries: list[str]) -> list:
"""Invokes the retriever for a list of simple queries and selects the final documents."""
logger.info(f"Retrieving context documents with simple queries: {queries}")
retrieved_docs = []
for query in queries:
docs = self.kb.retriever.invoke(query)
retrieved_docs.extend(docs)
return RAGContextEngine.select_final_context(
retrieved_docs=retrieved_docs,
config=self.config,
page_map=self.kb.page_map,
)
def build_context_messages(
self, docs: List[LangchainDocument]
) -> tuple[list[dict], list[Image.Image]]:
"""Builds a structured list of messages by grouping consecutive text blocks."""
if not docs:
return [], []
context_messages = []
images_found = []
prose_buffer = []
def flush_prose_buffer():
if prose_buffer:
full_prose = "\n\n".join(prose_buffer)
context_messages.append({"type": "text", "text": full_prose})
prose_buffer.clear()
add_images = self.config.get("ADD_MAPPED_FIGURES_TO_PROMPT", False)
for i, doc in enumerate(docs):
current_page = doc.metadata.get("page_number")
is_new_page = (i > 0) and (current_page != docs[i - 1].metadata.get("page_number"))
is_caption = doc.metadata.get("chunk_type") == "figure-caption"
if is_new_page or (add_images and is_caption):
flush_prose_buffer()
if add_images and is_caption:
source_info = f"--- Source: Page {current_page} ---"
caption_text = f"{source_info}\n{doc.page_content}"
context_messages.append({"type": "text", "text": caption_text})
image_path = doc.metadata.get("linked_figure_path")
if image_path and os.path.exists(image_path):
try:
image = Image.open(image_path).convert("RGB")
context_messages.append({"type": "image", "image": image})
images_found.append(image)
except Exception as e:
logger.warning(f"Could not load image {image_path}: {e}")
else:
if not prose_buffer:
source_info = f"--- Source: Page {current_page} ---"
prose_buffer.append(f"\n{source_info}\n")
prose_buffer.append(doc.page_content)
flush_prose_buffer()
return context_messages, images_found
@staticmethod
def select_final_context(retrieved_docs: list, config: dict, page_map: dict) -> list:
"""Selects final context from retrieved documents using the specified strategy."""
strategy = config.get("CONTEXT_SELECTION_STRATEGY")
top_k = config.get("FINAL_CONTEXT_TOP_K", 5)
def _calculate_block_frequencies(docs_list: list) -> list:
blocks = {}
for doc in docs_list:
if block_id := doc.metadata.get("block_id"):
if block_id not in blocks:
blocks[block_id] = []
blocks[block_id].append(doc)
return sorted(blocks.items(), key=lambda item: len(item[1]), reverse=True)
def _expand_chunks_to_blocks(chunks: list) -> list:
return [
LangchainDocument(
page_content=c.metadata.get("original_block_text", c.page_content),
metadata=c.metadata,
)
for c in chunks
]
final_context = []
if strategy == "chapter_aware_window_expansion":
if not retrieved_docs or not page_map:
return []
scored_blocks = _calculate_block_frequencies(retrieved_docs)
if not scored_blocks:
return _expand_chunks_to_blocks(retrieved_docs[:top_k])
primary_hit_page = scored_blocks[0][1][0].metadata.get("page_number")
important_pages = {
c[0].metadata.get("page_number")
for _, c in scored_blocks[:top_k]
if c and c[0].metadata.get("page_number")
}
window_size = config.get("CONTEXT_WINDOW_SIZE", 0)
pages_to_extract = set()
for page_num in important_pages:
current_chapter_info = page_map.get(page_num)
if not current_chapter_info:
continue
current_chapter_id = current_chapter_info["chapter_id"]
pages_to_extract.add(page_num)
for i in range(1, window_size + 1):
if (prev_info := page_map.get(page_num - i)) and prev_info["chapter_id"] == current_chapter_id:
pages_to_extract.add(page_num - i)
if (next_info := page_map.get(page_num + i)) and next_info["chapter_id"] == current_chapter_id:
pages_to_extract.add(page_num + i)
sorted_pages = sorted(list(pages_to_extract))
if primary_hit_page and primary_hit_page in page_map:
final_context.extend(page_map[primary_hit_page]["blocks"])
for page_num in sorted_pages:
if page_num != primary_hit_page and page_num in page_map:
final_context.extend(page_map[page_num]["blocks"])
elif strategy == "rerank_by_frequency":
scored_blocks = _calculate_block_frequencies(retrieved_docs)
representative_chunks = [chunks[0] for _, chunks in scored_blocks[:top_k]]
final_context = _expand_chunks_to_blocks(representative_chunks)
elif strategy == "select_by_rank":
unique_docs_map = {f"{doc.metadata.get('block_id', '')}_{doc.page_content}": doc for doc in retrieved_docs}
representative_chunks = list(unique_docs_map.values())[:top_k]
final_context = _expand_chunks_to_blocks(representative_chunks)
else:
logger.warning(f"Unknown strategy '{strategy}'. Defaulting to top-k raw chunks.")
final_context = retrieved_docs[:top_k]
logger.info(f"Selected {len(final_context)} final context blocks using '{strategy}' strategy.")
return final_context
|