Spaces:
Running
Running
| import os | |
| from smolagents import Tool | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class VisualRAGTool(Tool): | |
| name = "visual_rag" | |
| description = """Performs a RAG query on your internal PDF documents and returns the generated text response.""" | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The query to perform. This should be semantically close to your target documents.", | |
| }, | |
| "k": { | |
| "type": "number", | |
| "description": "The number of documents to retrieve.", | |
| "default": 1, | |
| "nullable": True, | |
| }, | |
| "api_key": { | |
| "type": "string", | |
| "description": "The OpenAI API key to use for the query. If not provided, the key will be taken from the OPENAI_KEY environment variable.", | |
| "nullable": True, | |
| } | |
| } | |
| output_type = "string" | |
| model_name: str = "vidore/colqwen2-v1.0" | |
| api_key: str = os.getenv("OPENAI_KEY") | |
| def __init__(self, *args, **kwargs): | |
| self.is_initialized = False | |
| def _init_models(self, model_name: str) -> None: | |
| import torch | |
| from colpali_engine.models import ColQwen2, ColQwen2Processor | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = ColQwen2.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| attn_implementation="flash_attention_2" | |
| ).eval() | |
| self.processor = ColQwen2Processor.from_pretrained(model_name) | |
| def setup(self): | |
| """ | |
| Overwrite this method here for any operation that is expensive and needs to be executed before you start using | |
| your tool. Such as loading a big model. | |
| """ | |
| self._init_models(self.model_name) | |
| self.embds = [] | |
| self.pages = [] | |
| self.is_initialized = True | |
| def _extract_contexts(self, images, api_key, window=10) -> list: | |
| """Extracts context from images.""" | |
| from utils import query_openai, Page, CONTEXT_SYSTEM_PROMPT | |
| from pqdm.processes import pqdm | |
| try: | |
| args = [ | |
| { | |
| 'query': "Give the general context about these pages. Give the context in the same language as the documents.", | |
| 'pages': [Page(image=im) for im in images[max(i-window+1, 0):i+1]], | |
| 'api_key': api_key, | |
| 'system_prompt': CONTEXT_SYSTEM_PROMPT, | |
| } for i in range(0, len(images), window) | |
| ] | |
| window_contexts = pqdm(args, query_openai, n_jobs=8, argument_type='kwargs') | |
| # code sequentially ftm with tqdm | |
| # query = "Give the general context about these pages. Give the context in the same language as the documents." | |
| # window_contexts = [query_openai(query, [Page(image=im) for im in images[max(i-window+1, 0):i+1]], api_key, DEFAULT_CONTEXT_PROMPT)\ | |
| # for i in tqdm(range(0, len(images), window))] | |
| contexts = [] | |
| for i in range(len(images)): | |
| context = window_contexts[i//window].content | |
| contexts.append(context) | |
| except Exception as e: | |
| print(f"Error extracting contexts: {e}") | |
| contexts = [None for _ in range(len(images))] | |
| # Ensure that the number of contexts is equal to the number of images | |
| assert len(contexts) == len(images) | |
| return contexts | |
| def _preprocess_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> list: | |
| """Converts a file to images and extracts metadata.""" | |
| from pdf2image import convert_from_path | |
| from utils import Metadata, Page | |
| title = file.split("/")[-1] | |
| images = convert_from_path(file, thread_count=4) | |
| if contextualize and api_key: | |
| contexts = self._extract_contexts(images, api_key, window=window) | |
| else: | |
| contexts = [None for _ in range(len(images))] | |
| metadatas = [Metadata(doc_title=title, page_id=i, context=contexts[i]) for i in range(len(images))] | |
| return [Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)] | |
| def preprocess(self, files: list, contextualize: bool = True, api_key: str = None, window: int = 10) -> list: | |
| """Preprocesses the files and extracts metadata.""" | |
| pages = [page for file in files for page in self._preprocess_file(file, contextualize=contextualize, api_key=api_key, window=window)] | |
| print(f"Example metadata:\n{pages[0].metadata.context}") | |
| return pages | |
| def compute_embeddings(self, pages) -> list: | |
| """Embeds the images using the model.""" | |
| """Example script to run inference with ColPali (ColQwen2)""" | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| # run inference - docs | |
| dataloader = DataLoader( | |
| pages, | |
| batch_size=4, | |
| shuffle=False, | |
| collate_fn=lambda x: self.processor.process_images([p.image for p in x]).to(self.device), | |
| ) | |
| embds = [] | |
| for batch_doc in tqdm(dataloader): | |
| with torch.no_grad(): | |
| batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()} | |
| embeddings_doc = self.model(**batch_doc) | |
| embds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
| return embds | |
| def index(self, files: list, contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int: | |
| """Indexes the uploaded files.""" | |
| if not self.is_initialized: | |
| self.setup() | |
| print("Converting files...") | |
| # Convert files to images and extract metadata | |
| pgs = self.preprocess(files, contextualize=contextualize, api_key=api_key or self.api_key) | |
| # Embed the images | |
| embds = self.compute_embeddings(pgs) | |
| # Overwrite the database if necessary | |
| if overwrite_db: | |
| self.pages = [] | |
| self.embds = [] | |
| # Extend the pages | |
| self.pages.extend(pgs) | |
| # Extend the datasets | |
| self.embds.extend(embds) | |
| print(f"Extracted and indexed {len(pgs)} images from {len(files)} files.") | |
| return len(embds) | |
| def retrieve(self, query: str, k: int) -> list: | |
| """Retrieve the top k documents based on the query.""" | |
| import torch | |
| k = min(k, len(self.embds)) | |
| qs = [] | |
| with torch.no_grad(): | |
| batch_query = self.processor.process_queries([query]).to(self.model.device) | |
| embeddings_query = self.model(**batch_query) | |
| qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
| # Run scoring | |
| scores = self.processor.score(qs, self.embds, device=self.device)[0] | |
| top_k_idx = scores.topk(k).indices.tolist() | |
| print("Top Scores:") | |
| [print(f'Page {self.pages[idx].metadata.page_id}: {scores[idx]}') for idx in top_k_idx] | |
| # Get the top k results | |
| results = [self.pages[idx] for idx in top_k_idx] | |
| return results | |
| def generate_answer(self, query: str, docs: list, api_key: str = None): | |
| """Generates an answer based on the query and the retrieved documents.""" | |
| from utils import query_openai, RAG_SYSTEM_PROMPT | |
| result = query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT) | |
| return result | |
| def search(self, query: str, k: int = 1, api_key: str = None) -> tuple: | |
| """Searches for the most relevant pages based on the query.""" | |
| print(f"Searching for query: {query}") | |
| # Retrieve the top k documents | |
| context = self.retrieve(query, k) | |
| # Generate response from GPT-4o-mini | |
| rag_answer = self.generate_answer( | |
| query=query, | |
| docs=context, | |
| api_key=api_key | |
| ) | |
| return context, rag_answer.content | |
| def forward(self, query: str, k: int = 1, api_key: str = None) -> str: | |
| assert isinstance(query, str), "Your search query must be a string" | |
| # Online indexing | |
| # if files: | |
| # _ = self.index(files, api_key) | |
| # Retrieve the top k documents and generate response | |
| return self.search( | |
| query=query, | |
| files=None, | |
| k=k, | |
| api_key=api_key | |
| )[1] |