Luca Foppiano commited on
Commit
cdfc7ae
·
unverified ·
2 Parent(s): cbdc1a4 b0a0e1a

Merge pull request #21 from lfoppiano/include-biblio-in-embeddings

Browse files
document_qa/document_qa_engine.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  from pathlib import Path
4
  from typing import Union, Any
5
 
 
6
  from grobid_client.grobid_client import GrobidClient
7
  from langchain.chains import create_extraction_chain
8
  from langchain.chains.question_answering import load_qa_chain
@@ -12,8 +13,6 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain.vectorstores import Chroma
13
  from tqdm import tqdm
14
 
15
- from document_qa.grobid_processors import GrobidProcessor
16
-
17
 
18
  class DocumentQAEngine:
19
  llm = None
@@ -173,8 +172,10 @@ class DocumentQAEngine:
173
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
174
  return relevant_documents
175
 
176
- def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
177
- """Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
 
 
178
  if verbose:
179
  print("File", pdf_file_path)
180
  filename = Path(pdf_file_path).stem
@@ -189,6 +190,7 @@ class DocumentQAEngine:
189
  texts = []
190
  metadatas = []
191
  ids = []
 
192
  if chunk_size < 0:
193
  for passage in structure['passages']:
194
  biblio_copy = copy.copy(biblio)
@@ -212,10 +214,25 @@ class DocumentQAEngine:
212
  metadatas = [biblio for _ in range(len(texts))]
213
  ids = [id for id, t in enumerate(texts)]
214
 
 
 
 
 
 
 
 
 
 
 
215
  return texts, metadatas, ids
216
 
217
- def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
218
- texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=chunk_size, perc_overlap=perc_overlap)
 
 
 
 
 
219
  if doc_id:
220
  hash = doc_id
221
  else:
@@ -233,7 +250,7 @@ class DocumentQAEngine:
233
 
234
  return hash
235
 
236
- def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
237
  input_files = []
238
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
239
  for file_ in files:
@@ -250,9 +267,12 @@ class DocumentQAEngine:
250
  if os.path.exists(data_path):
251
  print(data_path, "exists. Skipping it ")
252
  continue
253
-
254
- texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
255
- perc_overlap=perc_overlap)
 
 
 
256
  filename = metadata[0]['filename']
257
 
258
  vector_db_document = Chroma.from_texts(texts,
 
3
  from pathlib import Path
4
  from typing import Union, Any
5
 
6
+ from document_qa.grobid_processors import GrobidProcessor
7
  from grobid_client.grobid_client import GrobidClient
8
  from langchain.chains import create_extraction_chain
9
  from langchain.chains.question_answering import load_qa_chain
 
13
  from langchain.vectorstores import Chroma
14
  from tqdm import tqdm
15
 
 
 
16
 
17
  class DocumentQAEngine:
18
  llm = None
 
172
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
173
  return relevant_documents
174
 
175
+ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, include=(), verbose=False):
176
+ """
177
+ Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
178
+ """
179
  if verbose:
180
  print("File", pdf_file_path)
181
  filename = Path(pdf_file_path).stem
 
190
  texts = []
191
  metadatas = []
192
  ids = []
193
+
194
  if chunk_size < 0:
195
  for passage in structure['passages']:
196
  biblio_copy = copy.copy(biblio)
 
214
  metadatas = [biblio for _ in range(len(texts))]
215
  ids = [id for id, t in enumerate(texts)]
216
 
217
+ if "biblio" in include:
218
+ biblio_metadata = copy.copy(biblio)
219
+ biblio_metadata['type'] = "biblio"
220
+ biblio_metadata['section'] = "header"
221
+ for key in ['title', 'authors', 'publication_year']:
222
+ if key in biblio_metadata:
223
+ texts.append("{}: {}".format(key, biblio_metadata[key]))
224
+ metadatas.append(biblio_metadata)
225
+ ids.append(key)
226
+
227
  return texts, metadatas, ids
228
 
229
+ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1, include_biblio=False):
230
+ include = ["biblio"] if include_biblio else []
231
+ texts, metadata, ids = self.get_text_from_document(
232
+ pdf_path,
233
+ chunk_size=chunk_size,
234
+ perc_overlap=perc_overlap,
235
+ include=include)
236
  if doc_id:
237
  hash = doc_id
238
  else:
 
250
 
251
  return hash
252
 
253
+ def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1, include_biblio=False):
254
  input_files = []
255
  for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
256
  for file_ in files:
 
267
  if os.path.exists(data_path):
268
  print(data_path, "exists. Skipping it ")
269
  continue
270
+ include = ["biblio"] if include_biblio else []
271
+ texts, metadata, ids = self.get_text_from_document(
272
+ input_file,
273
+ chunk_size=chunk_size,
274
+ perc_overlap=perc_overlap,
275
+ include=include)
276
  filename = metadata[0]['filename']
277
 
278
  vector_db_document = Chroma.from_texts(texts,
document_qa/grobid_processors.py CHANGED
@@ -171,7 +171,7 @@ class GrobidProcessor(BaseProcessor):
171
  }
172
  try:
173
  year = dateparser.parse(doc_biblio.header.date).year
174
- biblio["year"] = year
175
  except:
176
  pass
177
 
 
171
  }
172
  try:
173
  year = dateparser.parse(doc_biblio.header.date).year
174
+ biblio["publication_year"] = year
175
  except:
176
  pass
177
 
streamlit_app.py CHANGED
@@ -283,7 +283,8 @@ if uploaded_file and not st.session_state.loaded_embeddings:
283
  # hash = get_file_hash(tmp_file.name)[:10]
284
  st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
285
  chunk_size=chunk_size,
286
- perc_overlap=0.1)
 
287
  st.session_state['loaded_embeddings'] = True
288
  st.session_state.messages = []
289
 
 
283
  # hash = get_file_hash(tmp_file.name)[:10]
284
  st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
285
  chunk_size=chunk_size,
286
+ perc_overlap=0.1,
287
+ include_biblio=True)
288
  st.session_state['loaded_embeddings'] = True
289
  st.session_state.messages = []
290