paultltc commited on
Commit
70f7106
·
1 Parent(s): 7a252b5

refactor to follow tool validation

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .venv
__pycache__/tool.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (6.34 kB). View file
 
tool.py CHANGED
@@ -1,141 +1,10 @@
1
  import os
2
 
3
- from dataclasses import dataclass
4
- from typing import List, Optional, Tuple
5
-
6
- import torch
7
- from torch.utils.data import DataLoader, Dataset
8
-
9
- import base64
10
- from io import BytesIO
11
- from PIL import Image
12
- from pdf2image import convert_from_path
13
-
14
- from tqdm import tqdm
15
- from pqdm.processes import pqdm
16
-
17
- from colpali_engine.models import ColQwen2, ColQwen2Processor
18
-
19
- from smolagents import Tool, ChatMessage
20
 
21
  from dotenv import load_dotenv
22
  load_dotenv()
23
 
24
- def encode_image_to_base64(image):
25
- """Encodes a PIL image to a base64 string."""
26
- buffered = BytesIO()
27
- image.save(buffered, format="JPEG")
28
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
29
-
30
- DEFAULT_SYSTEM_PROMPT = \
31
- """You are a smart assistant designed to answer questions about a PDF document.
32
- You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context.
33
- Use them to construct a short response to the question, and cite your sources in the following format: (document, page number).
34
- If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
35
- Give detailed and extensive answers, only containing info in the pages you are given.
36
- You can answer using information contained in plots and figures if necessary.
37
- Answer in the same language as the query."""
38
-
39
- def _build_query(query, pages):
40
- messages = []
41
- messages.append({"type": "text", "text": "PDF pages:\n"})
42
- for page in pages:
43
- capt = page.caption
44
- if capt is not None:
45
- messages.append({
46
- "type": "text",
47
- "text": capt
48
- })
49
- messages.append({
50
- "type": "image_url",
51
- "image_url": {
52
- "url": f"data:image/jpeg;base64,{encode_image_to_base64(page.image)}"
53
- },
54
- })
55
- messages.append({"type": "text", "text": f"Query:\n{query}"})
56
-
57
- return messages
58
-
59
- def query_openai(query, pages, api_key=None, system_prompt=DEFAULT_SYSTEM_PROMPT, model="gpt-4o-mini") -> ChatMessage:
60
- """Calls OpenAI's GPT-4o-mini with the query and image data."""
61
- if api_key and api_key.startswith("sk"):
62
- try:
63
- from openai import OpenAI
64
-
65
- client = OpenAI(api_key=api_key.strip())
66
-
67
- response = client.chat.completions.create(
68
- model=model,
69
- messages=[
70
- {
71
- "role": "system",
72
- "content": system_prompt
73
- },
74
- {
75
- "role": "user",
76
- "content": _build_query(query, pages)
77
- }
78
- ],
79
- max_tokens=500,
80
- )
81
-
82
- message = ChatMessage.from_dict(
83
- response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
84
- )
85
- message.raw = response
86
-
87
- return message
88
-
89
- except Exception as e:
90
- return "OpenAI API connection failure. Verify the provided key is correct (sk-***)."
91
-
92
- return "Enter your OpenAI API key to get a custom response"
93
-
94
- CONTEXT_SYSTEM_PROMPT = \
95
- """You are a smart assistant designed to extract context of PDF pages.
96
- Give concise answers, only containing info in the pages you are given.
97
- You can answer using information contained in plots and figures if necessary."""
98
-
99
- RAG_SYSTEM_PROMPT = \
100
- """ You are a smart assistant designed to answer questions about a PDF document.
101
-
102
- You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context.
103
- Use them to construct a response to the question, and cite your sources.
104
- Use the following citation format:
105
- "Some information from a first document [1, p.Page Number]. Some information from the same first document but at a different page [1, p.Page Number]. Some more information from another document [2, p.Page Number].
106
- ...
107
- Sources:
108
- [1] Document Title
109
- [2] Another Document Title"
110
-
111
- You can answer using information contained in plots and figures if necessary.
112
- If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
113
- Give detailed answers, only containing info in the pages you are given.
114
- Answer in the same language as the query."""
115
-
116
- @dataclass
117
- class Metadata:
118
- doc_title: str
119
- page_id: int
120
- context: Optional[str] = None
121
-
122
- def __str__(self):
123
- return f"Document: {self.doc_title}, Page ID: {self.page_id}, Context: {self.context}"
124
-
125
- @dataclass
126
- class Page:
127
- image: Image.Image
128
- metadata: Optional[Metadata] = None
129
-
130
- @property
131
- def caption(self):
132
- if self.metadata is None:
133
- return None
134
- return f"Document: {self.metadata.doc_title}, Context: {self.metadata.context}"
135
-
136
- def __hash__(self):
137
- return hash(self.image)
138
-
139
  class VisualRAGTool(Tool):
