Luca Foppiano commited on
Commit
5e4cf95
·
unverified ·
2 Parent(s): b5dfde0 727fc17

Merge pull request #30 from lfoppiano/question-coefficient

Browse files
.github/workflows/ci-build.yml CHANGED
@@ -26,7 +26,7 @@ jobs:
26
  - name: Install dependencies
27
  run: |
28
  python -m pip install --upgrade pip
29
- pip install --upgrade flake8 pytest pycodestyle
30
  if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
31
  - name: Lint with flake8
32
  run: |
 
26
  - name: Install dependencies
27
  run: |
28
  python -m pip install --upgrade pip
29
+ pip install --upgrade flake8 pytest pycodestyle pytest-cov
30
  if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
31
  - name: Lint with flake8
32
  run: |
Dockerfile CHANGED
@@ -15,8 +15,6 @@ RUN pip3 install -r requirements.txt
15
 
16
  COPY .streamlit ./.streamlit
17
  COPY document_qa ./document_qa
18
- COPY grobid_client_generic.py .
19
- COPY client.py .
20
  COPY streamlit_app.py .
21
 
22
  # extract version
 
15
 
16
  COPY .streamlit ./.streamlit
17
  COPY document_qa ./document_qa
 
 
18
  COPY streamlit_app.py .
19
 
20
  # extract version
document_qa/document_qa_engine.py CHANGED
@@ -1,23 +1,31 @@
1
  import copy
2
  import os
3
  from pathlib import Path
4
- from typing import Union, Any
5
 
6
  import tiktoken
7
- from grobid_client.grobid_client import GrobidClient
8
  from langchain.chains import create_extraction_chain
9
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
10
  map_rerank_prompt
 
11
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.retrievers import MultiQueryRetriever
13
  from langchain.schema import Document
14
- from langchain.vectorstores import Chroma
 
15
  from tqdm import tqdm
16
 
 
17
  from document_qa.grobid_processors import GrobidProcessor
 
18
 
19
 
20
  class TextMerger:
 
 
 
 
 
21
  def __init__(self, model_name=None, encoding_name="gpt2"):
22
  if model_name is not None:
23
  self.enc = tiktoken.encoding_for_model(model_name)
@@ -86,57 +94,56 @@ class TextMerger:
86
  return new_passages_struct
87
 
88
 
89
- class DocumentQAEngine:
90
- llm = None
91
- qa_chain_type = None
92
- embedding_function = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  embeddings_dict = {}
94
  embeddings_map_from_md5 = {}
95
  embeddings_map_to_md5 = {}
96
 
97
- default_prompts = {
98
- 'stuff': stuff_prompt,
99
- 'refine': refine_prompts,
100
- "map_reduce": map_reduce_prompt,
101
- "map_rerank": map_rerank_prompt
102
- }
103
-
104
- def __init__(self,
105
- llm,
106
- embedding_function,
107
- qa_chain_type="stuff",
108
- embeddings_root_path=None,
109
- grobid_url=None,
110
- memory=None
111
- ):
112
  self.embedding_function = embedding_function
113
- self.llm = llm
114
- self.memory = memory
115
- self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
116
- self.text_merger = TextMerger()
117
 
118
- if embeddings_root_path is not None:
119
- self.embeddings_root_path = embeddings_root_path
120
- if not os.path.exists(embeddings_root_path):
121
- os.makedirs(embeddings_root_path)
122
  else:
123
  self.load_embeddings(self.embeddings_root_path)
124
 
125
- if grobid_url:
126
- self.grobid_url = grobid_url
127
- grobid_client = GrobidClient(
128
- grobid_server=self.grobid_url,
129
- batch_size=1000,
130
- coordinates=["p", "title", "persName"],
131
- sleep_time=5,
132
- timeout=60,
133
- check_server=True
134
- )
135
- self.grobid_processor = GrobidProcessor(grobid_client)
136
-
137
  def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None:
138
  """
139
- Load the embeddings assuming they are all persisted and stored in a single directory.
140
  The root path of the embeddings containing one data store for each document in each subdirectory
141
  """
142
 
@@ -147,8 +154,10 @@ class DocumentQAEngine:
147
  return
148
 
149
  for embedding_document_dir in embeddings_directories:
150
- self.embeddings_dict[embedding_document_dir.name] = Chroma(persist_directory=embedding_document_dir.path,
151
- embedding_function=self.embedding_function)
 
 
152
 
153
  filename_list = list(Path(embedding_document_dir).glob('*.storage_filename'))
154
  if filename_list:
@@ -167,9 +176,60 @@ class DocumentQAEngine:
167
  def get_filename_from_md5(self, md5):
168
  return self.embeddings_map_from_md5[md5]
169
 
170
- def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
171
- verbose=False) -> (
172
- Any, str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # self.load_embeddings(self.embeddings_root_path)
174
 
175
  if verbose:
@@ -198,11 +258,52 @@ class DocumentQAEngine:
198
  else:
199
  return None, response, coordinates
200
 
201
- def query_storage(self, query: str, doc_id, context_size=4):
202
- documents = self._get_context(doc_id, query, context_size)
 
 
 
203
 
204
  context_as_text = [doc.page_content for doc in documents]
205
- return context_as_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  def _parse_json(self, response, output_parser):
208
  system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \
@@ -225,11 +326,8 @@ class DocumentQAEngine:
225
 
226
  return parsed_output
227
 
228
- def _run_query(self, doc_id, query, context_size=4):
229
- relevant_documents = self._get_context(doc_id, query, context_size)
230
- relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
231
- for doc in
232
- relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)]
233
  response = self.chain.run(input_documents=relevant_documents,
234
  question=query)
235
 
@@ -237,39 +335,46 @@ class DocumentQAEngine:
237
  self.memory.save_context({"input": query}, {"output": response})
238
  return response, relevant_document_coordinates
239
 
240
- def _get_context(self, doc_id, query, context_size=4):
241
- db = self.embeddings_dict[doc_id]
242
  retriever = db.as_retriever(search_kwargs={"k": context_size})
243
  relevant_documents = retriever.get_relevant_documents(query)
 
 
 
244
  if self.memory and len(self.memory.buffer_as_messages) > 0:
245
  relevant_documents.append(
246
  Document(
247
  page_content="""Following, the previous question and answers. Use these information only when in the question there are unspecified references:\n{}\n\n""".format(
248
  self.memory.buffer_as_str))
249
  )
250
- return relevant_documents
251
 
252
- def get_all_context_by_document(self, doc_id):
253
- """Return the full context from the document"""
254
- db = self.embeddings_dict[doc_id]
 
 
255
  docs = db.get()
256
  return docs['documents']
257
 
258
  def _get_context_multiquery(self, doc_id, query, context_size=4):
259
- db = self.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size})
260
  multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm)
261
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
262
  return relevant_documents
263
 
264
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
265
  """
266
- Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
 
 
267
  """
268
  if verbose:
269
  print("File", pdf_file_path)
270
  filename = Path(pdf_file_path).stem
271
  coordinates = True # if chunk_size == -1 else False
272
- structure = self.grobid_processor.process(pdf_file_path, coordinates=coordinates)
273
 
274
  biblio = structure['biblio']
275
  biblio['filename'] = filename.replace(" ", "_")
@@ -303,7 +408,13 @@ class DocumentQAEngine:
303
 
304
  return texts, metadatas, ids
305
 
306
- def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
 
 
 
 
 
 
307
  texts, metadata, ids = self.get_text_from_document(
308
  pdf_path,
309
  chunk_size=chunk_size,
@@ -313,25 +424,17 @@ class DocumentQAEngine:
313
  else:
314
  hash = metadata[0]['hash']
315
 
316
- if hash not in self.embeddings_dict.keys():
317
- self.embeddings_dict[hash] = Chroma.from_texts(texts,
318
- embedding=self.embedding_function,
319
- metadatas=metadata,
320
- collection_name=hash)
321
- else:
322
- # if 'documents' in self.embeddings_dict[hash].get() and len(self.embeddings_dict[hash].get()['documents']) == 0:
323
- # self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
324
- self.embeddings_dict[hash].delete_collection()
325
- self.embeddings_dict[hash] = Chroma.from_texts(texts,
326
- embedding=self.embedding_function,
327
- metadatas=metadata,
328
- collection_name=hash)
329
-
330
- self.embeddings_root_path = None
331
 
332
  return hash
333
 
334
- def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1, include_biblio=False):
 
 
 
 
 
 
335
  input_files = []
336
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
337
  for file_ in files:
@@ -343,17 +446,16 @@ class DocumentQAEngine:
343
  desc="Grobid + embeddings processing"):
344
 
345
  md5 = self.calculate_md5(input_file)
346
- data_path = os.path.join(self.embeddings_root_path, md5)
347
 
348
  if os.path.exists(data_path):
349
  print(data_path, "exists. Skipping it ")
350
  continue
