File size: 10,328 Bytes
81cdd5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import List

from PIL import Image
from langchain.docstore.document import Document as LangchainDocument

from .knowledge_base import KnowledgeBase

logger = logging.getLogger(__name__)


def format_context_messages_to_string(context_messages: list[dict]) -> str:
    """Takes a list of context message dicts and formats them into a single string."""
    if not context_messages:
        return "No relevant context was retrieved from the guideline document."

    full_text = [
        msg.get("text", "") for msg in context_messages if msg.get("type") == "text"
    ]
    return "\n".join(full_text)


class RAGContextEngine:
    """Uses a pre-built KnowledgeBase to retrieve and format context for queries."""

    def __init__(self, knowledge_base: KnowledgeBase, config_overrides: dict | None = None):
        if not isinstance(knowledge_base, KnowledgeBase) or not knowledge_base.retriever:
            raise ValueError("An initialized KnowledgeBase with a built retriever is required.")
        self.kb = knowledge_base
        self.config = self._get_default_config()
        if config_overrides:
            self.config.update(config_overrides)

    def _get_default_config(self):
        return {
            "FINAL_CONTEXT_TOP_K": 5,
            "CONTEXT_SELECTION_STRATEGY": "chapter_aware_window_expansion",
            "CONTEXT_WINDOW_SIZE": 0,
            "ADD_MAPPED_FIGURES_TO_PROMPT": False,
        }

    def get_context_messages(self, query_text: str) -> list[dict] | None:
        """Public API to get final, formatted context messages for a long query."""
        final_context_docs = self.retrieve_context_docs(query_text)
        if not final_context_docs:
            logger.warning(f"No relevant context found for query: {query_text}")
            return None
        context_messages, _ = self.build_context_messages(final_context_docs)
        return context_messages

    def retrieve_context_docs(self, query_text: str) -> list:
        """Handles both short and long queries to retrieve context documents."""
        logger.info(f"Retrieving context documents with query: {query_text}")
        if len(query_text.split()) > 5:
            logger.info("Long query detected. Decomposing into sub-queries...")
            temp_doc = LangchainDocument(page_content=query_text)
            enriched_temp_docs = self.kb.document_enricher([temp_doc], summarize=False)
            query_chunks_as_docs = self.kb.chunker(enriched_docs=enriched_temp_docs, display_results=False)
            sub_queries = list(set([doc.page_content for doc in query_chunks_as_docs]))
        else:
            logger.info("Short query detected. Using direct retrieval.")
            sub_queries = [query_text]
        return self.retrieve_context_docs_for_simple_queries(sub_queries)

    def get_context_messages_for_simple_queries(self, queries: list[str]) -> list:
        """Retrieves context docs and builds them into formatted messages."""
        final_context_docs = self.retrieve_context_docs_for_simple_queries(queries)
        if not final_context_docs:
            logger.warning(f"No relevant context found for queries: {queries}")
            return []
        context_messages, _ = self.build_context_messages(final_context_docs)
        return context_messages

    def retrieve_context_docs_for_simple_queries(self, queries: list[str]) -> list:
        """Invokes the retriever for a list of simple queries and selects the final documents."""
        logger.info(f"Retrieving context documents with simple queries: {queries}")
        retrieved_docs = []
        for query in queries:
            docs = self.kb.retriever.invoke(query)
            retrieved_docs.extend(docs)

        return RAGContextEngine.select_final_context(
            retrieved_docs=retrieved_docs,
            config=self.config,
            page_map=self.kb.page_map,
        )

    def build_context_messages(
            self, docs: List[LangchainDocument]
    ) -> tuple[list[dict], list[Image.Image]]:
        """Builds a structured list of messages by grouping consecutive text blocks."""
        if not docs:
            return [], []

        context_messages = []
        images_found = []
        prose_buffer = []

        def flush_prose_buffer():
            if prose_buffer:
                full_prose = "\n\n".join(prose_buffer)
                context_messages.append({"type": "text", "text": full_prose})
                prose_buffer.clear()

        add_images = self.config.get("ADD_MAPPED_FIGURES_TO_PROMPT", False)
        for i, doc in enumerate(docs):
            current_page = doc.metadata.get("page_number")
            is_new_page = (i > 0) and (current_page != docs[i - 1].metadata.get("page_number"))
            is_caption = doc.metadata.get("chunk_type") == "figure-caption"

            if is_new_page or (add_images and is_caption):
                flush_prose_buffer()

            if add_images and is_caption:
                source_info = f"--- Source: Page {current_page} ---"
                caption_text = f"{source_info}\n{doc.page_content}"
                context_messages.append({"type": "text", "text": caption_text})
                image_path = doc.metadata.get("linked_figure_path")
                if image_path and os.path.exists(image_path):
                    try:
                        image = Image.open(image_path).convert("RGB")
                        context_messages.append({"type": "image", "image": image})
                        images_found.append(image)
                    except Exception as e:
                        logger.warning(f"Could not load image {image_path}: {e}")
            else:
                if not prose_buffer:
                    source_info = f"--- Source: Page {current_page} ---"
                    prose_buffer.append(f"\n{source_info}\n")
                prose_buffer.append(doc.page_content)

        flush_prose_buffer()
        return context_messages, images_found

    @staticmethod
    def select_final_context(retrieved_docs: list, config: dict, page_map: dict) -> list:
        """Selects final context from retrieved documents using the specified strategy."""
        strategy = config.get("CONTEXT_SELECTION_STRATEGY")
        top_k = config.get("FINAL_CONTEXT_TOP_K", 5)

        def _calculate_block_frequencies(docs_list: list) -> list:
            blocks = {}
            for doc in docs_list:
                if block_id := doc.metadata.get("block_id"):
                    if block_id not in blocks:
                        blocks[block_id] = []
                    blocks[block_id].append(doc)
            return sorted(blocks.items(), key=lambda item: len(item[1]), reverse=True)

        def _expand_chunks_to_blocks(chunks: list) -> list:
            return [
                LangchainDocument(
                    page_content=c.metadata.get("original_block_text", c.page_content),
                    metadata=c.metadata,
                )
                for c in chunks
            ]

        final_context = []
        if strategy == "chapter_aware_window_expansion":
            if not retrieved_docs or not page_map:
                return []

            scored_blocks = _calculate_block_frequencies(retrieved_docs)
            if not scored_blocks:
                return _expand_chunks_to_blocks(retrieved_docs[:top_k])

            primary_hit_page = scored_blocks[0][1][0].metadata.get("page_number")
            important_pages = {
                c[0].metadata.get("page_number")
                for _, c in scored_blocks[:top_k]
                if c and c[0].metadata.get("page_number")
            }

            window_size = config.get("CONTEXT_WINDOW_SIZE", 0)
            pages_to_extract = set()
            for page_num in important_pages:
                current_chapter_info = page_map.get(page_num)
                if not current_chapter_info:
                    continue
                current_chapter_id = current_chapter_info["chapter_id"]
                pages_to_extract.add(page_num)
                for i in range(1, window_size + 1):
                    if (prev_info := page_map.get(page_num - i)) and prev_info["chapter_id"] == current_chapter_id:
                        pages_to_extract.add(page_num - i)
                    if (next_info := page_map.get(page_num + i)) and next_info["chapter_id"] == current_chapter_id:
                        pages_to_extract.add(page_num + i)

            sorted_pages = sorted(list(pages_to_extract))
            if primary_hit_page and primary_hit_page in page_map:
                final_context.extend(page_map[primary_hit_page]["blocks"])
            for page_num in sorted_pages:
                if page_num != primary_hit_page and page_num in page_map:
                    final_context.extend(page_map[page_num]["blocks"])

        elif strategy == "rerank_by_frequency":
            scored_blocks = _calculate_block_frequencies(retrieved_docs)
            representative_chunks = [chunks[0] for _, chunks in scored_blocks[:top_k]]
            final_context = _expand_chunks_to_blocks(representative_chunks)

        elif strategy == "select_by_rank":
            unique_docs_map = {f"{doc.metadata.get('block_id', '')}_{doc.page_content}": doc for doc in retrieved_docs}
            representative_chunks = list(unique_docs_map.values())[:top_k]
            final_context = _expand_chunks_to_blocks(representative_chunks)

        else:
            logger.warning(f"Unknown strategy '{strategy}'. Defaulting to top-k raw chunks.")
            final_context = retrieved_docs[:top_k]

        logger.info(f"Selected {len(final_context)} final context blocks using '{strategy}' strategy.")
        return final_context