visual-rag-tool / tool.py
paultltc's picture
refactor to follow tool validation
70f7106
raw
history blame
8.56 kB
import os
from smolagents import Tool
from dotenv import load_dotenv
load_dotenv()
class VisualRAGTool(Tool):
name = "visual_rag"
description = """Performs a RAG query on your internal PDF documents and returns the generated text response."""
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents.",
},
"k": {
"type": "number",
"description": "The number of documents to retrieve.",
"default": 1,
"nullable": True,
},
"api_key": {
"type": "string",
"description": "The OpenAI API key to use for the query. If not provided, the key will be taken from the OPENAI_KEY environment variable.",
"nullable": True,
}
}
output_type = "string"
model_name: str = "vidore/colqwen2-v1.0"
api_key: str = os.getenv("OPENAI_KEY")
def __init__(self, *args, **kwargs):
self.is_initialized = False
def _init_models(self, model_name: str) -> None:
import torch
from colpali_engine.models import ColQwen2, ColQwen2Processor
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2"
).eval()
self.processor = ColQwen2Processor.from_pretrained(model_name)
def setup(self):
"""
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
your tool. Such as loading a big model.
"""
self._init_models(self.model_name)
self.embds = []
self.pages = []
self.is_initialized = True
def _extract_contexts(self, images, api_key, window=10) -> list:
"""Extracts context from images."""
from utils import query_openai, Page, CONTEXT_SYSTEM_PROMPT
from pqdm.processes import pqdm
try:
args = [
{
'query': "Give the general context about these pages. Give the context in the same language as the documents.",
'pages': [Page(image=im) for im in images[max(i-window+1, 0):i+1]],
'api_key': api_key,
'system_prompt': CONTEXT_SYSTEM_PROMPT,
} for i in range(0, len(images), window)
]
window_contexts = pqdm(args, query_openai, n_jobs=8, argument_type='kwargs')
# code sequentially ftm with tqdm
# query = "Give the general context about these pages. Give the context in the same language as the documents."
# window_contexts = [query_openai(query, [Page(image=im) for im in images[max(i-window+1, 0):i+1]], api_key, DEFAULT_CONTEXT_PROMPT)\
# for i in tqdm(range(0, len(images), window))]
contexts = []
for i in range(len(images)):
context = window_contexts[i//window].content
contexts.append(context)
except Exception as e:
print(f"Error extracting contexts: {e}")
contexts = [None for _ in range(len(images))]
# Ensure that the number of contexts is equal to the number of images
assert len(contexts) == len(images)
return contexts
def _preprocess_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
"""Converts a file to images and extracts metadata."""
from pdf2image import convert_from_path
from utils import Metadata, Page
title = file.split("/")[-1]
images = convert_from_path(file, thread_count=4)
if contextualize and api_key:
contexts = self._extract_contexts(images, api_key, window=window)
else:
contexts = [None for _ in range(len(images))]
metadatas = [Metadata(doc_title=title, page_id=i, context=contexts[i]) for i in range(len(images))]
return [Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)]
def preprocess(self, files: list, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
"""Preprocesses the files and extracts metadata."""
pages = [page for file in files for page in self._preprocess_file(file, contextualize=contextualize, api_key=api_key, window=window)]
print(f"Example metadata:\n{pages[0].metadata.context}")
return pages
def compute_embeddings(self, pages) -> list:
"""Embeds the images using the model."""
"""Example script to run inference with ColPali (ColQwen2)"""
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
# run inference - docs
dataloader = DataLoader(
pages,
batch_size=4,
shuffle=False,
collate_fn=lambda x: self.processor.process_images([p.image for p in x]).to(self.device),
)
embds = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()}
embeddings_doc = self.model(**batch_doc)
embds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return embds
def index(self, files: list, contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int:
"""Indexes the uploaded files."""
if not self.is_initialized:
self.setup()
print("Converting files...")
# Convert files to images and extract metadata
pgs = self.preprocess(files, contextualize=contextualize, api_key=api_key or self.api_key)
# Embed the images
embds = self.compute_embeddings(pgs)
# Overwrite the database if necessary
if overwrite_db:
self.pages = []
self.embds = []
# Extend the pages
self.pages.extend(pgs)
# Extend the datasets
self.embds.extend(embds)
print(f"Extracted and indexed {len(pgs)} images from {len(files)} files.")
return len(embds)
def retrieve(self, query: str, k: int) -> list:
"""Retrieve the top k documents based on the query."""
import torch
k = min(k, len(self.embds))
qs = []
with torch.no_grad():
batch_query = self.processor.process_queries([query]).to(self.model.device)
embeddings_query = self.model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# Run scoring
scores = self.processor.score(qs, self.embds, device=self.device)[0]
top_k_idx = scores.topk(k).indices.tolist()
print("Top Scores:")
[print(f'Page {self.pages[idx].metadata.page_id}: {scores[idx]}') for idx in top_k_idx]
# Get the top k results
results = [self.pages[idx] for idx in top_k_idx]
return results
def generate_answer(self, query: str, docs: list, api_key: str = None):
"""Generates an answer based on the query and the retrieved documents."""
from utils import query_openai, RAG_SYSTEM_PROMPT
result = query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT)
return result
def search(self, query: str, k: int = 1, api_key: str = None) -> tuple:
"""Searches for the most relevant pages based on the query."""
print(f"Searching for query: {query}")
# Retrieve the top k documents
context = self.retrieve(query, k)
# Generate response from GPT-4o-mini
rag_answer = self.generate_answer(
query=query,
docs=context,
api_key=api_key
)
return context, rag_answer.content
def forward(self, query: str, k: int = 1, api_key: str = None) -> str:
assert isinstance(query, str), "Your search query must be a string"
# Online indexing
# if files:
# _ = self.index(files, api_key)
# Retrieve the top k documents and generate response
return self.search(
query=query,
files=None,
k=k,
api_key=api_key
)[1]