import os from smolagents import Tool from dotenv import load_dotenv load_dotenv() class VisualRAGTool(Tool): name = "visual_rag" description = """Performs a RAG query on your internal PDF documents and returns the generated text response.""" inputs = { "query": { "type": "string", "description": "The query to perform. This should be semantically close to your target documents.", }, "k": { "type": "number", "description": "The number of documents to retrieve.", "default": 1, "nullable": True, }, "api_key": { "type": "string", "description": "The OpenAI API key to use for the query. If not provided, the key will be taken from the OPENAI_KEY environment variable.", "nullable": True, } } output_type = "string" model_name: str = "vidore/colqwen2-v1.0" api_key: str = os.getenv("OPENAI_KEY") def __init__(self, *args, **kwargs): self.is_initialized = False def _init_models(self, model_name: str) -> None: import torch from colpali_engine.models import ColQwen2, ColQwen2Processor self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = ColQwen2.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2" ).eval() self.processor = ColQwen2Processor.from_pretrained(model_name) def setup(self): """ Overwrite this method here for any operation that is expensive and needs to be executed before you start using your tool. Such as loading a big model. """ self._init_models(self.model_name) self.embds = [] self.pages = [] self.is_initialized = True def _extract_contexts(self, images, api_key, window=10) -> list: """Extracts context from images.""" from utils import query_openai, Page, CONTEXT_SYSTEM_PROMPT from pqdm.processes import pqdm try: args = [ { 'query': "Give the general context about these pages. Give the context in the same language as the documents.", 'pages': [Page(image=im) for im in images[max(i-window+1, 0):i+1]], 'api_key': api_key, 'system_prompt': CONTEXT_SYSTEM_PROMPT, } for i in range(0, len(images), window) ] window_contexts = pqdm(args, query_openai, n_jobs=8, argument_type='kwargs') # code sequentially ftm with tqdm # query = "Give the general context about these pages. Give the context in the same language as the documents." # window_contexts = [query_openai(query, [Page(image=im) for im in images[max(i-window+1, 0):i+1]], api_key, DEFAULT_CONTEXT_PROMPT)\ # for i in tqdm(range(0, len(images), window))] contexts = [] for i in range(len(images)): context = window_contexts[i//window].content contexts.append(context) except Exception as e: print(f"Error extracting contexts: {e}") contexts = [None for _ in range(len(images))] # Ensure that the number of contexts is equal to the number of images assert len(contexts) == len(images) return contexts def _preprocess_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> list: """Converts a file to images and extracts metadata.""" from pdf2image import convert_from_path from utils import Metadata, Page title = file.split("/")[-1] images = convert_from_path(file, thread_count=4) if contextualize and api_key: contexts = self._extract_contexts(images, api_key, window=window) else: contexts = [None for _ in range(len(images))] metadatas = [Metadata(doc_title=title, page_id=i, context=contexts[i]) for i in range(len(images))] return [Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)] def preprocess(self, files: list, contextualize: bool = True, api_key: str = None, window: int = 10) -> list: """Preprocesses the files and extracts metadata.""" pages = [page for file in files for page in self._preprocess_file(file, contextualize=contextualize, api_key=api_key, window=window)] print(f"Example metadata:\n{pages[0].metadata.context}") return pages def compute_embeddings(self, pages) -> list: """Embeds the images using the model.""" """Example script to run inference with ColPali (ColQwen2)""" import torch from torch.utils.data import DataLoader from tqdm import tqdm # run inference - docs dataloader = DataLoader( pages, batch_size=4, shuffle=False, collate_fn=lambda x: self.processor.process_images([p.image for p in x]).to(self.device), ) embds = [] for batch_doc in tqdm(dataloader): with torch.no_grad(): batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()} embeddings_doc = self.model(**batch_doc) embds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) return embds def index(self, files: list, contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int: """Indexes the uploaded files.""" if not self.is_initialized: self.setup() print("Converting files...") # Convert files to images and extract metadata pgs = self.preprocess(files, contextualize=contextualize, api_key=api_key or self.api_key) # Embed the images embds = self.compute_embeddings(pgs) # Overwrite the database if necessary if overwrite_db: self.pages = [] self.embds = [] # Extend the pages self.pages.extend(pgs) # Extend the datasets self.embds.extend(embds) print(f"Extracted and indexed {len(pgs)} images from {len(files)} files.") return len(embds) def retrieve(self, query: str, k: int) -> list: """Retrieve the top k documents based on the query.""" import torch k = min(k, len(self.embds)) qs = [] with torch.no_grad(): batch_query = self.processor.process_queries([query]).to(self.model.device) embeddings_query = self.model(**batch_query) qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) # Run scoring scores = self.processor.score(qs, self.embds, device=self.device)[0] top_k_idx = scores.topk(k).indices.tolist() print("Top Scores:") [print(f'Page {self.pages[idx].metadata.page_id}: {scores[idx]}') for idx in top_k_idx] # Get the top k results results = [self.pages[idx] for idx in top_k_idx] return results def generate_answer(self, query: str, docs: list, api_key: str = None): """Generates an answer based on the query and the retrieved documents.""" from utils import query_openai, RAG_SYSTEM_PROMPT result = query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT) return result def search(self, query: str, k: int = 1, api_key: str = None) -> tuple: """Searches for the most relevant pages based on the query.""" print(f"Searching for query: {query}") # Retrieve the top k documents context = self.retrieve(query, k) # Generate response from GPT-4o-mini rag_answer = self.generate_answer( query=query, docs=context, api_key=api_key ) return context, rag_answer.content def forward(self, query: str, k: int = 1, api_key: str = None) -> str: assert isinstance(query, str), "Your search query must be a string" # Online indexing # if files: # _ = self.index(files, api_key) # Retrieve the top k documents and generate response return self.search( query=query, k=k, api_key=api_key )[1]