visual-rag-tool / tool.py
paultltc's picture
init commit
a924f05
raw
history blame
13.1 kB
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
from torch.utils.data import DataLoader, Dataset
import base64
from io import BytesIO
from PIL import Image
from pdf2image import convert_from_path
from tqdm import tqdm
from pqdm.processes import pqdm
from colpali_engine.models import ColQwen2, ColQwen2Processor
from smolagents import Tool, ChatMessage
from utils import query_openai
from dotenv import load_dotenv
load_dotenv()
def encode_image_to_base64(image):
"""Encodes a PIL image to a base64 string."""
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
DEFAULT_SYSTEM_PROMPT = \
"""You are a smart assistant designed to answer questions about a PDF document.
You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context.
Use them to construct a short response to the question, and cite your sources in the following format: (document, page number).
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
Give detailed and extensive answers, only containing info in the pages you are given.
You can answer using information contained in plots and figures if necessary.
Answer in the same language as the query."""
def _build_query(query, pages):
messages = []
messages.append({"type": "text", "text": "PDF pages:\n"})
for page in pages:
capt = page.caption
if capt is not None:
messages.append({
"type": "text",
"text": capt
})
messages.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encode_image_to_base64(page.image)}"
},
})
messages.append({"type": "text", "text": f"Query:\n{query}"})
return messages
def query_openai(query, pages, api_key=None, system_prompt=DEFAULT_SYSTEM_PROMPT, model="gpt-4o-mini") -> ChatMessage:
"""Calls OpenAI's GPT-4o-mini with the query and image data."""
if api_key and api_key.startswith("sk"):
try:
from openai import OpenAI
client = OpenAI(api_key=api_key.strip())
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": _build_query(query, pages)
}
],
max_tokens=500,
)
message = ChatMessage.from_dict(
response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
)
message.raw = response
return message
except Exception as e:
return "OpenAI API connection failure. Verify the provided key is correct (sk-***)."
return "Enter your OpenAI API key to get a custom response"
DEFAULT_CONTEXT_PROMPT = \
"""You are a smart assistant designed to extract context of PDF pages.
Give concise answers, only containing info in the pages you are given.
You can answer using information contained in plots and figures if necessary."""
RAG_SYSTEM_PROMPT = \
""" You are a smart assistant designed to answer questions about a PDF document.
You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context.
Use them to construct a response to the question, and cite your sources.
Use the following citation format:
"Some information from a first document [1, p.Page Number]. Some information from the same first document but at a different page [1, p.Page Number]. Some more information from another document [2, p.Page Number].
...
Sources:
[1] Document Title
[2] Another Document Title"
You can answer using information contained in plots and figures if necessary.
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
Give detailed answers, only containing info in the pages you are given.
Answer in the same language as the query."""
@dataclass
class Metadata:
doc_title: str
page_id: int
context: Optional[str] = None
def __str__(self):
return f"Document: {self.doc_title}, Page ID: {self.page_id}, Context: {self.context}"
@dataclass
class Page:
image: Image.Image
metadata: Optional[Metadata] = None
@property
def caption(self):
if self.metadata is None:
return None
return f"Document: {self.metadata.doc_title}, Context: {self.metadata.context}"
class VisualRAGTool(Tool):
name = "visual_rag"
description = """Performs a RAG query on your internal PDF documents and returns the generated text response."""
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents.",
},
"k": {
"type": "number",
"description": "The number of documents to retrieve.",
"default": 1,
"nullable": True,
},
"api_key": {
"type": "string",
"description": "The OpenAI API key to use for the query. If not provided, the key will be taken from the OPENAI_KEY environment variable.",
"nullable": True,
}
}
output_type = "string"
def _init_models(self, model_name: str) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2"
).eval()
self.processor = ColQwen2Processor.from_pretrained(model_name)
def __init__(self, model_name: str = "vidore/colqwen2-v1.0", api_key: str = None, files: List[str] = None, **kwargs):
super().__init__(**kwargs)
self.model_name = model_name
self.api_key = api_key
self.embds = []
self.pages = []
self.files = files
self._init_models(self.model_name)
self.is_initialized = False
def setup(self):
"""
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
your tool. Such as loading a big model.
"""
if self.files:
_ = self.index(self.files, self.api_key)
self.is_initialized = True
def _extract_contexts(self, images, api_key, window=10) -> List[str]:
"""Extracts context from images."""
try:
args = [
{
'query': "Give the general context about these pages. Give the context in the same language as the documents.",
'pages': [Page(image=im) for im in images[max(i-window+1, 0):i+1]],
'api_key': api_key,
'system_prompt': DEFAULT_CONTEXT_PROMPT
} for i in range(0, len(images), window)
]
window_contexts = pqdm(args, query_openai, n_jobs=8, argument_type='kwargs')
# code sequentially ftm with tqdm
# query = "Give the general context about these pages. Give the context in the same language as the documents."
# window_contexts = [query_openai(query, [Page(image=im) for im in images[max(i-window+1, 0):i+1]], api_key, DEFAULT_CONTEXT_PROMPT)\
# for i in tqdm(range(0, len(images), window))]
contexts = []
for i in range(len(images)):
context = window_contexts[i//window].content
contexts.append(context)
except Exception as e:
print(f"Error extracting contexts: {e}")
contexts = [None for _ in range(len(images))]
# Ensure that the number of contexts is equal to the number of images
assert len(contexts) == len(images)
return contexts
def _process_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> List[Page]:
"""Converts a file to images and extracts metadata."""
title = file.split("/")[-1]
images = convert_from_path(file, thread_count=4)
if contextualize and api_key:
contexts = self._extract_contexts(images, api_key, window=window)
else:
contexts = [None for _ in range(len(images))]
metadatas = [Metadata(doc_title=title, page_id=i, context=contexts[i]) for i in range(len(images))]
return [Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)]
def preprocess(self, files: List[str], contextualize: bool = True, api_key: str = None, window: int = 10) -> List[Page]:
"""Preprocesses the files and extracts metadata."""
pages = [page for file in files for page in self._process_file(file, contextualize=contextualize, api_key=api_key, window=window)]
print(f"Example metadata:\n{pages[0].metadata.context}")
return pages
def _embed_images(self, pages: List[Page]) -> List[torch.Tensor]:
"""Embeds the images using the model."""
"""Example script to run inference with ColPali (ColQwen2)"""
# run inference - docs
dataloader = DataLoader(
pages,
batch_size=4,
shuffle=False,
collate_fn=lambda x: self.processor.process_images([p.image for p in x]).to(self.device),
)
embds = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()}
embeddings_doc = self.model(**batch_doc)
embds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return embds
def index(self, files: List[str], contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int:
print("Converting files...")
# Convert files to images and extract metadata
pgs = self.preprocess(files, contextualize=contextualize, api_key=api_key or self.api_key)
# Embed the images
embds = self._embed_images(pgs)
# Overwrite the database if necessary
if overwrite_db:
self.pages = []
self.embds = []
# Extend the pages
self.pages.extend(pgs)
# Extend the datasets
self.embds.extend(embds)
print(f"Extracted and indexed {len(pgs)} images from {len(files)} files.")
return len(embds)
def retrieve(self, query: str, k: int) -> List[Page]:
"""Retrieve the top k documents based on the query."""
k = min(k, len(self.embds))
qs = []
with torch.no_grad():
batch_query = self.processor.process_queries([query]).to(self.model.device)
embeddings_query = self.model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# Run scoring
scores = self.processor.score(qs, self.embds, device=self.device)[0]
top_k_idx = scores.topk(k).indices.tolist()
print("Top Scores:")
[print(f'Page {self.pages[idx].metadata.page_id}: {scores[idx]}') for idx in top_k_idx]
# Get the top k results
results = [self.pages[idx] for idx in top_k_idx]
return results
def generate_answer(self, query: str, docs: List[Page], api_key: str = None) -> ChatMessage:
result = query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT)
return result
def search(self, query: str, k: int = 1, api_key: str = None) -> Tuple[list, str]:
print(f"Searching for query: {query}")
# Retrieve the top k documents
context = self.retrieve(query, k)
# Generate response from GPT-4o-mini
rag_answer = self.generate_answer(
query=query,
docs=context,
api_key=api_key
)
return context, rag_answer.content
def forward(self, query: str, k: int = 1, api_key: str = None) -> str:
assert isinstance(query, str), "Your search query must be a string"
# Online indexing
# if files:
# _ = self.index(files, api_key)
# Retrieve the top k documents and generate response
_, rag_answer = self.search(
query=query,
files=None,
k=k,
api_key=api_key
)
return rag_answer