351
- include = ["biblio"] if include_biblio else []
352
  texts, metadata, ids = self.get_text_from_document(
353
  input_file,
354
  chunk_size=chunk_size,
355
- perc_overlap=perc_overlap,
356
- include=include)
357
  filename = metadata[0]['filename']
358
 
359
  vector_db_document = Chroma.from_texts(texts,
 
1
  import copy
2
  import os
3
  from pathlib import Path
4
+ from typing import Union, Any, List
5
 
6
  import tiktoken
 
7
  from langchain.chains import create_extraction_chain
8
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
9
  map_rerank_prompt
10
+ from langchain.evaluation import PairwiseEmbeddingDistanceEvalChain, load_evaluator, EmbeddingDistance
11
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.retrievers import MultiQueryRetriever
13
  from langchain.schema import Document
14
+ from langchain_community.vectorstores.chroma import Chroma
15
+ from langchain_core.vectorstores import VectorStore
16
  from tqdm import tqdm
17
 
18
+ # from document_qa.embedding_visualiser import QueryVisualiser
19
  from document_qa.grobid_processors import GrobidProcessor
20
+ from document_qa.langchain import ChromaAdvancedRetrieval
21
 
22
 
23
  class TextMerger:
24
+ """
25
+ This class tries to replicate the RecursiveTextSplitter from LangChain, to preserve and merge the
26
+ coordinate information from the PDF document.
27
+ """
28
+
29
  def __init__(self, model_name=None, encoding_name="gpt2"):
30
  if model_name is not None:
31
  self.enc = tiktoken.encoding_for_model(model_name)
 
94
  return new_passages_struct
95
 
96
 
97
+ class BaseRetrieval:
98
+
99
+ def __init__(
100
+ self,
101
+ persist_directory: Path,
102
+ embedding_function
103
+ ):
104
+ self.embedding_function = embedding_function
105
+ self.persist_directory = persist_directory
106
+
107
+
108
+ class NER_Retrival(VectorStore):
109
+ """
110
+ This class implement a retrieval based on NER models.
111
+ This is an alternative retrieval to embeddings that relies on extracted entities.
112
+ """
113
+ pass
114
+
115
+
116
+ engines = {
117
+ 'chroma': ChromaAdvancedRetrieval,
118
+ 'ner': NER_Retrival
119
+ }
120
+
121
+
122
+ class DataStorage:
123
  embeddings_dict = {}
124
  embeddings_map_from_md5 = {}
125
  embeddings_map_to_md5 = {}
126
 
127
+ def __init__(
128
+ self,
129
+ embedding_function,
130
+ root_path: Path = None,
131
+ engine=ChromaAdvancedRetrieval,
132
+ ) -> None:
133
+ self.root_path = root_path
134
+ self.engine = engine
 
 
 
 
 
 
 
135
  self.embedding_function = embedding_function
 
 
 
 
136
 
137
+ if root_path is not None:
138
+ self.embeddings_root_path = root_path
139
+ if not os.path.exists(root_path):
140
+ os.makedirs(root_path)
141
  else:
142
  self.load_embeddings(self.embeddings_root_path)
143
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None:
145
  """
146
+ Load the vector storage assuming they are all persisted and stored in a single directory.
147
  The root path of the embeddings containing one data store for each document in each subdirectory
148
  """
149
 
 
154
  return
155
 
156
  for embedding_document_dir in embeddings_directories:
157
+ self.embeddings_dict[embedding_document_dir.name] = self.engine(
158
+ persist_directory=embedding_document_dir.path,
159
+ embedding_function=self.embedding_function
160
+ )
161
 
162
  filename_list = list(Path(embedding_document_dir).glob('*.storage_filename'))
163
  if filename_list:
 
176
  def get_filename_from_md5(self, md5):
177
  return self.embeddings_map_from_md5[md5]
178
 
179
+ def embed_document(self, doc_id, texts, metadatas):
180
+ if doc_id not in self.embeddings_dict.keys():
181
+ self.embeddings_dict[doc_id] = self.engine.from_texts(texts,
182
+ embedding=self.embedding_function,
183
+ metadatas=metadatas,
184
+ collection_name=doc_id)
185
+ else:
186
+ # Workaround Chroma (?) breaking change
187
+ self.embeddings_dict[doc_id].delete_collection()
188
+ self.embeddings_dict[doc_id] = self.engine.from_texts(texts,
189
+ embedding=self.embedding_function,
190
+ metadatas=metadatas,
191
+ collection_name=doc_id)
192
+
193
+ self.embeddings_root_path = None
194
+
195
+
196
+ class DocumentQAEngine:
197
+ llm = None
198
+ qa_chain_type = None
199
+
200
+ default_prompts = {
201
+ 'stuff': stuff_prompt,
202
+ 'refine': refine_prompts,
203
+ "map_reduce": map_reduce_prompt,
204
+ "map_rerank": map_rerank_prompt
205
+ }
206
+
207
+ def __init__(self,
208
+ llm,
209
+ data_storage: DataStorage,
210
+ qa_chain_type="stuff",
211
+ grobid_url=None,
212
+ memory=None
213
+ ):
214
+
215
+ self.llm = llm
216
+ self.memory = memory
217
+ self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
218
+ self.text_merger = TextMerger()
219
+ self.data_storage = data_storage
220
+
221
+ if grobid_url:
222
+ self.grobid_processor = GrobidProcessor(grobid_url)
223
+
224
+ def query_document(
225
+ self,
226
+ query: str,
227
+ doc_id,
228
+ output_parser=None,
229
+ context_size=4,
230
+ extraction_schema=None,
231
+ verbose=False
232
+ ) -> (Any, str):
233
  # self.load_embeddings(self.embeddings_root_path)
234
 
235
  if verbose:
 
258
  else:
259
  return None, response, coordinates
260
 
261
+ def query_storage(self, query: str, doc_id, context_size=4) -> (List[Document], list):
262
+ """
263
+ Returns the context related to a given query
264
+ """
265
+ documents, coordinates = self._get_context(doc_id, query, context_size)
266
 
267
  context_as_text = [doc.page_content for doc in documents]
268
+ return context_as_text, coordinates
269
+
270
+ def query_storage_and_embeddings(self, query: str, doc_id, context_size=4) -> List[Document]:
271
+ """
272
+ Returns both the context and the embedding information from a given query
273
+ """
274
+ db = self.data_storage.embeddings_dict[doc_id]
275
+ retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings")
276
+ relevant_documents = retriever.get_relevant_documents(query)
277
+
278
+ return relevant_documents
279
+
280
+ def analyse_query(self, query, doc_id, context_size=4):
281
+ db = self.data_storage.embeddings_dict[doc_id]
282
+ # retriever = db.as_retriever(
283
+ # search_kwargs={"k": context_size, 'score_threshold': 0.0},
284
+ # search_type="similarity_score_threshold"
285
+ # )
286
+ retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings")
287
+ relevant_documents = retriever.get_relevant_documents(query)
288
+ relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
289
+ for doc in
290
+ relevant_documents]
291
+ all_documents = db.get(include=['documents', 'metadatas', 'embeddings'])
292
+ # all_documents_embeddings = all_documents["embeddings"]
293
+ # query_embedding = db._embedding_function.embed_query(query)
294
+
295
+ # distance_evaluator = load_evaluator("pairwise_embedding_distance",
296
+ # embeddings=db._embedding_function,
297
+ # distance_metric=EmbeddingDistance.EUCLIDEAN)
298
+
299
+ # distance_evaluator.evaluate_string_pairs(query=query_embedding, documents="")
300
+
301
+ similarities = [doc.metadata['__similarity'] for doc in relevant_documents]
302
+ min_similarity = min(similarities)
303
+ mean_similarity = sum(similarities) / len(similarities)
304
+ coefficient = min_similarity - mean_similarity
305
+
306
+ return f"Coefficient: {coefficient}, (Min similarity {min_similarity}, Mean similarity: {mean_similarity})", relevant_document_coordinates
307
 
308
  def _parse_json(self, response, output_parser):
309
  system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \
 
326
 
327
  return parsed_output
328
 
329
+ def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list):
330
+ relevant_documents, relevant_document_coordinates = self._get_context(doc_id, query, context_size)
 
 
 
331
  response = self.chain.run(input_documents=relevant_documents,
332
  question=query)
333
 
 
335
  self.memory.save_context({"input": query}, {"output": response})
336
  return response, relevant_document_coordinates
337
 
338
+ def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list):
339
+ db = self.data_storage.embeddings_dict[doc_id]
340
  retriever = db.as_retriever(search_kwargs={"k": context_size})
341
  relevant_documents = retriever.get_relevant_documents(query)
342
+ relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
343
+ for doc in
344
+ relevant_documents]
345
  if self.memory and len(self.memory.buffer_as_messages) > 0:
