Spaces:
Running
Running
File size: 8,535 Bytes
017336c 70f7106 a924f05 d3703b7 017336c d3703b7 70f7106 a924f05 70f7106 a924f05 d3703b7 a924f05 70f7106 a924f05 70f7106 a924f05 d3703b7 a924f05 70f7106 a924f05 70f7106 a924f05 70f7106 a924f05 0601caa a924f05 70f7106 a924f05 70f7106 a924f05 70f7106 9d843a3 a924f05 7a252b5 a924f05 70f7106 a924f05 70f7106 a924f05 70f7106 a924f05 70f7106 a924f05 70f7106 a924f05 70f7106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import os
from smolagents import Tool
from dotenv import load_dotenv
load_dotenv()
class VisualRAGTool(Tool):
name = "visual_rag"
description = """Performs a RAG query on your internal PDF documents and returns the generated text response."""
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents.",
},
"k": {
"type": "number",
"description": "The number of documents to retrieve.",
"default": 1,
"nullable": True,
},
"api_key": {
"type": "string",
"description": "The OpenAI API key to use for the query. If not provided, the key will be taken from the OPENAI_KEY environment variable.",
"nullable": True,
}
}
output_type = "string"
model_name: str = "vidore/colqwen2-v1.0"
api_key: str = os.getenv("OPENAI_KEY")
def __init__(self, *args, **kwargs):
self.is_initialized = False
def _init_models(self, model_name: str) -> None:
import torch
from colpali_engine.models import ColQwen2, ColQwen2Processor
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2"
).eval()
self.processor = ColQwen2Processor.from_pretrained(model_name)
def setup(self):
"""
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
your tool. Such as loading a big model.
"""
self._init_models(self.model_name)
self.embds = []
self.pages = []
self.is_initialized = True
def _extract_contexts(self, images, api_key, window=10) -> list:
"""Extracts context from images."""
from utils import query_openai, Page, CONTEXT_SYSTEM_PROMPT
from pqdm.processes import pqdm
try:
args = [
{
'query': "Give the general context about these pages. Give the context in the same language as the documents.",
'pages': [Page(image=im) for im in images[max(i-window+1, 0):i+1]],
'api_key': api_key,
'system_prompt': CONTEXT_SYSTEM_PROMPT,
} for i in range(0, len(images), window)
]
window_contexts = pqdm(args, query_openai, n_jobs=8, argument_type='kwargs')
# code sequentially ftm with tqdm
# query = "Give the general context about these pages. Give the context in the same language as the documents."
# window_contexts = [query_openai(query, [Page(image=im) for im in images[max(i-window+1, 0):i+1]], api_key, DEFAULT_CONTEXT_PROMPT)\
# for i in tqdm(range(0, len(images), window))]
contexts = []
for i in range(len(images)):
context = window_contexts[i//window].content
contexts.append(context)
except Exception as e:
print(f"Error extracting contexts: {e}")
contexts = [None for _ in range(len(images))]
# Ensure that the number of contexts is equal to the number of images
assert len(contexts) == len(images)
return contexts
def _preprocess_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
"""Converts a file to images and extracts metadata."""
from pdf2image import convert_from_path
from utils import Metadata, Page
title = file.split("/")[-1]
images = convert_from_path(file, thread_count=4)
if contextualize and api_key:
contexts = self._extract_contexts(images, api_key, window=window)
else:
contexts = [None for _ in range(len(images))]
metadatas = [Metadata(doc_title=title, page_id=i, context=contexts[i]) for i in range(len(images))]
return [Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)]
def preprocess(self, files: list, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
"""Preprocesses the files and extracts metadata."""
pages = [page for file in files for page in self._preprocess_file(file, contextualize=contextualize, api_key=api_key, window=window)]
print(f"Example metadata:\n{pages[0].metadata.context}")
return pages
def compute_embeddings(self, pages) -> list:
"""Embeds the images using the model."""
"""Example script to run inference with ColPali (ColQwen2)"""
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
# run inference - docs
dataloader = DataLoader(
pages,
batch_size=4,
shuffle=False,
collate_fn=lambda x: self.processor.process_images([p.image for p in x]).to(self.device),
)
embds = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()}
embeddings_doc = self.model(**batch_doc)
embds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return embds
def index(self, files: list, contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int:
"""Indexes the uploaded files."""
if not self.is_initialized:
self.setup()
print("Converting files...")
# Convert files to images and extract metadata
pgs = self.preprocess(files, contextualize=contextualize, api_key=api_key or self.api_key)
# Embed the images
embds = self.compute_embeddings(pgs)
# Overwrite the database if necessary
if overwrite_db:
self.pages = []
self.embds = []
# Extend the pages
self.pages.extend(pgs)
# Extend the datasets
self.embds.extend(embds)
print(f"Extracted and indexed {len(pgs)} images from {len(files)} files.")
return len(embds)
def retrieve(self, query: str, k: int) -> list:
"""Retrieve the top k documents based on the query."""
import torch
k = min(k, len(self.embds))
qs = []
with torch.no_grad():
batch_query = self.processor.process_queries([query]).to(self.model.device)
embeddings_query = self.model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# Run scoring
scores = self.processor.score(qs, self.embds, device=self.device)[0]
top_k_idx = scores.topk(k).indices.tolist()
print("Top Scores:")
[print(f'Page {self.pages[idx].metadata.page_id}: {scores[idx]}') for idx in top_k_idx]
# Get the top k results
results = [self.pages[idx] for idx in top_k_idx]
return results
def generate_answer(self, query: str, docs: list, api_key: str = None):
"""Generates an answer based on the query and the retrieved documents."""
from utils import query_openai, RAG_SYSTEM_PROMPT
result = query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT)
return result
def search(self, query: str, k: int = 1, api_key: str = None) -> tuple:
"""Searches for the most relevant pages based on the query."""
print(f"Searching for query: {query}")
# Retrieve the top k documents
context = self.retrieve(query, k)
# Generate response from GPT-4o-mini
rag_answer = self.generate_answer(
query=query,
docs=context,
api_key=api_key
)
return context, rag_answer.content
def forward(self, query: str, k: int = 1, api_key: str = None) -> str:
assert isinstance(query, str), "Your search query must be a string"
# Online indexing
# if files:
# _ = self.index(files, api_key)
# Retrieve the top k documents and generate response
return self.search(
query=query,
k=k,
api_key=api_key
)[1] |