rad_learning_companion / backend /rag /rag_context_engine.py
chandru1652's picture
Initial public commit
81cdd5f
raw
history blame
10.3 kB
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import List
from PIL import Image
from langchain.docstore.document import Document as LangchainDocument
from .knowledge_base import KnowledgeBase
logger = logging.getLogger(__name__)
def format_context_messages_to_string(context_messages: list[dict]) -> str:
"""Takes a list of context message dicts and formats them into a single string."""
if not context_messages:
return "No relevant context was retrieved from the guideline document."
full_text = [
msg.get("text", "") for msg in context_messages if msg.get("type") == "text"
]
return "\n".join(full_text)
class RAGContextEngine:
"""Uses a pre-built KnowledgeBase to retrieve and format context for queries."""
def __init__(self, knowledge_base: KnowledgeBase, config_overrides: dict | None = None):
if not isinstance(knowledge_base, KnowledgeBase) or not knowledge_base.retriever:
raise ValueError("An initialized KnowledgeBase with a built retriever is required.")
self.kb = knowledge_base
self.config = self._get_default_config()
if config_overrides:
self.config.update(config_overrides)
def _get_default_config(self):
return {
"FINAL_CONTEXT_TOP_K": 5,
"CONTEXT_SELECTION_STRATEGY": "chapter_aware_window_expansion",
"CONTEXT_WINDOW_SIZE": 0,
"ADD_MAPPED_FIGURES_TO_PROMPT": False,
}
def get_context_messages(self, query_text: str) -> list[dict] | None:
"""Public API to get final, formatted context messages for a long query."""
final_context_docs = self.retrieve_context_docs(query_text)
if not final_context_docs:
logger.warning(f"No relevant context found for query: {query_text}")
return None
context_messages, _ = self.build_context_messages(final_context_docs)
return context_messages
def retrieve_context_docs(self, query_text: str) -> list:
"""Handles both short and long queries to retrieve context documents."""
logger.info(f"Retrieving context documents with query: {query_text}")
if len(query_text.split()) > 5:
logger.info("Long query detected. Decomposing into sub-queries...")
temp_doc = LangchainDocument(page_content=query_text)
enriched_temp_docs = self.kb.document_enricher([temp_doc], summarize=False)
query_chunks_as_docs = self.kb.chunker(enriched_docs=enriched_temp_docs, display_results=False)
sub_queries = list(set([doc.page_content for doc in query_chunks_as_docs]))
else:
logger.info("Short query detected. Using direct retrieval.")
sub_queries = [query_text]
return self.retrieve_context_docs_for_simple_queries(sub_queries)
def get_context_messages_for_simple_queries(self, queries: list[str]) -> list:
"""Retrieves context docs and builds them into formatted messages."""
final_context_docs = self.retrieve_context_docs_for_simple_queries(queries)
if not final_context_docs:
logger.warning(f"No relevant context found for queries: {queries}")
return []
context_messages, _ = self.build_context_messages(final_context_docs)
return context_messages
def retrieve_context_docs_for_simple_queries(self, queries: list[str]) -> list:
"""Invokes the retriever for a list of simple queries and selects the final documents."""
logger.info(f"Retrieving context documents with simple queries: {queries}")
retrieved_docs = []
for query in queries:
docs = self.kb.retriever.invoke(query)
retrieved_docs.extend(docs)
return RAGContextEngine.select_final_context(
retrieved_docs=retrieved_docs,
config=self.config,
page_map=self.kb.page_map,
)
def build_context_messages(
self, docs: List[LangchainDocument]
) -> tuple[list[dict], list[Image.Image]]:
"""Builds a structured list of messages by grouping consecutive text blocks."""
if not docs:
return [], []
context_messages = []
images_found = []
prose_buffer = []
def flush_prose_buffer():
if prose_buffer:
full_prose = "\n\n".join(prose_buffer)
context_messages.append({"type": "text", "text": full_prose})
prose_buffer.clear()
add_images = self.config.get("ADD_MAPPED_FIGURES_TO_PROMPT", False)
for i, doc in enumerate(docs):
current_page = doc.metadata.get("page_number")
is_new_page = (i > 0) and (current_page != docs[i - 1].metadata.get("page_number"))
is_caption = doc.metadata.get("chunk_type") == "figure-caption"
if is_new_page or (add_images and is_caption):
flush_prose_buffer()
if add_images and is_caption:
source_info = f"--- Source: Page {current_page} ---"
caption_text = f"{source_info}\n{doc.page_content}"
context_messages.append({"type": "text", "text": caption_text})
image_path = doc.metadata.get("linked_figure_path")
if image_path and os.path.exists(image_path):
try:
image = Image.open(image_path).convert("RGB")
context_messages.append({"type": "image", "image": image})
images_found.append(image)
except Exception as e:
logger.warning(f"Could not load image {image_path}: {e}")
else:
if not prose_buffer:
source_info = f"--- Source: Page {current_page} ---"
prose_buffer.append(f"\n{source_info}\n")
prose_buffer.append(doc.page_content)
flush_prose_buffer()
return context_messages, images_found
@staticmethod
def select_final_context(retrieved_docs: list, config: dict, page_map: dict) -> list:
"""Selects final context from retrieved documents using the specified strategy."""
strategy = config.get("CONTEXT_SELECTION_STRATEGY")
top_k = config.get("FINAL_CONTEXT_TOP_K", 5)
def _calculate_block_frequencies(docs_list: list) -> list:
blocks = {}
for doc in docs_list:
if block_id := doc.metadata.get("block_id"):
if block_id not in blocks:
blocks[block_id] = []
blocks[block_id].append(doc)
return sorted(blocks.items(), key=lambda item: len(item[1]), reverse=True)
def _expand_chunks_to_blocks(chunks: list) -> list:
return [
LangchainDocument(
page_content=c.metadata.get("original_block_text", c.page_content),
metadata=c.metadata,
)
for c in chunks
]
final_context = []
if strategy == "chapter_aware_window_expansion":
if not retrieved_docs or not page_map:
return []
scored_blocks = _calculate_block_frequencies(retrieved_docs)
if not scored_blocks:
return _expand_chunks_to_blocks(retrieved_docs[:top_k])
primary_hit_page = scored_blocks[0][1][0].metadata.get("page_number")
important_pages = {
c[0].metadata.get("page_number")
for _, c in scored_blocks[:top_k]
if c and c[0].metadata.get("page_number")
}
window_size = config.get("CONTEXT_WINDOW_SIZE", 0)
pages_to_extract = set()
for page_num in important_pages:
current_chapter_info = page_map.get(page_num)
if not current_chapter_info:
continue
current_chapter_id = current_chapter_info["chapter_id"]
pages_to_extract.add(page_num)
for i in range(1, window_size + 1):
if (prev_info := page_map.get(page_num - i)) and prev_info["chapter_id"] == current_chapter_id:
pages_to_extract.add(page_num - i)
if (next_info := page_map.get(page_num + i)) and next_info["chapter_id"] == current_chapter_id:
pages_to_extract.add(page_num + i)
sorted_pages = sorted(list(pages_to_extract))
if primary_hit_page and primary_hit_page in page_map:
final_context.extend(page_map[primary_hit_page]["blocks"])
for page_num in sorted_pages:
if page_num != primary_hit_page and page_num in page_map:
final_context.extend(page_map[page_num]["blocks"])
elif strategy == "rerank_by_frequency":
scored_blocks = _calculate_block_frequencies(retrieved_docs)
representative_chunks = [chunks[0] for _, chunks in scored_blocks[:top_k]]
final_context = _expand_chunks_to_blocks(representative_chunks)
elif strategy == "select_by_rank":
unique_docs_map = {f"{doc.metadata.get('block_id', '')}_{doc.page_content}": doc for doc in retrieved_docs}
representative_chunks = list(unique_docs_map.values())[:top_k]
final_context = _expand_chunks_to_blocks(representative_chunks)
else:
logger.warning(f"Unknown strategy '{strategy}'. Defaulting to top-k raw chunks.")
final_context = retrieved_docs[:top_k]
logger.info(f"Selected {len(final_context)} final context blocks using '{strategy}' strategy.")
return final_context