346
  relevant_documents.append(
347
  Document(
348
  page_content="""Following, the previous question and answers. Use these information only when in the question there are unspecified references:\n{}\n\n""".format(
349
  self.memory.buffer_as_str))
350
  )
351
+ return relevant_documents, relevant_document_coordinates
352
 
353
+ def get_full_context_by_document(self, doc_id):
354
+ """
355
+ Return the full context from the document
356
+ """
357
+ db = self.data_storage.embeddings_dict[doc_id]
358
  docs = db.get()
359
  return docs['documents']
360
 
361
  def _get_context_multiquery(self, doc_id, query, context_size=4):
362
+ db = self.data_storage.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size})
363
  multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm)
364
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
365
  return relevant_documents
366
 
367
  def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
368
  """
369
+ Extract text from documents using Grobid.
370
+ - if chunk_size is < 0, keeps each paragraph separately
371
+ - if chunk_size > 0, aggregate all paragraphs and split them again using an approximate chunk size
372
  """
373
  if verbose:
374
  print("File", pdf_file_path)
375
  filename = Path(pdf_file_path).stem
376
  coordinates = True # if chunk_size == -1 else False
377
+ structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
378
 
379
  biblio = structure['biblio']
380
  biblio['filename'] = filename.replace(" ", "_")
 
408
 
409
  return texts, metadatas, ids
410
 
411
+ def create_memory_embeddings(
412
+ self,
413
+ pdf_path,
414
+ doc_id=None,
415
+ chunk_size=500,
416
+ perc_overlap=0.1
417
+ ):
418
  texts, metadata, ids = self.get_text_from_document(
419
  pdf_path,
420
  chunk_size=chunk_size,
 
424
  else:
425
  hash = metadata[0]['hash']
426
 
427
+ self.data_storage.embed_document(hash, texts, metadata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  return hash
430
 
431
+ def create_embeddings(
432
+ self,
433
+ pdfs_dir_path: Path,
434
+ chunk_size=500,
435
+ perc_overlap=0.1,
436
+ include_biblio=False
437
+ ):
438
  input_files = []
439
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
440
  for file_ in files:
 
446
  desc="Grobid + embeddings processing"):
447
 
448
  md5 = self.calculate_md5(input_file)
449
+ data_path = os.path.join(self.data_storage.embeddings_root_path, md5)
450
 
451
  if os.path.exists(data_path):
452
  print(data_path, "exists. Skipping it ")
453
  continue
454
+ # include = ["biblio"] if include_biblio else []
455
  texts, metadata, ids = self.get_text_from_document(
456
  input_file,
457
  chunk_size=chunk_size,
458
+ perc_overlap=perc_overlap)
 
459
  filename = metadata[0]['filename']
460
 
461
  vector_db_document = Chroma.from_texts(texts,
document_qa/grobid_processors.py CHANGED
@@ -2,12 +2,11 @@ import re
2
  from collections import OrderedDict
3
  from html import escape
4
  from pathlib import Path
5
- from typing_extensions import deprecated
6
 
7
  import dateparser
8
  import grobid_tei_xml
9
  from bs4 import BeautifulSoup
10
- from tqdm import tqdm
11
 
12
 
13
  def get_span_start(type, title=None):
@@ -55,51 +54,6 @@ def decorate_text_with_annotations(text, spans, tag="span"):
55
  return annotated_text
56
 
57
 
58
- @deprecated("Use GrobidQuantitiesProcessor.process() instead")
59
- def extract_quantities(client, x_all, column_text_index):
60
- # relevant_items = ['magnetic field strength', 'magnetic induction', 'maximum energy product',
61
- # "magnetic flux density", "magnetic flux"]
62
- # property_keywords = ['coercivity', 'remanence']
63
-
64
- output_data = []
65
-
66
- for idx, example in tqdm(enumerate(x_all), desc="extract quantities"):
67
- text = example[column_text_index]
68
- spans = GrobidQuantitiesProcessor(client).process(text)
69
-
70
- data_record = {
71
- "id": example[0],
72
- "filename": example[1],
73
- "passage_id": example[2],
74
- "text": text,
75
- "spans": spans
76
- }
77
-
78
- output_data.append(data_record)
79
-
80
- return output_data
81
-
82
-
83
- @deprecated("Use GrobidMaterialsProcessor.process() instead")
84
- def extract_materials(client, x_all, column_text_index):
85
- output_data = []
86
-
87
- for idx, example in tqdm(enumerate(x_all), desc="extract materials"):
88
- text = example[column_text_index]
89
- spans = GrobidMaterialsProcessor(client).process(text)
90
- data_record = {
91
- "id": example[0],
92
- "filename": example[1],
93
- "passage_id": example[2],
94
- "text": text,
95
- "spans": spans
96
- }
97
-
98
- output_data.append(data_record)
99
-
100
- return output_data
101
-
102
-
103
  def get_parsed_value_type(quantity):
104
  if 'parsedValue' in quantity and 'structure' in quantity['parsedValue']:
105
  return quantity['parsedValue']['structure']['type']
@@ -130,11 +84,19 @@ class BaseProcessor(object):
130
 
131
 
132
  class GrobidProcessor(BaseProcessor):
133
- def __init__(self, grobid_client):
134
  # super().__init__()
 
 
 
 
 
 
 
 
135
  self.grobid_client = grobid_client
136
 
137
- def process(self, input_path, coordinates=False):
138
  pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
139
  input_path,
140
  consolidate_header=True,
@@ -153,6 +115,15 @@ class GrobidProcessor(BaseProcessor):
153
 
154
  return document_object
155
 
 
 
 
 
 
 
 
 
 
156
  def parse_grobid_xml(self, text, coordinates=False):
157
  output_data = OrderedDict()
158
 
@@ -212,6 +183,7 @@ class GrobidProcessor(BaseProcessor):
212
  })
213
 
214
  text_blocks_body = get_xml_nodes_body(soup, verbose=False, use_paragraphs=True)
 
215
 
216
  use_paragraphs = True
217
  if not use_paragraphs:
@@ -287,7 +259,7 @@ class GrobidQuantitiesProcessor(BaseProcessor):
287
  def __init__(self, grobid_quantities_client):
288
  self.grobid_quantities_client = grobid_quantities_client
289
 
290
- def process(self, text):
291
  status, result = self.grobid_quantities_client.process_text(text.strip())
292
 
293
  if status != 200:
@@ -555,11 +527,12 @@ class GrobidMaterialsProcessor(BaseProcessor):
555
  return materials
556
 
557
 
558
- class GrobidAggregationProcessor(GrobidProcessor, GrobidQuantitiesProcessor, GrobidMaterialsProcessor):
559
- def __init__(self, grobid_client, grobid_quantities_client=None, grobid_superconductors_client=None):
560
- GrobidProcessor.__init__(self, grobid_client)
561
- self.gqp = GrobidQuantitiesProcessor(grobid_quantities_client)
562
- self.gmp = GrobidMaterialsProcessor(grobid_superconductors_client)
 
563
 
564
  def process_single_text(self, text):
565
  extracted_quantities_spans = self.process_properties(text)
@@ -569,10 +542,16 @@ class GrobidAggregationProcessor(GrobidProcessor, GrobidQuantitiesProcessor, Gro
569
  return entities
570
 
571
  def process_properties(self, text):
572
- return self.gqp.process(text)
 
 
 
573
 
574
  def process_materials(self, text):
575
- return self.gmp.process(text)
 
 
 
576
 
577
  @staticmethod
578
  def box_to_dict(box, color=None, type=None):
@@ -724,11 +703,11 @@ class XmlProcessor(BaseProcessor):
724
 
725
  # def process_single(self, input_file):
726
  # doc = self.process_structure(input_file)
727
- #
728
  # for paragraph in doc['passages']:
729
  # entities = self.process_single_text(paragraph['text'])
730
  # paragraph['spans'] = entities
731
- #
732
  # return doc
733
 
734
  def process(self, text):
@@ -822,6 +801,20 @@ def get_xml_nodes_body(soup: object, use_paragraphs: bool = True, verbose: bool
822
  return nodes
823
 
824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
  def get_xml_nodes_figures(soup: object, verbose: bool = False) -> list:
826
  children = []
827
  for child in soup.TEI.children:
 
2
  from collections import OrderedDict
3
  from html import escape
4
  from pathlib import Path
 
5
 
6
  import dateparser
7
  import grobid_tei_xml
8
  from bs4 import BeautifulSoup
9
+ from grobid_client.grobid_client import GrobidClient
10
 
11
 
12
  def get_span_start(type, title=None):
 
54
  return annotated_text
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def get_parsed_value_type(quantity):
58
  if 'parsedValue' in quantity and 'structure' in quantity['parsedValue']:
59
  return quantity['parsedValue']['structure']['type']
 
84
 
85
 
86
  class GrobidProcessor(BaseProcessor):
87
+ def __init__(self, grobid_url, ping_server=True):
88
  # super().__init__()
89
+ grobid_client = GrobidClient(
90
+ grobid_server=grobid_url,
91
+ batch_size=5,
92
+ coordinates=["p", "title", "persName"],
93
+ sleep_time=5,
94
+ timeout=60,
95
+ check_server=ping_server
96
+ )
97
  self.grobid_client = grobid_client
98
 
99
+ def process_structure(self, input_path, coordinates=False):
100
  pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
101
  input_path,
102
  consolidate_header=True,
 
115
 
116
  return document_object
117
 
118
+ def process_single(self, input_file):
119
+ doc = self.process_structure(input_file)
120
+
121
+ for paragraph in doc['passages']:
122
+ entities = self.process_single_text(paragraph['text'])
123
+ paragraph['spans'] = entities
124
+
125
+ return doc
126
+
127
  def parse_grobid_xml(self, text, coordinates=False):
128
  output_data = OrderedDict()
129
 
 
183
  })