140
  name = "visual_rag"
141
  description = """Performs a RAG query on your internal PDF documents and returns the generated text response."""
@@ -161,7 +30,13 @@ class VisualRAGTool(Tool):
161
  model_name: str = "vidore/colqwen2-v1.0"
162
  api_key: str = os.getenv("OPENAI_KEY")
163
 
 
 
 
164
  def _init_models(self, model_name: str) -> None:
 
 
 
165
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
166
  self.model = ColQwen2.from_pretrained(
167
  model_name,
@@ -170,9 +45,6 @@ class VisualRAGTool(Tool):
170
  attn_implementation="flash_attention_2"
171
  ).eval()
172
  self.processor = ColQwen2Processor.from_pretrained(model_name)
173
-
174
- def __init__(self, *args, **kwargs):
175
- self.is_initialized = False
176
 
177
  def setup(self):
178
  """
@@ -186,8 +58,10 @@ class VisualRAGTool(Tool):
186
 
187
  self.is_initialized = True
188
 
189
- def _extract_contexts(self, images, api_key, window=10) -> List[str]:
190
  """Extracts context from images."""
 
 
191
  try:
192
  args = [
193
  {
@@ -218,8 +92,11 @@ class VisualRAGTool(Tool):
218
 
219
  return contexts
220
 
221
- def _preprocess_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> List[Page]:
222
  """Converts a file to images and extracts metadata."""
 
 
 
223
  title = file.split("/")[-1]
224
  images = convert_from_path(file, thread_count=4)
225
  if contextualize and api_key:
@@ -230,7 +107,7 @@ class VisualRAGTool(Tool):
230
 
231
  return [Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)]
232
 
233
- def preprocess(self, files: List[str], contextualize: bool = True, api_key: str = None, window: int = 10) -> List[Page]:
234
  """Preprocesses the files and extracts metadata."""
235
  pages = [page for file in files for page in self._preprocess_file(file, contextualize=contextualize, api_key=api_key, window=window)]
236
 
@@ -238,9 +115,13 @@ class VisualRAGTool(Tool):
238
 
239
  return pages
240
 
241
- def compute_embeddings(self, pages: List[Page]) -> List[torch.Tensor]:
242
  """Embeds the images using the model."""
243
  """Example script to run inference with ColPali (ColQwen2)"""
 
 
 
 
244
  # run inference - docs
245
  dataloader = DataLoader(
246
  pages,
@@ -259,7 +140,8 @@ class VisualRAGTool(Tool):
259
 
260
  return embds
261
 
262
- def index(self, files: List[str], contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int:
 
263
  if not self.is_initialized:
264
  self.setup()
265
 
@@ -285,8 +167,9 @@ class VisualRAGTool(Tool):
285
 
286
  return len(embds)
287
 
288
- def retrieve(self, query: str, k: int) -> List[Page]:
289
  """Retrieve the top k documents based on the query."""
 
290
  k = min(k, len(self.embds))
291
 
292
  qs = []
@@ -307,11 +190,14 @@ class VisualRAGTool(Tool):
307
 
308
  return results
309
 
310
- def generate_answer(self, query: str, docs: List[Page], api_key: str = None) -> ChatMessage:
 
 
311
  result = query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT)
312
  return result
313
 
314
- def search(self, query: str, k: int = 1, api_key: str = None) -> Tuple[list, str]:
 
315
  print(f"Searching for query: {query}")
316
 
317
  # Retrieve the top k documents
@@ -334,11 +220,9 @@ class VisualRAGTool(Tool):
334
  # _ = self.index(files, api_key)
335
 
336
  # Retrieve the top k documents and generate response
337
- _, rag_answer = self.search(
338
  query=query,
339
  files=None,
340
  k=k,
341
  api_key=api_key
342
- )
343
-
344
- return rag_answer
 
1
  import os
2
 
3
+ from smolagents import Tool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  from dotenv import load_dotenv
6
  load_dotenv()
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class VisualRAGTool(Tool):
9
  name = "visual_rag"
10
  description = """Performs a RAG query on your internal PDF documents and returns the generated text response."""
 
30
  model_name: str = "vidore/colqwen2-v1.0"
31
  api_key: str = os.getenv("OPENAI_KEY")
32
 
33
+ def __init__(self, *args, **kwargs):
34
+ self.is_initialized = False
35
+
36
  def _init_models(self, model_name: str) -> None:
37
+ import torch
38
+ from colpali_engine.models import ColQwen2, ColQwen2Processor
39
+
40
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
  self.model = ColQwen2.from_pretrained(
42
  model_name,
 
45
  attn_implementation="flash_attention_2"
46
  ).eval()
47
  self.processor = ColQwen2Processor.from_pretrained(model_name)
 
 
 
48
 
49
  def setup(self):
50
  """
 
