File size: 8,535 Bytes
017336c
 
70f7106
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3703b7
017336c
d3703b7
70f7106
 
 
a924f05
70f7106
 
 
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
d3703b7
 
 
 
a924f05
 
 
70f7106
a924f05
70f7106
 
a924f05
 
 
 
 
 
d3703b7
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
a924f05
70f7106
 
 
a924f05
 
 
 
 
 
 
 
 
 
70f7106
a924f05
0601caa
a924f05
 
 
 
 
70f7106
a924f05
 
70f7106
 
 
 
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
 
9d843a3
 
 
a924f05
 
 
 
 
7a252b5
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
a924f05
70f7106
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
 
 
a924f05
 
 
70f7106
 
a924f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70f7106
a924f05
 
 
70f7106
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os

from smolagents import Tool

from dotenv import load_dotenv
load_dotenv()

class VisualRAGTool(Tool):
    name = "visual_rag"
    description = """Performs a RAG query on your internal PDF documents and returns the generated text response."""
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform. This should be semantically close to your target documents.",
        },
        "k": {
            "type": "number",
            "description": "The number of documents to retrieve.",
            "default": 1,
            "nullable": True,
        },
        "api_key": {
            "type": "string",
            "description": "The OpenAI API key to use for the query. If not provided, the key will be taken from the OPENAI_KEY environment variable.",
            "nullable": True,
        }
    }
    output_type = "string"

    model_name: str = "vidore/colqwen2-v1.0"
    api_key: str = os.getenv("OPENAI_KEY")

    def __init__(self, *args, **kwargs):
        self.is_initialized = False

    def _init_models(self, model_name: str) -> None:
        import torch
        from colpali_engine.models import ColQwen2, ColQwen2Processor

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = ColQwen2.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                attn_implementation="flash_attention_2"
            ).eval()
        self.processor = ColQwen2Processor.from_pretrained(model_name)
    
    def setup(self):
        """
        Overwrite this method here for any operation that is expensive and needs to be executed before you start using
        your tool. Such as loading a big model.
        """
        self._init_models(self.model_name)

        self.embds = []
        self.pages = []

        self.is_initialized = True

    def _extract_contexts(self, images, api_key, window=10) -> list:
        """Extracts context from images."""
        from utils import query_openai, Page, CONTEXT_SYSTEM_PROMPT
        from pqdm.processes import pqdm
        try:
            args = [
                {
                    'query': "Give the general context about these pages. Give the context in the same language as the documents.",
                    'pages': [Page(image=im) for im in images[max(i-window+1, 0):i+1]],
                    'api_key': api_key,
                    'system_prompt': CONTEXT_SYSTEM_PROMPT,
                } for i in range(0, len(images), window)
            ]
            window_contexts = pqdm(args, query_openai, n_jobs=8, argument_type='kwargs')

            # code sequentially ftm with tqdm
            # query = "Give the general context about these pages. Give the context in the same language as the documents."
            # window_contexts = [query_openai(query, [Page(image=im) for im in images[max(i-window+1, 0):i+1]], api_key, DEFAULT_CONTEXT_PROMPT)\
            #                      for i in tqdm(range(0, len(images), window))]

            contexts = []
            for i in range(len(images)):
                context = window_contexts[i//window].content
                contexts.append(context)

        except Exception as e:
            print(f"Error extracting contexts: {e}")
            contexts = [None for _ in range(len(images))]

        # Ensure that the number of contexts is equal to the number of images
        assert len(contexts) == len(images)
        
        return contexts

    def _preprocess_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
        """Converts a file to images and extracts metadata."""
        from pdf2image import convert_from_path
        from utils import Metadata, Page

        title = file.split("/")[-1]
        images = convert_from_path(file, thread_count=4)
        if contextualize and api_key:
            contexts = self._extract_contexts(images, api_key, window=window)
        else:
            contexts = [None for _ in range(len(images))]
        metadatas = [Metadata(doc_title=title, page_id=i, context=contexts[i]) for i in range(len(images))]

        return [Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)]

    def preprocess(self, files: list, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
        """Preprocesses the files and extracts metadata."""
        pages = [page for file in files for page in self._preprocess_file(file, contextualize=contextualize, api_key=api_key, window=window)]     

        print(f"Example metadata:\n{pages[0].metadata.context}")

        return pages
    
    def compute_embeddings(self, pages) -> list:
        """Embeds the images using the model."""
        """Example script to run inference with ColPali (ColQwen2)"""
        import torch
        from torch.utils.data import DataLoader
        from tqdm import tqdm

        # run inference - docs
        dataloader = DataLoader(
            pages,
            batch_size=4,
            shuffle=False,
            collate_fn=lambda x: self.processor.process_images([p.image for p in x]).to(self.device),
        )

        embds = []

        for batch_doc in tqdm(dataloader):
            with torch.no_grad():
                batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()}
                embeddings_doc = self.model(**batch_doc)
            embds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

        return embds

    def index(self, files: list, contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int:
        """Indexes the uploaded files."""
        if not self.is_initialized:
            self.setup()
        
        print("Converting files...")
        # Convert files to images and extract metadata
        pgs = self.preprocess(files, contextualize=contextualize, api_key=api_key or self.api_key)

        # Embed the images
        embds = self.compute_embeddings(pgs)

        # Overwrite the database if necessary
        if overwrite_db:
            self.pages = []
            self.embds = []
        
        # Extend the pages
        self.pages.extend(pgs)

        # Extend the datasets
        self.embds.extend(embds)

        print(f"Extracted and indexed {len(pgs)} images from {len(files)} files.")
        
        return len(embds)

    def retrieve(self, query: str, k: int) -> list:
        """Retrieve the top k documents based on the query."""
        import torch
        k = min(k, len(self.embds))

        qs = []
        with torch.no_grad():
            batch_query = self.processor.process_queries([query]).to(self.model.device)
            embeddings_query = self.model(**batch_query)
            qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

        # Run scoring
        scores = self.processor.score(qs, self.embds, device=self.device)[0]
        top_k_idx = scores.topk(k).indices.tolist()

        print("Top Scores:")
        [print(f'Page {self.pages[idx].metadata.page_id}: {scores[idx]}') for idx in top_k_idx]

        # Get the top k results
        results = [self.pages[idx] for idx in top_k_idx]

        return results
        
    def generate_answer(self, query: str, docs: list, api_key: str = None):
        """Generates an answer based on the query and the retrieved documents."""
        from utils import query_openai, RAG_SYSTEM_PROMPT
        result = query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT)
        return result

    def search(self, query: str, k: int = 1, api_key: str = None) -> tuple:
        """Searches for the most relevant pages based on the query."""
        print(f"Searching for query: {query}")
    
        # Retrieve the top k documents
        context = self.retrieve(query, k)

        # Generate response from GPT-4o-mini
        rag_answer = self.generate_answer(
            query=query, 
            docs=context, 
            api_key=api_key
        )

        return context, rag_answer.content

    def forward(self, query: str, k: int = 1, api_key: str = None) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        # Online indexing
        # if files:
        #     _ = self.index(files, api_key)

        # Retrieve the top k documents and generate response
        return self.search(
            query=query, 
            k=k, 
            api_key=api_key
        )[1]