184
 
185
  text_blocks_body = get_xml_nodes_body(soup, verbose=False, use_paragraphs=True)
186
+ text_blocks_body.extend(get_xml_nodes_back(soup, verbose=False, use_paragraphs=True))
187
 
188
  use_paragraphs = True
189
  if not use_paragraphs:
 
259
  def __init__(self, grobid_quantities_client):
260
  self.grobid_quantities_client = grobid_quantities_client
261
 
262
+ def process(self, text) -> list:
263
  status, result = self.grobid_quantities_client.process_text(text.strip())
264
 
265
  if status != 200:
 
527
  return materials
528
 
529
 
530
+ class GrobidAggregationProcessor(GrobidQuantitiesProcessor, GrobidMaterialsProcessor):
531
+ def __init__(self, grobid_quantities_client=None, grobid_superconductors_client=None):
532
+ if grobid_quantities_client:
533
+ self.gqp = GrobidQuantitiesProcessor(grobid_quantities_client)
534
+ if grobid_superconductors_client:
535
+ self.gmp = GrobidMaterialsProcessor(grobid_superconductors_client)
536
 
537
  def process_single_text(self, text):
538
  extracted_quantities_spans = self.process_properties(text)
 
542
  return entities
543
 
544
  def process_properties(self, text):
545
+ if self.gqp:
546
+ return self.gqp.process(text)
547
+ else:
548
+ return []
549
 
550
  def process_materials(self, text):
551
+ if self.gmp:
552
+ return self.gmp.process(text)
553
+ else:
554
+ return []
555
 
556
  @staticmethod
557
  def box_to_dict(box, color=None, type=None):
 
703
 
704
  # def process_single(self, input_file):
705
  # doc = self.process_structure(input_file)
706
+ #
707
  # for paragraph in doc['passages']:
708
  # entities = self.process_single_text(paragraph['text'])
709
  # paragraph['spans'] = entities
710
+ #
711
  # return doc
712
 
713
  def process(self, text):
 
801
  return nodes
802
 
803
 
804
+ def get_xml_nodes_back(soup: object, use_paragraphs: bool = True, verbose: bool = False) -> list:
805
+ nodes = []
806
+ tag_name = "p" if use_paragraphs else "s"
807
+ for child in soup.TEI.children:
808
+ if child.name == 'text':
809
+ nodes.extend(
810
+ [subsubchild for subchild in child.find_all("back") for subsubchild in subchild.find_all(tag_name)])
811
+
812
+ if verbose:
813
+ print(str(nodes))
814
+
815
+ return nodes
816
+
817
+
818
  def get_xml_nodes_figures(soup: object, verbose: bool = False) -> list:
819
  children = []
820
  for child in soup.TEI.children:
document_qa/langchain.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection
3
+
4
+ from langchain.schema import Document
5
+ from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K
6
+ from langchain_core.callbacks import CallbackManagerForRetrieverRun
7
+ from langchain_core.utils import xor_args
8
+ from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
9
+
10
+
11
+ class AdvancedVectorStoreRetriever(VectorStoreRetriever):
12
+ allowed_search_types: ClassVar[Collection[str]] = (
13
+ "similarity",
14
+ "similarity_score_threshold",
15
+ "mmr",
16
+ "similarity_with_embeddings"
17
+ )
18
+
19
+ def _get_relevant_documents(
20
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
21
+ ) -> List[Document]:
22
+
23
+ if self.search_type == "similarity_with_embeddings":
24
+ docs_scores_and_embeddings = (
25
+ self.vectorstore.advanced_similarity_search(
26
+ query, **self.search_kwargs
27
+ )
28
+ )
29
+
30
+ for doc, score, embeddings in docs_scores_and_embeddings:
31
+ if '__embeddings' not in doc.metadata.keys():
32
+ doc.metadata['__embeddings'] = embeddings
33
+ if '__similarity' not in doc.metadata.keys():
34
+ doc.metadata['__similarity'] = score
35
+
36
+ docs = [doc for doc, _, _ in docs_scores_and_embeddings]
37
+ elif self.search_type == "similarity_score_threshold":
38
+ docs_and_similarities = (
39
+ self.vectorstore.similarity_search_with_relevance_scores(
40
+ query, **self.search_kwargs
41
+ )
42
+ )
43
+ for doc, similarity in docs_and_similarities:
44
+ if '__similarity' not in doc.metadata.keys():
45
+ doc.metadata['__similarity'] = similarity
46
+
47
+ docs = [doc for doc, _ in docs_and_similarities]
48
+ else:
49
+ docs = super()._get_relevant_documents(query, run_manager=run_manager)
50
+
51
+ return docs
52
+
53
+
54
+ class AdvancedVectorStore(VectorStore):
55
+ def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever:
56
+ tags = kwargs.pop("tags", None) or []
57
+ tags.extend(self._get_retriever_tags())
58
+ return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
59
+
60
+
61
+ class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore):
62
+ def __init__(self, **kwargs):
63
+ super().__init__(**kwargs)
64
+
65
+ @xor_args(("query_texts", "query_embeddings"))
66
+ def __query_collection(
67
+ self,
68
+ query_texts: Optional[List[str]] = None,
69
+ query_embeddings: Optional[List[List[float]]] = None,
70
+ n_results: int = 4,
71
+ where: Optional[Dict[str, str]] = None,
72
+ where_document: Optional[Dict[str, str]] = None,
73
+ **kwargs: Any,
74
+ ) -> List[Document]:
75
+ """Query the chroma collection."""
76
+ try:
77
+ import chromadb # noqa: F401
78
+ except ImportError:
79
+ raise ValueError(
80
+ "Could not import chromadb python package. "
81
+ "Please install it with `pip install chromadb`."
82
+ )
83
+ return self._collection.query(
84
+ query_texts=query_texts,
85
+ query_embeddings=query_embeddings,
86
+ n_results=n_results,
87
+ where=where,
88
+ where_document=where_document,
89
+ **kwargs,
90
+ )
91
+
92
+ def advanced_similarity_search(
93
+ self,
94
+ query: str,
95
+ k: int = DEFAULT_K,
96
+ filter: Optional[Dict[str, str]] = None,
97
+ **kwargs: Any,
98
+ ) -> [List[Document], float, List[float]]:
99
+ docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter)
100
+ return docs_scores_and_embeddings
101
+
102
+ def similarity_search_with_scores_and_embeddings(
103
+ self,
104
+ query: str,
105
+ k: int = DEFAULT_K,
106
+ filter: Optional[Dict[str, str]] = None,
107
+ where_document: Optional[Dict[str, str]] = None,
108
+ **kwargs: Any,
109
+ ) -> List[Tuple[Document, float, List[float]]]:
110
+
111
+ if self._embedding_function is None:
112
+ results = self.__query_collection(
113
+ query_texts=[query],
114
+ n_results=k,
115
+ where=filter,
116
+ where_document=where_document,
117
+ include=['metadatas', 'documents', 'embeddings', 'distances']
118
+ )
119
+ else:
120
+ query_embedding = self._embedding_function.embed_query(query)
121
+ results = self.__query_collection(
122
+ query_embeddings=[query_embedding],
123
+ n_results=k,
124
+ where=filter,
125
+ where_document=where_document,
126
+ include=['metadatas', 'documents', 'embeddings', 'distances']
127
+ )
128
+
129
+ return _results_to_docs_scores_and_embeddings(results)
130
+
131
+
132
+ def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]:
133
+ return [
134
+ (Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3])
135
+ for result in zip(
136
+ results["documents"][0],
137
+ results["metadatas"][0],
138
+ results["distances"][0],
139
+ results["embeddings"][0],
140
+ )
141
+ ]
client.py → document_qa/ner_client_generic.py RENAMED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  """ Generic API Client """
2
  from copy import deepcopy
3
  import json
@@ -121,7 +133,7 @@ class ApiClient(object):
121
  params = deepcopy(params) or {}
