lotrlol commited on
Commit
bf7dfb8
Β·
1 Parent(s): 9c5ddf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -32,7 +32,7 @@ class DocumentSearch:
32
  # loading faiss index
33
  self.index = faiss.read_index(DocumentSearch.idx_path)
34
  # loading sbert cross_encoder
35
- self.cross_encoder = CrossEncoder(DocumentSearch.cross_enc_path)
36
 
37
  def search(self, query: str, k: int) -> list:
38
  # get vector representation of text query
@@ -43,24 +43,21 @@ class DocumentSearch:
43
  res_docs = [self.docs[i] for i in indeces[0]]
44
  # get scores by index
45
  dists = [dist for dist in distances[0]]
46
-
 
 
47
  # get answers by index
48
- answers = [self.docs[i] for i in indeces[0]]
49
  # prepare inputs for cross encoder
50
- model_inputs = [[query, pairs[0]] for pairs in answers]
51
- urls = [pairs[1] for pairs in answers]
52
  # get similarity score between query and documents
53
- scores = self.cross_encoder.predict(model_inputs, batch_size=1)
54
  # compose results into list of dicts
55
- results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
56
 
57
  # return results sorted by similarity scores
58
- return sorted(results, key=lambda x: x['score'], reverse=True)[:k]
59
-
60
-
61
- if __name__ == "__main__":
62
- # get instance of DocumentSearch class
63
- surfer = DocumentSearch()
64
 
65
 
66
  if __name__ == "__main__":
@@ -89,7 +86,7 @@ if __name__ == "__main__":
89
  # set start time
90
  stt = time.time()
91
  # retrieve top 5 documents
92
- results = surfer.search(query, k=1)
93
  # set endtime
94
  ent = time.time()
95
  # measure resulting time
@@ -114,4 +111,4 @@ if __name__ == "__main__":
114
  else:
115
  st.markdown("Typical queries looks like this: _**\"What is flu?\"**_,\
116
  _**\"How to cure breast cancer?\"**_,\
117
- _**\"I have headache, what should I do?\"**_")
 
32
  # loading faiss index
33
  self.index = faiss.read_index(DocumentSearch.idx_path)
34
  # loading sbert cross_encoder
35
+ # self.cross_encoder = CrossEncoder(DocumentSearch.cross_enc_path)
36
 
37
  def search(self, query: str, k: int) -> list:
38
  # get vector representation of text query
 
43
  res_docs = [self.docs[i] for i in indeces[0]]
44
  # get scores by index
45
  dists = [dist for dist in distances[0]]
46
+
47
+ return[{'doc': doc[0], 'url': doc[1], 'score': dist} for doc, dist in zip(res_docs, dists)][:k]
48
+ ##### OLD VERSION WITH CROSS-ENCODER #####
49
  # get answers by index
50
+ #answers = [self.docs[i] for i in indeces[0]]
51
  # prepare inputs for cross encoder
52
+ # model_inputs = [[query, pairs[0]] for pairs in answers]
53
+ # urls = [pairs[1] for pairs in answers]
54
  # get similarity score between query and documents
55
+ # scores = self.cross_encoder.predict(model_inputs, batch_size=1)
56
  # compose results into list of dicts
57
+ # results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
58
 
59
  # return results sorted by similarity scores
60
+ # return sorted(results, key=lambda x: x['score'], reverse=True)[:k]
 
 
 
 
 
61
 
62
 
63
  if __name__ == "__main__":
 
86
  # set start time
87
  stt = time.time()
88
  # retrieve top 5 documents
89
+ results = surfer.search(query, k=10)
90
  # set endtime
91
  ent = time.time()
92
  # measure resulting time
 
111
  else:
112
  st.markdown("Typical queries looks like this: _**\"What is flu?\"**_,\
113
  _**\"How to cure breast cancer?\"**_,\
114
+ _**\"I have headache, what should I do?\"**_")