File size: 13,792 Bytes
017336c
 
70f7106
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3703b7
017336c
c502633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3703b7
70f7106
 
 
a924f05
70f7106
 
 
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
d3703b7
 
 
 
a924f05
 
 
c502633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
a924f05
c502633
 
 
 
 
 
 
a924f05
 
 
 
c502633
a924f05
d3703b7
a924f05
 
c502633
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
a924f05
70f7106
 
a924f05
 
 
 
 
 
c502633
a924f05
c502633
a924f05
70f7106
a924f05
0601caa
a924f05
c502633
a924f05
 
 
70f7106
a924f05
 
70f7106
 
 
 
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
 
9d843a3
 
 
a924f05
 
 
 
 
7a252b5
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
a924f05
70f7106
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
6c0b466
a924f05
 
 
 
 
 
70f7106
 
c502633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a924f05
 
70f7106
 
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
a924f05
 
 
70f7106
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import os

from smolagents import Tool

from dotenv import load_dotenv
load_dotenv()

class VisualRAGTool(Tool):
    name = "visual_rag"
    description = """Performs a RAG query on your internal PDF documents and returns the generated text response."""
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform. This should be semantically close to your target documents.",
        },
        "k": {
            "type": "number",
            "description": "The number of documents to retrieve.",
            "default": 1,
            "nullable": True,
        },
        "api_key": {
            "type": "string",
            "description": "The OpenAI API key to use for the query. If not provided, the key will be taken from the OPENAI_KEY environment variable.",
            "nullable": True,
        }
    }
    output_type = "string"

    model_name: str = "vidore/colqwen2-v1.0"
    api_key: str = os.getenv("OPENAI_KEY")
        
    class Page:
        from typing import Optional, Dict, Any
        from PIL import Image
    
        image: Image.Image
        metadata: Optional[Dict[str, Any]] = None

        def __init__(self, image, metadata=None):
            self.image = image
            self.metadata = metadata

        @property
        def caption(self):
            if self.metadata is None:
                return None
            return f"Document: {self.metadata.get('doc_title')}, Context: {self.metadata.get('context')}"
        
        def __hash__(self):
            return hash(self.image)

    def __init__(self, *args, **kwargs):
        self.is_initialized = False

    def _init_models(self, model_name: str) -> None:
        import torch
        from colpali_engine.models import ColQwen2, ColQwen2Processor

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = ColQwen2.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                attn_implementation="flash_attention_2"
            ).eval()
        self.processor = ColQwen2Processor.from_pretrained(model_name)
    
    def setup(self):
        """
        Overwrite this method here for any operation that is expensive and needs to be executed before you start using
        your tool. Such as loading a big model.
        """
        self._init_models(self.model_name)

        self.embds = []
        self.pages = []

        self.is_initialized = True

    def _encode_image_to_base64(self, image):
        """Encodes a PIL image to a base64 string."""
        from io import BytesIO
        import base64

        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    def _build_query(self, query: str, pages: list) -> list:
        """Builds the query for OpenAI based on the pages and the query."""
        messages = []
        messages.append({"type": "text", "text": "PDF pages:\n"})
        for page in pages:
            capt = page.caption
            if capt is not None:
                messages.append({
                        "type": "text",
                        "text": capt
                    })
            messages.append({
                    "type": "image_url",
                    "image_url": {
                    "url": f"data:image/jpeg;base64,{self._encode_image_to_base64(page.image)}"
                    },
                })
        messages.append({"type": "text", "text": f"Query:\n{query}"})

        return messages

    def query_openai(self, query, pages, api_key=None, system_prompt=None, model="gpt-4o-mini"):
        """Calls OpenAI's GPT-4o-mini with the query and image data."""
        from smolagents import ChatMessage

        system_prompt = system_prompt or \
        """You are a smart assistant designed to answer questions about a PDF document.
            You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context. 
            Use them to construct a short response to the question, and cite your sources in the following format: (document, page number).
            If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
            Give detailed and extensive answers, only containing info in the pages you are given.
            You can answer using information contained in plots and figures if necessary.
            Answer in the same language as the query."""
        
        api_key = api_key or self.api_key

        if api_key and api_key.startswith("sk"):
            try:
                from openai import OpenAI
        
                client = OpenAI(api_key=api_key.strip())

                response = client.chat.completions.create(
                    model=model,
                    messages=[
                        {
                            "role": "system",
                            "content": system_prompt
                        },
                        {
                        "role": "user",
                        "content": self._build_query(query, pages)
                        }
                    ],
                    max_tokens=500,
                )

                message = ChatMessage.from_dict(
                    response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
                )
                message.raw = response

                return message
            
            except Exception as e:
                return "OpenAI API connection failure. Verify the provided key is correct (sk-***)."
            
        return "Enter your OpenAI API key to get a custom response"

    def _extract_contexts(self, images, api_key, window=10) -> list:
        """Extracts context from images."""
        from pqdm.threads import pqdm

        CONTEXT_SYSTEM_PROMPT = \
        """You are a smart assistant designed to extract context of PDF pages.
            Give concise answers, only containing info in the pages you are given.
            You can answer using information contained in plots and figures if necessary."""

        try:
            args = [
                {
                    'query': "Give the general context about these pages. Give the context in the same language as the documents.",
                    'pages': [self.Page(image=im) for im in images[max(i-window+1, 0):i+1]],
                    'api_key': api_key,
                    'system_prompt': CONTEXT_SYSTEM_PROMPT,
                } for i in range(0, len(images), window)
            ]
            window_contexts = pqdm(args, self.query_openai, n_jobs=8, argument_type='kwargs')

            # code sequentially ftm with tqdm
            # query = "Give the general context about these pages. Give the context in the same language as the documents."
            # window_contexts = [query_openai(query, [Page(image=im) for im in images[max(i-window+1, 0):i+1]], api_key, DEFAULT_CONTEXT_PROMPT)\
            #                      for i in tqdm(range(0, len(images), window))]

            contexts = []
            for i in range(len(images)):
                context = window_contexts[i//window].content
                contexts.append(context)

        except Exception as e:
            print(f"Error extracting contexts: {e}")
            contexts = [None for _ in range(len(images))]

        # Ensure that the number of contexts is equal to the number of images
        assert len(contexts) == len(images)
        
        return contexts

    def _preprocess_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
        """Converts a file to images and extracts metadata."""
        from pdf2image import convert_from_path

        title = file.split("/")[-1]
        images = convert_from_path(file, thread_count=4)
        if contextualize and api_key:
            contexts = self._extract_contexts(images, api_key, window=window)
        else:
            contexts = [None for _ in range(len(images))]
        metadatas = [{'doc_title': title, 'page_id': i, 'context': contexts[i]} for i in range(len(images))]

        return [self.Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)]

    def preprocess(self, files: list, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
        """Preprocesses the files and extracts metadata."""
        pages = [page for file in files for page in self._preprocess_file(file, contextualize=contextualize, api_key=api_key, window=window)]     

        print(f"Example metadata:\n{pages[0].metadata.get('context')}")

        return pages
    
    def compute_embeddings(self, pages) -> list:
        """Embeds the images using the model."""
        """Example script to run inference with ColPali (ColQwen2)"""
        import torch
        from torch.utils.data import DataLoader
        from tqdm import tqdm

        # run inference - docs
        dataloader = DataLoader(
            pages,
            batch_size=4,
            shuffle=False,
            collate_fn=lambda x: self.processor.process_images([p.image for p in x]).to(self.device),
        )

        embds = []

        for batch_doc in tqdm(dataloader):
            with torch.no_grad():
                batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()}
                embeddings_doc = self.model(**batch_doc)
            embds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

        return embds

    def index(self, files: list, contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int:
        """Indexes the uploaded files."""
        if not self.is_initialized:
            self.setup()
        
        print("Converting files...")
        # Convert files to images and extract metadata
        pgs = self.preprocess(files, contextualize=contextualize, api_key=api_key or self.api_key)

        # Embed the images
        embds = self.compute_embeddings(pgs)

        # Overwrite the database if necessary
        if overwrite_db:
            self.pages = []
            self.embds = []
        
        # Extend the pages
        self.pages.extend(pgs)

        # Extend the datasets
        self.embds.extend(embds)

        print(f"Extracted and indexed {len(pgs)} images from {len(files)} files.")
        
        return len(embds)

    def retrieve(self, query: str, k: int) -> list:
        """Retrieve the top k documents based on the query."""
        import torch
        k = min(k, len(self.embds))

        qs = []
        with torch.no_grad():
            batch_query = self.processor.process_queries([query]).to(self.model.device)
            embeddings_query = self.model(**batch_query)
            qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

        # Run scoring
        scores = self.processor.score(qs, self.embds, device=self.device)[0]
        top_k_idx = scores.topk(k).indices.tolist()

        print("Top Scores:")
        [print(f"Page {self.pages[idx].metadata.get('page_id')}: {scores[idx]}") for idx in top_k_idx]

        # Get the top k results
        results = [self.pages[idx] for idx in top_k_idx]

        return results
        
    def generate_answer(self, query: str, docs: list, api_key: str = None):
        """Generates an answer based on the query and the retrieved documents."""

        RAG_SYSTEM_PROMPT = \
        """ You are a smart assistant designed to answer questions about a PDF document.

            You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context.
            Use them to construct a response to the question, and cite your sources.
            Use the following citation format:
            "Some information from a first document [1, p.Page Number]. Some information from the same first document but at a different page [1, p.Page Number]. Some more information from another document [2, p.Page Number].
            ...
            Sources:
            [1] Document Title
            [2] Another Document Title"

            You can answer using information contained in plots and figures if necessary.
            If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
            Give detailed answers, only containing info in the pages you are given.
            Answer in the same language as the query."""
        
        result = self.query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT)
        return result

    def search(self, query: str, k: int = 1, api_key: str = None) -> tuple:
        """Searches for the most relevant pages based on the query."""
        print(f"Searching for query: {query}")
    
        # Retrieve the top k documents
        context = self.retrieve(query, k)

        # Generate response from GPT-4o-mini
        rag_answer = self.generate_answer(
            query=query, 
            docs=context, 
            api_key=api_key
        )

        return context, rag_answer.content

    def forward(self, query: str, k: int = 1, api_key: str = None) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        # Online indexing
        # if files:
        #     _ = self.index(files, api_key)

        # Retrieve the top k documents and generate response
        return self.search(
            query=query, 
            k=k, 
            api_key=api_key
        )[1]