122
  data = data or {}
123
  files = files or {}
124
- #if self.username is not None and self.api_key is not None:
125
  # params.update(self.get_credentials())
126
  r = requests.request(
127
  method,
@@ -223,3 +235,227 @@ class ApiClient(object):
223
  params={'format': 'json'},
224
  **kwargs
225
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import yaml
5
+
6
+ '''
7
+ This client is a generic client for any Grobid application and sub-modules.
8
+ At the moment, it supports only single document processing.
9
+
10
+ Source: https://github.com/kermitt2/grobid-client-python
11
+ '''
12
+
13
  """ Generic API Client """
14
  from copy import deepcopy
15
  import json
 
133
  params = deepcopy(params) or {}
134
  data = data or {}
135
  files = files or {}
136
+ # if self.username is not None and self.api_key is not None:
137
  # params.update(self.get_credentials())
138
  r = requests.request(
139
  method,
 
235
  params={'format': 'json'},
236
  **kwargs
237
  )
238
+
239
+
240
+ class NERClientGeneric(ApiClient):
241
+
242
+ def __init__(self, config_path=None, ping=False):
243
+ self.config = None
244
+ if config_path is not None:
245
+ self.config = self._load_yaml_config_from_file(path=config_path)
246
+ super().__init__(self.config['grobid']['server'])
247
+
248
+ if ping:
249
+ result = self.ping_service()
250
+ if not result:
251
+ raise Exception("Grobid is down.")
252
+
253
+ os.environ['NO_PROXY'] = "nims.go.jp"
254
+
255
+ @staticmethod
256
+ def _load_json_config_from_file(path='./config.json'):
257
+ """
258
+ Load the json configuration
259
+ """
260
+ config = {}
261
+ with open(path, 'r') as fp:
262
+ config = json.load(fp)
263
+
264
+ return config
265
+
266
+ @staticmethod
267
+ def _load_yaml_config_from_file(path='./config.yaml'):
268
+ """
269
+ Load the YAML configuration
270
+ """
271
+ config = {}
272
+ try:
273
+ with open(path, 'r') as the_file:
274
+ raw_configuration = the_file.read()
275
+
276
+ config = yaml.safe_load(raw_configuration)
277
+ except Exception as e:
278
+ print("Configuration could not be loaded: ", str(e))
279
+ exit(1)
280
+
281
+ return config
282
+
283
+ def set_config(self, config, ping=False):
284
+ self.config = config
285
+ if ping:
286
+ try:
287
+ result = self.ping_service()
288
+ if not result:
289
+ raise Exception("Grobid is down.")
290
+ except Exception as e:
291
+ raise Exception("Grobid is down or other problems were encountered. ", e)
292
+
293
+ def ping_service(self):
294
+ # test if the server is up and running...
295
+ ping_url = self.get_url("ping")
296
+
297
+ r = requests.get(ping_url)
298
+ status = r.status_code
299
+
300
+ if status != 200:
301
+ print('GROBID server does not appear up and running ' + str(status))
302
+ return False
303
+ else:
304
+ print("GROBID server is up and running")
305
+ return True
306
+
307
+ def get_url(self, action):
308
+ grobid_config = self.config['grobid']
309
+ base_url = grobid_config['server']
310
+ action_url = base_url + grobid_config['url_mapping'][action]
311
+
312
+ return action_url
313
+
314
+ def process_texts(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}):
315
+
316
+ files = {
317
+ 'texts': input
318
+ }
319
+
320
+ the_url = self.get_url(method_name)
321
+ params, the_url = self.get_params_from_url(the_url)
322
+
323
+ res, status = self.post(
324
+ url=the_url,
325
+ files=files,
326
+ data=params,
327
+ headers=headers
328
+ )
329
+
330
+ if status == 503:
331
+ time.sleep(self.config['sleep_time'])
332
+ return self.process_texts(input, method_name, params, headers)
333
+ elif status != 200:
334
+ print('Processing failed with error ' + str(status))
335
+ return status, None
336
+ else:
337
+ return status, json.loads(res.text)
338
+
339
+ def process_text(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}):
340
+
341
+ files = {
342
+ 'text': input
343
+ }
344
+
345
+ the_url = self.get_url(method_name)
346
+ params, the_url = self.get_params_from_url(the_url)
347
+
348
+ res, status = self.post(
349
+ url=the_url,
350
+ files=files,
351
+ data=params,
352
+ headers=headers
353
+ )
354
+
355
+ if status == 503:
356
+ time.sleep(self.config['sleep_time'])
357
+ return self.process_text(input, method_name, params, headers)
358
+ elif status != 200:
359
+ print('Processing failed with error ' + str(status))
360
+ return status, None
361
+ else:
362
+ return status, json.loads(res.text)
363
+
364
+ def process_pdf(self,
365
+ form_data: dict,
366
+ method_name='superconductors',
367
+ params={},
368
+ headers={"Accept": "application/json"}
369
+ ):
370
+
371
+ the_url = self.get_url(method_name)
372
+ params, the_url = self.get_params_from_url(the_url)
373
+
374
+ res, status = self.post(
375
+ url=the_url,
376
+ files=form_data,
377
+ data=params,
378
+ headers=headers
379
+ )
380
+
381
+ if status == 503:
382
+ time.sleep(self.config['sleep_time'])
383
+ return self.process_text(input, method_name, params, headers)
384
+ elif status != 200:
385
+ print('Processing failed with error ' + str(status))
386
+ else:
387
+ return res.text
388
+
389
+ def process_pdfs(self, pdf_files, params={}):
390
+ pass
391
+
392
+ def process_pdf(
393
+ self,
394
+ pdf_file,
395
+ method_name,
396
+ params={},
397
+ headers={"Accept": "application/json"},
398
+ verbose=False,
399
+ retry=None
400
+ ):
401
+
402
+ files = {
403
+ 'input': (
404
+ pdf_file,
405
+ open(pdf_file, 'rb'),
406
+ 'application/pdf',
407
+ {'Expires': '0'}
408
+ )
409
+ }
410
+
411
+ the_url = self.get_url(method_name)
412
+
413
+ params, the_url = self.get_params_from_url(the_url)
414
+
415
+ res, status = self.post(
416
+ url=the_url,
417
+ files=files,
418
+ data=params,
419
+ headers=headers
420
+ )
421
+
422
+ if status == 503 or status == 429:
423
+ if retry is None:
424
+ retry = self.config['max_retry'] - 1
425
+ else:
426
+ if retry - 1 == 0:
427
+ if verbose:
428
+ print("re-try exhausted. Aborting request")
429
+ return None, status
430
+ else:
431
+ retry -= 1
432
+
433
+ sleep_time = self.config['sleep_time']
434
+ if verbose:
435
+ print("Server is saturated, waiting", sleep_time, "seconds and trying again. ")
436
+ time.sleep(sleep_time)
437
+ return self.process_pdf(pdf_file, method_name, params, headers, verbose=verbose, retry=retry)
438
+ elif status != 200:
439
+ desc = None
440
+ if res.content:
441
+ c = json.loads(res.text)
442
+ desc = c['description'] if 'description' in c else None
443
+ return desc, status
444
+ elif status == 204:
445
+ # print('No content returned. Moving on. ')
446
+ return None, status
447
+ else:
448
+ return res.text, status
449
+
450
+ def get_params_from_url(self, the_url):
451
+ """
452
+ This method is used to pass to the URL predefined parameters, which are added in the URL format
453
+ """
454
+ params = {}
455
+ if "?" in the_url:
456
+ split = the_url.split("?")
457
+ the_url = split[0]
458
+ params = split[1]
459
+
460
+ params = {param.split("=")[0]: param.split("=")[1] for param in params.split("&")}
461
+ return params, the_url
grobid_client_generic.py DELETED
@@ -1,264 +0,0 @@
1
- import json
2
- import os
3
- import time
4
-
5
- import requests
6
- import yaml
7
-
8
- from client import ApiClient
9
-
10
- '''
11
- This client is a generic client for any Grobid application and sub-modules.
12
- At the moment, it supports only single document processing.
13
-
14
- Source: https://github.com/kermitt2/grobid-client-python
15
- '''
16
-
17
-
18
- class GrobidClientGeneric(ApiClient):
19
-
20
- def __init__(self, config_path=None, ping=False):
21
- self.config = None
22
- if config_path is not None:
23
- self.config = self.load_yaml_config_from_file(path=config_path)
24
- super().__init__(self.config['grobid']['server'])
25
-
26
- if ping:
27
- result = self.ping_grobid()
28
- if not result:
29
- raise Exception("Grobid is down.")
30
-
31
- os.environ['NO_PROXY'] = "nims.go.jp"
32
-
33
- @staticmethod
34
- def load_json_config_from_file(self, path='./config.json', ping=False):
35
- """
36
- Load the json configuration
37
- """
38
- config = {}
39
- with open(path, 'r') as fp:
40
- config = json.load(fp)
41
-
42
- if ping:
43
- result = self.ping_grobid()
44
- if not result:
45
- raise Exception("Grobid is down.")
46
-
47
- return config
48
-
49
- def load_yaml_config_from_file(self, path='./config.yaml'):
50
- """
51
- Load the YAML configuration
52
- """
53
- config = {}
54
- try:
55
- with open(path, 'r') as the_file:
56
- raw_configuration = the_file.read()
57
-
58
- config = yaml.safe_load(raw_configuration)
59
- except Exception as e:
60
- print("Configuration could not be loaded: ", str(e))
61
- exit(1)
62
-
63
- return config
64
-
65
- def set_config(self, config, ping=False):
66
- self.config = config
67
- if ping:
68
- try:
69
- result = self.ping_grobid()
70
- if not result:
71
- raise Exception("Grobid is down.")
72
- except Exception as e:
73
- raise Exception("Grobid is down or other problems were encountered. ", e)
74
-
75
- def ping_grobid(self):
76
- # test if the server is up and running...
77
- ping_url = self.get_grobid_url("ping")
78
-
79
- r = requests.get(ping_url)
80
- status = r.status_code
81
-
82
- if status != 200:
83
- print('GROBID server does not appear up and running ' + str(status))
84
- return False
85
- else:
86
- print("GROBID server is up and running")
87
- return True
88
-
89
- def get_grobid_url(self, action):
90
- grobid_config = self.config['grobid']
91
- base_url = grobid_config['server']
92
- action_url = base_url + grobid_config['url_mapping'][action]
93
-
94
- return action_url
95
-
96
- def process_texts(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}):
97
-
98
- files = {
99
- 'texts': input
100
- }
101
-
102
- the_url = self.get_grobid_url(method_name)
103
- params, the_url = self.get_params_from_url(the_url)
104
-
105
- res, status = self.post(
106
- url=the_url,
107
- files=files,
108
- data=params,
109
- headers=headers
110
- )
111
-
112
- if status == 503:
113
- time.sleep(self.config['sleep_time'])
114
- return self.process_texts(input, method_name, params, headers)
115
- elif status != 200:
116
- print('Processing failed with error ' + str(status))
117
- return status, None
118
- else:
119
- return status, json.loads(res.text)
120
-
121
- def process_text(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}):
122
-
123
- files = {
124
- 'text': input
125
- }
126
-
127
- the_url = self.get_grobid_url(method_name)
128
- params, the_url = self.get_params_from_url(the_url)
129
-
130
- res, status = self.post(
131
- url=the_url,
132
- files=files,
133
- data=params,
134
- headers=headers
135
- )
136
-
137
- if status == 503:
138
- time.sleep(self.config['sleep_time'])
139
- return self.process_text(input, method_name, params, headers)
140
- elif status != 200:
141
- print('Processing failed with error ' + str(status))
142
- return status, None
143
- else:
144
- return status, json.loads(res.text)
145
-
146
- def process(self, form_data: dict, method_name='superconductors', params={}, headers={"Accept": "application/json"}):
147
-
148
- the_url = self.get_grobid_url(method_name)
149
- params, the_url = self.get_params_from_url(the_url)
150
-
151
- res, status = self.post(
152
- url=the_url,
153
- files=form_data,
154
- data=params,
155
- headers=headers
156
- )
157
-
158
- if status == 503:
159
- time.sleep(self.config['sleep_time'])
160
- return self.process_text(input, method_name, params, headers)
161
- elif status != 200:
162
- print('Processing failed with error ' + str(status))
163
- else:
164
- return res.text
165
-
166
- def process_pdf_batch(self, pdf_files, params={}):
167
- pass
168
-
169
- def process_pdf(self, pdf_file, method_name, params={}, headers={"Accept": "application/json"}, verbose=False,
170
- retry=None):
171
-
172
- files = {
173
- 'input': (
174
- pdf_file,
175
- open(pdf_file, 'rb'),
176
- 'application/pdf',
177
- {'Expires': '0'}
178
- )
179
- }
180
-
181
- the_url = self.get_grobid_url(method_name)
182
-
183
- params, the_url = self.get_params_from_url(the_url)
184
-
185
- res, status = self.post(
186
- url=the_url,
187
- files=files,
188
- data=params,
189
- headers=headers
190
- )
191
-
192
- if status == 503 or status == 429:
193
- if retry is None:
194
- retry = self.config['max_retry'] - 1
195
- else:
196
- if retry - 1 == 0:
197
- if verbose:
198
- print("re-try exhausted. Aborting request")
199
- return None, status
200
- else:
201
- retry -= 1
202
-
203
- sleep_time = self.config['sleep_time']
204
- if verbose:
205
- print("Server is saturated, waiting", sleep_time, "seconds and trying again. ")
206
- time.sleep(sleep_time)
207
- return self.process_pdf(pdf_file, method_name, params, headers, verbose=verbose, retry=retry)
208
- elif status != 200:
209
- desc = None
210
- if res.content:
211
- c = json.loads(res.text)
212
- desc = c['description'] if 'description' in c else None
213
- return desc, status
214
- elif status == 204:
215
- # print('No content returned. Moving on. ')
216
- return None, status
217
- else:
218
- return res.text, status
219
-
220
- def get_params_from_url(self, the_url):
221
- params = {}
222
- if "?" in the_url:
223
- split = the_url.split("?")
224
- the_url = split[0]
225
- params = split[1]
226
-
227
- params = {param.split("=")[0]: param.split("=")[1] for param in params.split("&")}
228
- return params, the_url
229
-
230
- def process_json(self, text, method_name="processJson", params={}, headers={"Accept": "application/json"},
231
- verbose=False):
232
- files = {
233
- 'input': (
234
- None,
235
- text,
236
- 'application/json',
237
- {'Expires': '0'}
238
- )
239
- }
240
-
241
- the_url = self.get_grobid_url(method_name)
242
-
243
- params, the_url = self.get_params_from_url(the_url)
244
-
245
- res, status = self.post(
246
- url=the_url,
247
- files=files,
248
- data=params,
249
- headers=headers
250
- )
251
-
252
- if status == 503:
253
- time.sleep(self.config['sleep_time'])
254
- return self.process_json(text, method_name, params, headers), status
255
- elif status != 200:
256
- if verbose:
257
- print('Processing failed with error ', status)
258
- return None, status
259
- elif status == 204:
260
- if verbose:
261
- print('No content returned. Moving on. ')
262
- return None, status
263
- else:
264
- return res.text, status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -4,10 +4,10 @@ grobid-client-python==0.0.7
4
  grobid_tei_xml==0.1.3
