nickmuchi commited on
Commit
35f456f
·
1 Parent(s): d33afe3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -52,17 +52,22 @@ def extract_text_from_file(file):
52
 
53
  # To read file as string:
54
  file_text = stringio.read()
 
 
55
 
56
  # read pdf file
57
  elif file.type == "application/pdf":
58
  pdfReader = PdfFileReader(file)
59
  count = pdfReader.numPages
60
  all_text = ""
 
61
 
62
  for i in range(count):
63
  page = pdfReader.getPage(i)
64
  all_text += page.extractText()
65
  file_text = all_text
 
 
66
 
67
  # read docx file
68
  elif (
@@ -70,8 +75,8 @@ def extract_text_from_file(file):
70
  == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
71
  ):
72
  file_text = docx2txt.process(file)
73
-
74
- return file_text
75
 
76
  def preprocess_plain_text(text,window_size=3):
77
 
@@ -171,6 +176,10 @@ def search_func(query, top_k=2):
171
  if url_text:
172
 
173
  st.write(f"Document Header: {title}")
 
 
 
 
174
 
175
  ##### BM25 search (lexical search) #####
176
  bm25_scores = bm25.get_scores(bm25_tokenizer(query))
@@ -178,7 +187,7 @@ def search_func(query, top_k=2):
178
  bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
179
  bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
180
 
181
- st.write(f"Top-{top_k} lexical search (BM25) hits")
182
  for hit in bm25_hits[0:top_k]:
183
  st.write("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
184
 
@@ -200,14 +209,14 @@ def search_func(query, top_k=2):
200
 
201
  # Output of top-3 hits from bi-encoder
202
  st.markdown("\n-------------------------\n")
203
- st.write(f"Top-{top_k} Bi-Encoder Retrieval hits")
204
  hits = sorted(hits, key=lambda x: x['score'], reverse=True)
205
  for hit in hits[0:top_k]:
206
  st.write("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
207
 
208
  # Output of top-3 hits from re-ranker
209
  st.markdown("\n-------------------------\n")
210
- st.write(f"Top-{top_k} Cross-Encoder Re-ranker hits")
211
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
212
  for hit in hits[0:top_k]:
213
  st.write("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
@@ -284,7 +293,8 @@ if validators.url(url_text):
284
 
285
  elif upload_doc:
286
 
287
- passages = preprocess_plain_text(extract_text_from_file(upload_doc),window_size=window_size)
 
288
 
289
  search = st.button("Search")
290
 
 
52
 
53
  # To read file as string:
54
  file_text = stringio.read()
55
+
56
+ return file_text, None
57
 
58
  # read pdf file
59
  elif file.type == "application/pdf":
60
  pdfReader = PdfFileReader(file)
61
  count = pdfReader.numPages
62
  all_text = ""
63
+ pdf_title = pdfReader.getDocumentInfo().title
64
 
65
  for i in range(count):
66
  page = pdfReader.getPage(i)
67
  all_text += page.extractText()
68
  file_text = all_text
69
+
70
+ return file_text, pdf_title
71
 
72
  # read docx file
73
  elif (
 
75
  == "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
76
  ):
77
  file_text = docx2txt.process(file)
78
+
79
+ return file_text, None
80
 
81
  def preprocess_plain_text(text,window_size=3):
82
 
 
176
  if url_text:
177
 
178
  st.write(f"Document Header: {title}")
179
+
180
+ elif pdf_title:
181
+
182
+ st.write(f"Document Header: {pdf_title}")
183
 
184
  ##### BM25 search (lexical search) #####
185
  bm25_scores = bm25.get_scores(bm25_tokenizer(query))
 
187
  bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
188
  bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
189
 
190
+ st.subheader(f"Top-{top_k} lexical search (BM25) hits")
191
  for hit in bm25_hits[0:top_k]:
192
  st.write("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
193
 
 
209
 
210
  # Output of top-3 hits from bi-encoder
211
  st.markdown("\n-------------------------\n")
212
+ st.subheader(f"Top-{top_k} Bi-Encoder Retrieval hits")
213
  hits = sorted(hits, key=lambda x: x['score'], reverse=True)
214
  for hit in hits[0:top_k]:
215
  st.write("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
216
 
217
  # Output of top-3 hits from re-ranker
218
  st.markdown("\n-------------------------\n")
219
+ st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
220
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
221
  for hit in hits[0:top_k]:
222
  st.write("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
 
293
 
294
  elif upload_doc:
295
 
296
+ text, pdf_title = extract_text_from_file(upload_doc)
297
+ passages = preprocess_plain_text(text,window_size=window_size)
298
 
299
  search = st.button("Search")
300