58
 
59
  self.is_initialized = True
60
 
61
+ def _extract_contexts(self, images, api_key, window=10) -> list:
62
  """Extracts context from images."""
63
+ from utils import query_openai, Page, CONTEXT_SYSTEM_PROMPT
64
+ from pqdm.processes import pqdm
65
  try:
66
  args = [
67
  {
 
92
 
93
  return contexts
94
 
95
+ def _preprocess_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
96
  """Converts a file to images and extracts metadata."""
97
+ from pdf2image import convert_from_path
98
+ from utils import Metadata, Page
99
+
100
  title = file.split("/")[-1]
101
  images = convert_from_path(file, thread_count=4)
102
  if contextualize and api_key:
 
107
 
108
  return [Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)]
109
 
110
+ def preprocess(self, files: list, contextualize: bool = True, api_key: str = None, window: int = 10) -> list:
111
  """Preprocesses the files and extracts metadata."""
112
  pages = [page for file in files for page in self._preprocess_file(file, contextualize=contextualize, api_key=api_key, window=window)]
113
 
 
115
 
116
  return pages
117
 
118
+ def compute_embeddings(self, pages) -> list:
119
  """Embeds the images using the model."""
120
  """Example script to run inference with ColPali (ColQwen2)"""
121
+ import torch
122
+ from torch.utils.data import DataLoader
123
+ from tqdm import tqdm
124
+
125
  # run inference - docs
126
  dataloader = DataLoader(
127
  pages,
 
140
 
141
  return embds
142
 
143
+ def index(self, files: list, contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int:
144
+ """Indexes the uploaded files."""
145
  if not self.is_initialized:
146
  self.setup()
147
 
 
167
 
168
  return len(embds)
169
 
170
+ def retrieve(self, query: str, k: int) -> list:
171
  """Retrieve the top k documents based on the query."""
172
+ import torch
173
  k = min(k, len(self.embds))
174
 
175
  qs = []
 
190
 
191
  return results
192
 
193
+ def generate_answer(self, query: str, docs: list, api_key: str = None):
194
+ """Generates an answer based on the query and the retrieved documents."""
195
+ from utils import query_openai, RAG_SYSTEM_PROMPT
196
  result = query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT)
197
  return result
198
 
199
+ def search(self, query: str, k: int = 1, api_key: str = None) -> tuple:
200
+ """Searches for the most relevant pages based on the query."""
201
  print(f"Searching for query: {query}")
202
 
203
  # Retrieve the top k documents
 
220
  # _ = self.index(files, api_key)
221
 
222
  # Retrieve the top k documents and generate response
223
+ return self.search(
224
  query=query,
225
  files=None,
226
  k=k,
227
  api_key=api_key
228
+ )[1]
 
 
utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple
3
+
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ from smolagents import ChatMessage
10
+
11
+ def encode_image_to_base64(image):
12
+ """Encodes a PIL image to a base64 string."""
13
+ buffered = BytesIO()
14
+ image.save(buffered, format="JPEG")
15
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
16
+
17
+ DEFAULT_SYSTEM_PROMPT = \
18
+ """You are a smart assistant designed to answer questions about a PDF document.
19
+ You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context.
20
+ Use them to construct a short response to the question, and cite your sources in the following format: (document, page number).
21
+ If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
22
+ Give detailed and extensive answers, only containing info in the pages you are given.
23
+ You can answer using information contained in plots and figures if necessary.
24
+ Answer in the same language as the query."""
25
+
26
+ def _build_query(query, pages):
27
+ messages = []
28
+ messages.append({"type": "text", "text": "PDF pages:\n"})
29
+ for page in pages:
30
+ capt = page.caption
31
+ if capt is not None:
32
+ messages.append({
33
+ "type": "text",
34
+ "text": capt
35
+ })
36
+ messages.append({
37
+ "type": "image_url",
38
+ "image_url": {
39
+ "url": f"data:image/jpeg;base64,{encode_image_to_base64(page.image)}"
40
+ },
41
+ })
42
+ messages.append({"type": "text", "text": f"Query:\n{query}"})
43
+
44
+ return messages
45
+
46
+ def query_openai(query, pages, api_key=None, system_prompt=DEFAULT_SYSTEM_PROMPT, model="gpt-4o-mini") -> ChatMessage:
47
+ """Calls OpenAI's GPT-4o-mini with the query and image data."""
48
+ if api_key and api_key.startswith("sk"):
49
+ try:
50
+ from openai import OpenAI
51
+
52
+ client = OpenAI(api_key=api_key.strip())
53
+
54
+ response = client.chat.completions.create(
55
+ model=model,
56
+ messages=[
57
+ {
58
+ "role": "system",
59
+ "content": system_prompt
60
+ },
61
+ {
62
+ "role": "user",
63
+ "content": _build_query(query, pages)
64
+ }
65
+ ],
66
+ max_tokens=500,
67
+ )
68
+
69
+ message = ChatMessage.from_dict(
70
+ response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
71
+ )
72
+ message.raw = response
73
+
74
+ return message
75
+
76
+ except Exception as e:
77
+ return "OpenAI API connection failure. Verify the provided key is correct (sk-***)."
78
+
79
+ return "Enter your OpenAI API key to get a custom response"
80
+
81
+ CONTEXT_SYSTEM_PROMPT = \
82
+ """You are a smart assistant designed to extract context of PDF pages.
83
+ Give concise answers, only containing info in the pages you are given.
84
+ You can answer using information contained in plots and figures if necessary."""
85
+
86
+ RAG_SYSTEM_PROMPT = \
87
+ """ You are a smart assistant designed to answer questions about a PDF document.
88
+
89
+ You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context.
90
+ Use them to construct a response to the question, and cite your sources.
91
+ Use the following citation format:
92
+ "Some information from a first document [1, p.Page Number]. Some information from the same first document but at a different page [1, p.Page Number]. Some more information from another document [2, p.Page Number].
93
+ ...
94
+ Sources:
95
+ [1] Document Title
96
+ [2] Another Document Title"
97
+
98
+ You can answer using information contained in plots and figures if necessary.
99
+ If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
100
+ Give detailed answers, only containing info in the pages you are given.
101
+ Answer in the same language as the query."""
102
+
103
+ @dataclass
104
+ class Metadata:
105
+ doc_title: str
106
+ page_id: int
107
+ context: Optional[str] = None
108
+
109
+ def __str__(self):
110
+ return f"Document: {self.doc_title}, Page ID: {self.page_id}, Context: {self.context}"
111
+
112
+ @dataclass
113
+ class Page:
114
+ image: Image.Image
115
+ metadata: Optional[Metadata] = None
116
+
117
+ @property
118
+ def caption(self):
119
+ if self.metadata is None:
120
+ return None
121
+ return f"Document: {self.metadata.doc_title}, Context: {self.metadata.context}"
122
+
123
+ def __hash__(self):
124
+ return hash(self.image)