5
 
6
  # Utils
7
- tqdm==4.66.1
8
  pyyaml==6.0.1
9
- pytest==7.4.3
10
- streamlit==1.32.2
11
  lxml
12
  Beautifulsoup4
13
  python-dotenv
@@ -15,13 +15,15 @@ watchdog
15
  dateparser
16
 
17
  # LLM
18
- chromadb==0.4.19
19
- tiktoken==0.4.0
20
- openai==0.27.7
21
- langchain==0.0.350
22
- langchain-core==0.1.0
23
  typing-inspect==0.9.0
24
- typing_extensions==4.8.0
25
- pydantic==2.4.2
26
- sentence_transformers==2.2.2
27
- streamlit-pdf-viewer==0.0.13
 
 
 
4
  grobid_tei_xml==0.1.3
5
 
6
  # Utils
7
+ tqdm==4.66.2
8
  pyyaml==6.0.1
9
+ pytest==8.1.1
10
+ streamlit==1.33.0
11
  lxml
12
  Beautifulsoup4
13
  python-dotenv
 
15
  dateparser
16
 
17
  # LLM
18
+ chromadb==0.4.24
19
+ tiktoken==0.6.0
20
+ openai==1.16.2
21
+ langchain==0.1.14
22
+ langchain-core==0.1.40
23
  typing-inspect==0.9.0
24
+ typing_extensions==4.11.0
25
+ pydantic==2.6.4
26
+ sentence_transformers==2.6.1
27
+ streamlit-pdf-viewer
28
+ umap-learn
29
+ plotly
streamlit_app.py CHANGED
@@ -5,27 +5,41 @@ from tempfile import NamedTemporaryFile
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
- from langchain.llms.huggingface_hub import HuggingFaceHub
9
  from langchain.memory import ConversationBufferWindowMemory
 
 
 
 
10
  from streamlit_pdf_viewer import pdf_viewer
11
 
 
 
12
  dotenv.load_dotenv(override=True)
13
 
14
  import streamlit as st
15
- from langchain.chat_models import ChatOpenAI
16
- from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
17
-
18
- from document_qa.document_qa_engine import DocumentQAEngine
19
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
20
- from grobid_client_generic import GrobidClientGeneric
21
 
