# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import concurrent.futures import logging import os import re from pathlib import Path from typing import Dict, List import fitz # PyMuPDF from PIL import Image from langchain.docstore.document import Document as LangchainDocument from langchain.retrievers import BM25Retriever, EnsembleRetriever from langchain.text_splitter import NLTKTextSplitter from langchain_community.vectorstores import Chroma from tqdm import tqdm logger = logging.getLogger(__name__) IMAGE_SUMMARY_PROMPT = """Summarize key findings in this image.""" class KnowledgeBase: """Processes a source PDF and builds a self-contained, searchable RAG knowledge base.""" def __init__(self, models: dict, config_overrides: dict | None = None): """Initializes the builder with necessary models and configuration.""" self.embedder = models.get("embedder") self.ner_pipeline = models.get("ner_pipeline") # Set default config and apply any overrides self.config = self._get_default_config() if config_overrides: self.config.update(config_overrides) # For consistent chunking, the RAG query uses the same enriching and chunking logic as the knowledge base. self.document_enricher = self._enrich_documents self.chunker = self._create_chunks_from_documents self.retriever: EnsembleRetriever | None = None self.page_map: Dict[int, Dict] = {} self.source_filepath = "" # Create necessary directories from config Path(self.config["IMAGE_DIR"]).mkdir(parents=True, exist_ok=True) Path(self.config["CHROMA_PERSIST_DIR"]).mkdir(parents=True, exist_ok=True) def _get_default_config(self): """Returns the default configuration for the KnowledgeBase.""" return { "IMAGE_DIR": Path("processed_figures_kb/"), "CHROMA_PERSIST_DIR": Path("chroma_db_store/"), "MEDICAL_ENTITY_TYPES_TO_EXTRACT": ["PROBLEM"], "EXTRACT_IMAGE_SUMMARIES": False, # Disabled as we don't load the LLM here "FILTER_FIRST_PAGES": 6, "FIGURE_MIN_WIDTH": 30, "FIGURE_MIN_HEIGHT": 30, "SENTENCE_CHUNK_SIZE": 250, "CHUNK_FILTER_SIZE": 20, "RETRIEVER_TOP_K": 20, "ENSEMBLE_WEIGHTS_BM25,SENTENCE,NER": [0.2, 0.3, 0.5], "SENTENCE_SCORE_THRESHOLD": 0.6, "NER_SCORE_THRESHOLD": 0.5, "MAX_PARALLEL_WORKERS": 16, } def build(self, pdf_filepath: str): """The main public method to build the knowledge base from a PDF.""" logger.info(f"--------- Building Knowledge Base from '{pdf_filepath}' ---------") pdf_path = Path(pdf_filepath) if not pdf_path.exists(): logger.error(f"ERROR: PDF file not found at {pdf_filepath}") return None self.source_filepath = pdf_path # Step 1: Process the PDF and build the structured page_map. self.page_map = self._process_and_structure_pdf(pdf_path) all_docs = [ doc for page_data in self.page_map.values() for doc in page_data["blocks"] ] # Step 2: Enrich documents with NER metadata. enriched_docs = self._enrich_documents(all_docs, self.config.get("EXTRACT_IMAGE_SUMMARIES", False)) # Step 3: Chunk the enriched documents into final searchable units. final_chunks = self._create_chunks_from_documents(enriched_docs) # Step 4: Build the final ensemble retriever. self.retriever = self._build_ensemble_retriever(final_chunks) if self.retriever: logger.info(f"--------- Knowledge Base Built Successfully ---------") else: logger.error(f"--------- Knowledge Base Building Failed ---------") return self # --- Step 1: PDF Content Extraction --- def _process_and_structure_pdf(self, pdf_path: Path) -> dict: """Processes a PDF in parallel and directly builds the final page_map. This version is more efficient by opening the PDF only once. """ logger.info("Step 1: Processing PDF and building structured page map...") page_map = {} try: # Improvement: Open the PDF ONCE to get all preliminary info with fitz.open(pdf_path) as doc: pdf_bytes_buffer = doc.write() page_count = len(doc) toc = doc.get_toc() # Improvement: Create a more robust chapter lookup map page_to_chapter_id = {} if toc: chapters = [item for item in toc if item[0] == 1] for i, (lvl, title, start_page) in enumerate(chapters): end_page = ( chapters[i + 1][2] - 1 if i + 1 < len(chapters) else page_count ) for page_num in range(start_page, end_page + 1): page_to_chapter_id[page_num] = i # Create tasks for the thread pool (using a tuple as requested) tasks = [ ( pdf_bytes_buffer, i, self.config, pdf_path.name, page_to_chapter_id, ) for i in range(self.config["FILTER_FIRST_PAGES"], page_count) ] # Parallel Processing num_workers = min( self.config["MAX_PARALLEL_WORKERS"], os.cpu_count() or 1 ) with concurrent.futures.ThreadPoolExecutor( max_workers=num_workers ) as executor: futures = [ executor.submit(self.process_single_page, task) for task in tasks ] progress_bar = tqdm( concurrent.futures.as_completed(futures), total=len(tasks), desc="Processing & Structuring Pages", ) for future in progress_bar: result = future.result() if result: # The worker now returns a fully formed dictionary for the page_map page_map[result["page_num"]] = result["content"] except Exception as e: logger.error(f"❌ Failed to process PDF {pdf_path.name}: {e}") return {} logger.info(f"✅ PDF processed. Created a map of {len(page_map)} pages.") return dict(sorted(page_map.items())) # --- Step 2: Document Enrichment --- def _enrich_documents( self, docs: List[LangchainDocument], summarize: bool = False ) -> List[LangchainDocument]: """Enriches a list of documents with NER metadata and image summaries.""" logger.info("\nStep 2: Enriching documents...") # NER Enrichment if self.ner_pipeline: logger.info("Adding NER metadata...") for doc in tqdm(docs, desc="Enriching with NER"): # 1. Skip documents that have no actual text content if not doc.page_content or not doc.page_content.strip(): continue try: # 2. Process ONLY the text of the current document processed_doc = self.ner_pipeline(doc.page_content) # 3. Extract entities from the result. This result now # unambiguously belongs to the current 'doc'. entities = [ ent.text for ent in processed_doc.ents if ent.type in self.config["MEDICAL_ENTITY_TYPES_TO_EXTRACT"] ] # 4. Assign the correctly mapped entities to the document's metadata if entities: # Using set() handles duplicates before sorting and joining unique_entities = sorted(list(set(entities))) doc.metadata["block_ner_entities"] = ", ".join(unique_entities) except Exception as e: # Add error handling for robustness in case a single block fails logger.warning( f"\nWarning: Could not process NER for a block on page {doc.metadata.get('page_number', 'N/A')}: {e}") # Image Summary Enrichment if summarize: logger.info("Generating image summaries...") docs_with_figures = [ doc for doc in docs if "linked_figure_path" in doc.metadata ] for doc in tqdm(docs_with_figures, desc="Summarizing Images"): try: img = Image.open(doc.metadata["linked_figure_path"]).convert("RGB") summary = self._summarize_image(img) if summary: doc.metadata["image_summary"] = summary except Exception as e: logger.warning( "Warning: Could not summarize image" f" {doc.metadata.get('linked_figure_path', '')}: {e}" ) return docs def _summarize_image(self, pil_image: Image.Image) -> str: """Helper method to call the LLM for image summarization.""" if not self.llm_pipeline: return "" messages = [{ "role": "user", "content": [ {"type": "text", "text": IMAGE_SUMMARY_PROMPT}, {"type": "image", "image": pil_image}, ], }] try: output = self.llm_pipeline(text=messages, max_new_tokens=150) return output[0]["generated_text"][-1]["content"].strip() except Exception: return "" # --- Step 3: Document Chunking --- def _create_chunks_from_documents( self, enriched_docs: List[LangchainDocument], display_results: bool = True ) -> List[LangchainDocument]: """Takes enriched documents and creates the final list of chunks for indexing. This method now has a single responsibility: chunking. """ if display_results: logger.info("\nStep 3: Creating final chunks...") # Sentence Splitting if display_results: logger.info("Applying NLTK Sentence Splitting...") splitter = NLTKTextSplitter(chunk_size=self.config["SENTENCE_CHUNK_SIZE"]) sentence_chunks = splitter.split_documents(enriched_docs) if display_results: logger.info(f"Generated {len(sentence_chunks)} sentence-level chunks.") # NER Entity Chunking (based on previously enriched metadata) if display_results: logger.info("Creating NER Entity Chunks...") ner_entity_chunks = [ LangchainDocument( page_content=entity, metadata={**doc.metadata, "chunk_type": "ner_entity_standalone"}, ) for doc in enriched_docs if (entities_str := doc.metadata.get("block_ner_entities")) for entity in entities_str.split(", ") if entity ] if display_results: logger.info(f"Added {len(ner_entity_chunks)} NER entity chunks.") all_chunks = sentence_chunks + ner_entity_chunks return [chunk for chunk in all_chunks if chunk.page_content] # --- Step 4: Retriever Building --- def _build_ensemble_retriever( self, chunks: List[LangchainDocument] ) -> EnsembleRetriever | None: """Builds the final ensemble retriever from the chunks. This method was already well-focused. """ if not chunks: logger.error("No chunks to build retriever from.") return None logger.info("\nStep 4: Building specialized retrievers...") sentence_chunks = [ doc for doc in chunks if doc.metadata.get("chunk_type") != "ner_entity_standalone" ] ner_chunks = [ doc for doc in chunks if doc.metadata.get("chunk_type") == "ner_entity_standalone" ] retrievers, weights = [], [] if sentence_chunks: bm25_retriever = BM25Retriever.from_documents(sentence_chunks) bm25_retriever.k = self.config["RETRIEVER_TOP_K"] retrievers.append(bm25_retriever) weights.append(self.config["ENSEMBLE_WEIGHTS_BM25,SENTENCE,NER"][0]) sentence_vs = Chroma.from_documents( documents=sentence_chunks, embedding=self.embedder, persist_directory=str( self.config["CHROMA_PERSIST_DIR"] / "sentences" ), ) vector_retriever = sentence_vs.as_retriever( search_type="similarity_score_threshold", search_kwargs={ "k": self.config["RETRIEVER_TOP_K"], "score_threshold": self.config["SENTENCE_SCORE_THRESHOLD"], }, ) retrievers.append(vector_retriever) weights.append(self.config["ENSEMBLE_WEIGHTS_BM25,SENTENCE,NER"][1]) if ner_chunks: ner_vs = Chroma.from_documents( documents=ner_chunks, embedding=self.embedder, persist_directory=str(self.config["CHROMA_PERSIST_DIR"] / "entities"), ) ner_retriever = ner_vs.as_retriever( search_type="similarity_score_threshold", search_kwargs={ "k": self.config["RETRIEVER_TOP_K"], "score_threshold": self.config["NER_SCORE_THRESHOLD"], }, ) retrievers.append(ner_retriever) weights.append(self.config["ENSEMBLE_WEIGHTS_BM25,SENTENCE,NER"][2]) if not retrievers: logger.error("⚠️ Could not create any retrievers.") return None logger.info(f"Creating final ensemble with weights: {weights}") return EnsembleRetriever(retrievers=retrievers, weights=weights) @staticmethod def process_single_page(args_tuple: tuple) -> dict | None: """Worker function for parallel PDF processing. Processes one page and returns a structured dictionary for that page. """ # Unpack arguments (still using a tuple as requested) pdf_bytes_buffer, page_num_idx, config, pdf_filename, page_to_chapter_id = ( args_tuple ) lc_documents = [] page_num = page_num_idx + 1 try: # Improvement: Use a 'with' statement for resource management with fitz.open(stream=pdf_bytes_buffer, filetype="pdf") as doc: page = doc[page_num_idx] # 1. Extract raw, potentially fragmented text blocks raw_text_blocks = page.get_text("blocks", sort=True) # 2. Immediately merge blocks into paragraphs >>> paragraph_blocks = KnowledgeBase._merge_text_blocks(raw_text_blocks) # 3. Process figures (no change) page_figures = [] for fig_j, path_dict in enumerate(page.get_drawings()): bbox = path_dict["rect"] if ( bbox.is_empty or bbox.width < config["FIGURE_MIN_WIDTH"] or bbox.height < config["FIGURE_MIN_HEIGHT"] ): continue # Improvement: More concise bounding box padding padded_bbox = bbox + (-2, -2, 2, 2) padded_bbox.intersect(page.rect) if padded_bbox.is_empty: continue pix = page.get_pixmap(clip=padded_bbox, dpi=150) if pix.width > 0 and pix.height > 0: img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) img_path = ( config["IMAGE_DIR"] / f"{Path(pdf_filename).stem}_p{page_num}_fig{fig_j + 1}.png" ) img.save(img_path) page_figures.append({ "bbox": bbox, "path": str(img_path), "id": f"Figure {fig_j + 1} on {pdf_filename}, page {page_num}", }) # 4. Process the clean PARAGRAPH blocks text_blocks_on_page = [ { "bbox": fitz.Rect(x0, y0, x1, y1), "text": text.strip(), "original_idx": b_idx, } for b_idx, (x0, y0, x1, y1, text, _, _) in enumerate( paragraph_blocks ) if text.strip() ] # 5. Link captions and create documents potential_captions = [ b for b in text_blocks_on_page if re.match(r"^\s*Figure\s*\d+", b["text"], re.I) ] mapped_caption_indices = set() for fig_data in page_figures: cap_text, cap_idx = KnowledgeBase.find_best_caption_for_figure( fig_data["bbox"], potential_captions ) if cap_text and cap_idx not in mapped_caption_indices: mapped_caption_indices.add(cap_idx) metadata = { "source_pdf": pdf_filename, "page_number": page_num, "chunk_type": "figure-caption", "linked_figure_path": fig_data["path"], "linked_figure_id": fig_data["id"], "block_id": f"{page_num}_{cap_idx}", "original_block_text": cap_text, } lc_documents.append( LangchainDocument(page_content=cap_text, metadata=metadata) ) for block_data in text_blocks_on_page: if block_data["original_idx"] in mapped_caption_indices: continue if KnowledgeBase.should_filter_text_block( block_data["text"], block_data["bbox"], page.rect.height, config["CHUNK_FILTER_SIZE"], ): continue metadata = { "source_pdf": pdf_filename, "page_number": page_num, "chunk_type": "text_block", "block_id": f"{page_num}_{block_data['original_idx']}", "original_block_text": block_data["text"], } lc_documents.append( LangchainDocument( page_content=block_data["text"], metadata=metadata ) ) except Exception as e: logger.error(f"Error processing {pdf_filename} page {page_num}: {e}") return None if not lc_documents: return None # Structure the final output lc_documents.sort( key=lambda d: int(d.metadata.get("block_id", "0_0").split("_")[-1]) ) return { "page_num": page_num, "content": { "chapter_id": page_to_chapter_id.get(page_num, -1), "blocks": lc_documents, }, } @staticmethod def _merge_text_blocks(blocks: list) -> list: """Intelligently merges fragmented text blocks into coherent paragraphs.""" if not blocks: return [] merged_blocks = [] current_text = "" current_bbox = fitz.Rect() sentence_enders = {".", "?", "!", "•"} for i, block in enumerate(blocks): block_text = block[4].strip() if not current_text: # Starting a new paragraph current_bbox = fitz.Rect(block[:4]) current_text = block_text else: # Continue existing paragraph current_bbox.include_rect(block[:4]) current_text = f"{current_text} {block_text}" is_last_block = i == len(blocks) - 1 ends_with_punctuation = block_text.endswith(tuple(sentence_enders)) if ends_with_punctuation or is_last_block: merged_blocks.append(( current_bbox.x0, current_bbox.y0, current_bbox.x1, current_bbox.y1, current_text, len(merged_blocks), 0, )) current_text = "" return merged_blocks @staticmethod def should_filter_text_block( block_text: str, block_bbox: fitz.Rect, page_height: float, filter_size: int, ) -> bool: """Determines if a text block from a header/footer should be filtered out.""" is_in_header_area = block_bbox.y0 < (page_height * 0.10) is_in_footer_area = block_bbox.y1 > (page_height * 0.80) is_short_text = len(block_text) < filter_size return (is_in_header_area or is_in_footer_area) and is_short_text @staticmethod def find_best_caption_for_figure( figure_bbox: fitz.Rect, potential_captions_on_page: list ) -> tuple: """Finds the best caption for a given figure based on proximity and alignment.""" best_caption_info = (None, -1) min_score = float("inf") for cap_info in potential_captions_on_page: cap_bbox = cap_info["bbox"] # Heuristic: Score captions directly below the figure if cap_bbox.y0 >= figure_bbox.y1 - 10: # Caption starts below the figure vertical_dist = cap_bbox.y0 - figure_bbox.y1 # Calculate horizontal overlap overlap_x_start = max(figure_bbox.x0, cap_bbox.x0) overlap_x_end = min(figure_bbox.x1, cap_bbox.x1) if ( overlap_x_end - overlap_x_start ) > 0: # If they overlap horizontally fig_center_x = (figure_bbox.x0 + figure_bbox.x1) / 2 cap_center_x = (cap_bbox.x0 + cap_bbox.x1) / 2 horizontal_center_dist = abs(fig_center_x - cap_center_x) # Score is a combination of vertical and horizontal distance score = vertical_dist + (horizontal_center_dist * 0.5) if score < min_score: min_score = score best_caption_info = (cap_info["text"], cap_info["original_idx"]) return best_caption_info