22
  OPENAI_MODELS = ['gpt-3.5-turbo',
23
  "gpt-4",
24
  "gpt-4-1106-preview"]
25
 
 
 
 
 
 
 
26
  OPEN_MODELS = {
27
- 'mistral-7b-instruct-v0.1': 'mistralai/Mistral-7B-Instruct-v0.1',
28
- "zephyr-7b-beta": 'HuggingFaceH4/zephyr-7b-beta'
 
 
 
 
 
 
 
29
  }
30
 
31
  DISABLE_MEMORY = ['zephyr-7b-beta']
@@ -82,6 +96,9 @@ if 'pdf' not in st.session_state:
82
  if 'pdf_rendering' not in st.session_state:
83
  st.session_state['pdf_rendering'] = None
84
 
 
 
 
85
  st.set_page_config(
86
  page_title="Scientific Document Insights Q/A",
87
  page_icon="📝",
@@ -138,44 +155,57 @@ def clear_memory():
138
 
139
 
140
  # @st.cache_resource
141
- def init_qa(model, api_key=None):
142
  ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
143
  if model in OPENAI_MODELS:
 
 
 
144
  st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
145
  if api_key:
146
  chat = ChatOpenAI(model_name=model,
147
  temperature=0,
148
  openai_api_key=api_key,
149
  frequency_penalty=0.1)
150
- embeddings = OpenAIEmbeddings(openai_api_key=api_key)
 
 
 
 
151
 
152
  else:
153
  chat = ChatOpenAI(model_name=model,
154
  temperature=0,
155
  frequency_penalty=0.1)
156
- embeddings = OpenAIEmbeddings()
157
 
158
  elif model in OPEN_MODELS:
159
- chat = HuggingFaceHub(
 
 
 
160
  repo_id=OPEN_MODELS[model],
161
- model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048}
 
 
162
  )
163
  embeddings = HuggingFaceEmbeddings(
164
- model_name="all-MiniLM-L6-v2")
165
  st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
166
  else:
167
  st.error("The model was not loaded properly. Try reloading. ")
168
  st.stop()
169
  return
170
 
171
- return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
 
172
 
173
 
174
  @st.cache_resource
175
  def init_ner():
176
  quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True)
177
 
178
- materials_client = GrobidClientGeneric(ping=True)
179
  config_materials = {
180
  'grobid': {
181
  "server": os.environ['GROBID_MATERIALS_URL'],
@@ -190,10 +220,8 @@ def init_ner():
190
 
191
  materials_client.set_config(config_materials)
192
 
193
- gqa = GrobidAggregationProcessor(None,
194
- grobid_quantities_client=quantities_client,
195
- grobid_superconductors_client=materials_client
196
- )
197
  return gqa
198
 
199
 
@@ -229,15 +257,25 @@ with st.sidebar:
229
  "Model:",
230
  options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
231
  index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index(
232
- "zephyr-7b-beta") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else (
233
  OPENAI_MODELS + list(OPEN_MODELS.keys())).index(os.environ["DEFAULT_MODEL"]),
234
  placeholder="Select model",
235
  help="Select the LLM model:",
236
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
237
  )
 
 
 
 
 
 
 
 
 
 
238
 
239
  st.markdown(
240
- ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa/tree/review-interface#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
241
 
242
  if (model in OPEN_MODELS) and model not in st.session_state['api_keys']:
243
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
@@ -254,7 +292,7 @@ with st.sidebar:
254
  st.session_state['api_keys'][model] = api_key
255
  # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
256
  # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
257
- st.session_state['rqa'][model] = init_qa(model)
258
 
259
  elif model in OPENAI_MODELS and model not in st.session_state['api_keys']:
260
  if 'OPENAI_API_KEY' not in os.environ:
@@ -268,9 +306,9 @@ with st.sidebar:
268
  with st.spinner("Preparing environment"):
269
  st.session_state['api_keys'][model] = api_key
270
  if 'OPENAI_API_KEY' not in os.environ:
271
- st.session_state['rqa'][model] = init_qa(model, api_key)
272
  else:
273
- st.session_state['rqa'][model] = init_qa(model)
274
  # else:
275
  # is_api_key_provided = st.session_state['api_key']
276
 
@@ -305,16 +343,24 @@ question = st.chat_input(
305
  disabled=not uploaded_file
306
  )
307
 
 
 
 
 
 
 
308
  with st.sidebar:
309
  st.header("Settings")
310
  mode = st.radio(
311
  "Query mode",
312
- ("LLM", "Embeddings"),
313
  disabled=not uploaded_file,
314
  index=0,
315
  horizontal=True,
 
316
  help="LLM will respond the question, Embedding will show the "
317
- "paragraphs relevant to the question in the paper."
 
318
  )
319
 
320
  # Add a checkbox for showing annotations
@@ -340,9 +386,12 @@ with st.sidebar:
340
 
341
  st.session_state['pdf_rendering'] = st.radio(
342
  "PDF rendering mode",
343
- {"PDF.JS", "Native browser engine"},
344
  index=0,
345
  disabled=not uploaded_file,
 
 
 
346
  )
347
 
348
  st.divider()
@@ -358,10 +407,13 @@ with st.sidebar:
358
 
359
  st.header("Query mode (Advanced use)")
360
  st.markdown(
361
- """By default, the mode is set to LLM (Language Model) which enables question/answering. You can directly ask questions related to the document content, and the system will answer the question using content from the document.""")
 
362
 
363
  st.markdown(
364
- """If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
 
 
365
 
366
  if uploaded_file and not st.session_state.loaded_embeddings:
367
  if model not in st.session_state['api_keys']:
@@ -426,10 +478,12 @@ with right_column:
426
  if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
427
  for message in st.session_state.messages:
428
  with st.chat_message(message["role"]):
429
- if message['mode'] == "LLM":
430
  st.markdown(message["content"], unsafe_allow_html=True)
431
- elif message['mode'] == "Embeddings":
432
  st.write(message["content"])
 
 
433
  if model not in st.session_state['rqa']:
434
  st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
435
  st.stop()
@@ -439,30 +493,43 @@ with right_column:
439
  st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
440
 
441
  text_response = None
442
- if mode == "Embeddings":
 
 
 
 
 
 
 
443
  with st.spinner("Generating LLM response..."):
444
- text_response = st.session_state['rqa'][model].query_storage(question, st.session_state.doc_id,
445
- context_size=context_size)
446
- elif mode == "LLM":
447
- with st.spinner("Generating response..."):
448
- _, text_response, coordinates = st.session_state['rqa'][model].query_document(question,
449
- st.session_state.doc_id,
450
- context_size=context_size)
451
-
452
- annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
453
- for coord_doc in coordinates]
454
- gradients = generate_color_gradient(len(annotations))
455
- for i, color in enumerate(gradients):
456
- for annotation in annotations[i]:
457
- annotation['color'] = color
458
- st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in
459
- annotation_doc]
 
 
 
 
 
 
460
 
461
  if not text_response:
462
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
463
 
464
  with st.chat_message("assistant"):
465
- if mode == "LLM":
466
  if st.session_state['ner_processing']:
467
  with st.spinner("Processing NER on LLM response..."):
468
  entities = gqa.process_single_text(text_response)
@@ -486,6 +553,6 @@ with left_column:
486
  height=800,
487
  annotation_outline_size=1,
488
  annotations=st.session_state['annotations'],
489
- rendering='unwrap' if st.session_state['pdf_rendering'] == 'PDF.JS' else 'legacy_embed',
490
  render_text=True
491
  )
 
5
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
 
8
  from langchain.memory import ConversationBufferWindowMemory
9
+ from langchain_community.chat_models.openai import ChatOpenAI
10
+ from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings
11
+ from langchain_community.embeddings.openai import OpenAIEmbeddings
12
+ from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
13
  from streamlit_pdf_viewer import pdf_viewer
14
 
15
+ from document_qa.ner_client_generic import NERClientGeneric
16
+
17
  dotenv.load_dotenv(override=True)
18
 
19
  import streamlit as st
20
+ from document_qa.document_qa_engine import DocumentQAEngine, DataStorage
 
 
 
21
  from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
 
22
 
23
  OPENAI_MODELS = ['gpt-3.5-turbo',
24
  "gpt-4",
25
  "gpt-4-1106-preview"]
26
 
27
+ OPENAI_EMBEDDINGS = [
28
+ 'text-embedding-ada-002',
29
+ 'text-embedding-3-large',
30
+ 'openai-text-embedding-3-small'
31
+ ]
32
+
33
  OPEN_MODELS = {
34
+ 'mistral-7b-instruct-v0.3': 'mistralai/Mistral-7B-Instruct-v0.2',
35
+ # 'Phi-3-mini-128k-instruct': "microsoft/Phi-3-mini-128k-instruct",
36
+ 'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct"
37
+ }
38
+
39
+ DEFAULT_OPEN_EMBEDDING_NAME = 'Default (all-MiniLM-L6-v2)'
40
+ OPEN_EMBEDDINGS = {
41
+ DEFAULT_OPEN_EMBEDDING_NAME: 'all-MiniLM-L6-v2',
42
+ 'Salesforce/SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral'
43
  }
44
 
45
  DISABLE_MEMORY = ['zephyr-7b-beta']
 
96
  if 'pdf_rendering' not in st.session_state:
97
  st.session_state['pdf_rendering'] = None
98
 
99
+ if 'embeddings' not in st.session_state:
100
+ st.session_state['embeddings'] = None
101
+
102
  st.set_page_config(
103
  page_title="Scientific Document Insights Q/A",
104
  page_icon="📝",
 
155
 
156
 
157
  # @st.cache_resource
158
+ def init_qa(model, embeddings_name=None, api_key=None):
159
  ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
160
  if model in OPENAI_MODELS:
161
+ if embeddings_name is None:
162
+ embeddings_name = 'text-embedding-ada-002'
163
+
164
  st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
165
  if api_key:
166
  chat = ChatOpenAI(model_name=model,
167
  temperature=0,
168
  openai_api_key=api_key,
169
  frequency_penalty=0.1)
170
+ if embeddings_name not in OPENAI_EMBEDDINGS:
171
+ st.error(f"The embeddings provided {embeddings_name} are not supported by this model {model}.")
172
+ st.stop()
173
+ return
174
+ embeddings = OpenAIEmbeddings(model=embeddings_name, openai_api_key=api_key)
175
 
176
  else:
177
  chat = ChatOpenAI(model_name=model,
178
  temperature=0,
179
  frequency_penalty=0.1)
180
+ embeddings = OpenAIEmbeddings(model=embeddings_name)
181
 
182
  elif model in OPEN_MODELS:
183
+ if embeddings_name is None:
184
+ embeddings_name = DEFAULT_OPEN_EMBEDDING_NAME
185
+
186
+ chat = HuggingFaceEndpoint(
187
  repo_id=OPEN_MODELS[model],
188
+ temperature=0.01,
189
+ max_new_tokens=2048,
190
+ model_kwargs={"max_length": 4096}
191
  )
192
  embeddings = HuggingFaceEmbeddings(
193
+ model_name=OPEN_EMBEDDINGS[embeddings_name])
194
  st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None
195
  else:
196
  st.error("The model was not loaded properly. Try reloading. ")
197
  st.stop()
198
  return
199
 
200
+ storage = DataStorage(embeddings)
201
+ return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
202
 
203
 
204
  @st.cache_resource
205
  def init_ner():
206
  quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True)
207
 
208
+ materials_client = NERClientGeneric(ping=True)
209
  config_materials = {
210
  'grobid': {
211
  "server": os.environ['GROBID_MATERIALS_URL'],
 
220
 
221
  materials_client.set_config(config_materials)
222
 
223
+ gqa = GrobidAggregationProcessor(grobid_quantities_client=quantities_client,
224
+ grobid_superconductors_client=materials_client)
 
 
225
  return gqa
226
 
227
 
 
257
  "Model:",
258
  options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
259
  index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index(
260
+ "mistral-7b-instruct-v0.2") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else (
261
  OPENAI_MODELS + list(OPEN_MODELS.keys())).index(os.environ["DEFAULT_MODEL"]),
262
  placeholder="Select model",
263
  help="Select the LLM model:",
264
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
265
  )
266
+ embedding_choices = OPENAI_EMBEDDINGS if model in OPENAI_MODELS else OPEN_EMBEDDINGS
267
+
268
+ st.session_state['embeddings'] = embedding_name = st.selectbox(
269
+ "Embeddings:",
270
+ options=embedding_choices,
271
+ index=0,
272
+ placeholder="Select embedding",
273
+ help="Select the Embedding function:",
274
+ disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
275
+ )
276
 
277
  st.markdown(
278
+ ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ")
279
 
280
  if (model in OPEN_MODELS) and model not in st.session_state['api_keys']:
281
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
 
292
  st.session_state['api_keys'][model] = api_key
293
  # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
294
  # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
295
+ st.session_state['rqa'][model] = init_qa(model, embedding_name)
296
 
297
  elif model in OPENAI_MODELS and model not in st.session_state['api_keys']:
298
  if 'OPENAI_API_KEY' not in os.environ:
 
306
  with st.spinner("Preparing environment"):
307
  st.session_state['api_keys'][model] = api_key
308
  if 'OPENAI_API_KEY' not in os.environ:
309
+ st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'], api_key)
310
  else:
311
+ st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'])
312
  # else:
313
  # is_api_key_provided = st.session_state['api_key']
314
 
 
343
  disabled=not uploaded_file
344
  )
345
 
346
+ query_modes = {
347
+ "llm": "LLM Q/A",
348
+ "embeddings": "Embeddings",
349
+ "question_coefficient": "Question coefficient"
350
+ }
351
+
352
  with st.sidebar:
353
  st.header("Settings")
354
  mode = st.radio(
355
  "Query mode",
356
+ ("llm", "embeddings", "question_coefficient"),
357
  disabled=not uploaded_file,
358
  index=0,
359
  horizontal=True,
360
+ format_func=lambda x: query_modes[x],
361
  help="LLM will respond the question, Embedding will show the "
362
+ "relevant paragraphs to the question in the paper. "
363
+ "Question coefficient attempt to estimate how effective the question will be answered."
364
  )
365
 
366
  # Add a checkbox for showing annotations
 
386
 
387
  st.session_state['pdf_rendering'] = st.radio(
388
  "PDF rendering mode",
389
+ ("unwrap", "legacy_embed"),
390
  index=0,
391
  disabled=not uploaded_file,
392
+ help="PDF rendering engine."
393
+ "Note: The Legacy PDF viewer does not support annotations and might not work on Chrome.",
394
+ format_func=lambda q: "Legacy PDF Viewer" if q == "legacy_embed" else "Streamlit PDF Viewer (Pdf.js)"
395
  )
396
 
397
  st.divider()
 
407
 
408
  st.header("Query mode (Advanced use)")
409
  st.markdown(
410
+ """By default, the mode is set to LLM (Language Model) which enables question/answering.
411
+ You can directly ask questions related to the document content, and the system will answer the question using content from the document.""")
412
 
413
  st.markdown(
414
+ """If you switch the mode to "Embedding," the system will return specific chunks from the document
415
+ that are semantically related to your query. This mode helps to test why sometimes the answers are not
416
+ satisfying or incomplete. """)
417
 
418
  if uploaded_file and not st.session_state.loaded_embeddings:
419
  if model not in st.session_state['api_keys']:
 
478
  if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id:
479
  for message in st.session_state.messages:
480
  with st.chat_message(message["role"]):
481
+ if message['mode'] == "llm":
482
  st.markdown(message["content"], unsafe_allow_html=True)
483
+ elif message['mode'] == "embeddings":
484
  st.write(message["content"])
485
+ if message['mode'] == "question_coefficient":
486
+ st.markdown(message["content"], unsafe_allow_html=True)
487
  if model not in st.session_state['rqa']:
488
  st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
489
  st.stop()
 
493
  st.session_state.messages.append({"role": "user", "mode": mode, "content": question})
494
 
495
  text_response = None
496
+ if mode == "embeddings":
497
+ with st.spinner("Fetching the relevant context..."):
498
+ text_response, coordinates = st.session_state['rqa'][model].query_storage(
499
+ question,
500
+ st.session_state.doc_id,
501
+ context_size=context_size
502
+ )
503
+ elif mode == "llm":
504
  with st.spinner("Generating LLM response..."):
505
+ _, text_response, coordinates = st.session_state['rqa'][model].query_document(
506
+ question,
507
+ st.session_state.doc_id,
508
+ context_size=context_size
509
+ )
510
+
511
+ elif mode == "question_coefficient":
512
+ with st.spinner("Estimate question/context relevancy..."):
513
+ text_response, coordinates = st.session_state['rqa'][model].analyse_query(
514
+ question,
515
+ st.session_state.doc_id,
516
+ context_size=context_size
517
+ )
518
+
519
+ annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
520
+ for coord_doc in coordinates]
521
+ gradients = generate_color_gradient(len(annotations))
522
+ for i, color in enumerate(gradients):
523
+ for annotation in annotations[i]:
524
+ annotation['color'] = color
525
+ st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in
526
+ annotation_doc]
527
 
528
  if not text_response:
529
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
530
 
531
  with st.chat_message("assistant"):
532
+ if mode == "llm":
533
  if st.session_state['ner_processing']:
534
  with st.spinner("Processing NER on LLM response..."):
535
  entities = gqa.process_single_text(text_response)
 
553
  height=800,
554
  annotation_outline_size=1,
555
  annotations=st.session_state['annotations'],
556
+ rendering=st.session_state['pdf_rendering']
557
  render_text=True
558
  )