Jonas Leeb commited on
Commit
dc760b4
·
1 Parent(s): 0fbc2c7

bug fixes and usability improvements

Browse files
Files changed (1) hide show
  1. app.py +52 -14
app.py CHANGED
@@ -57,8 +57,8 @@ class ArxivSearch:
57
  outputs=self.output_md
58
  )
59
  self.embedding_dropdown.change(
60
- self.search_function,
61
- inputs=[self.query_box, self.embedding_dropdown],
62
  outputs=self.output_md
63
  )
64
  self.plot_button.click(
@@ -73,11 +73,12 @@ class ArxivSearch:
73
  )
74
 
75
  self.load_data(dataset)
76
- self.load_model('tfidf')
77
- self.load_model('word2vec')
78
- self.load_model('bert')
79
- self.load_model('scibert')
80
- self.load_model('sbert')
 
81
 
82
  self.iface.launch()
83
 
@@ -114,7 +115,6 @@ class ArxivSearch:
114
  self.arxiv_ids.append(arxiv_id)
115
 
116
  def plot_dense(self, embedding, pca, results_indices):
117
- print(self.query_encoding.shape[0])
118
  all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0]))))
119
  all_data = embedding[all_indices]
120
  pca.fit(all_data)
@@ -149,7 +149,9 @@ class ArxivSearch:
149
  z=reduced_data[:, 2],
150
  mode='markers',
151
  marker=dict(size=3.5, color="#ffffff", opacity=0.2),
152
- name='All Documents'
 
 
153
  )
154
  layout = go.Layout(
155
  margin=dict(l=0, r=0, b=0, t=0),
@@ -172,7 +174,9 @@ class ArxivSearch:
172
  z=reduced_results_points[:, 2],
173
  mode='markers',
174
  marker=dict(size=3.5, color='orange', opacity=0.75),
175
- name='Results'
 
 
176
  )
177
  if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
178
  query_trace = go.Scatter3d(
@@ -181,7 +185,9 @@ class ArxivSearch:
181
  z=query_point[:, 2],
182
  mode='markers',
183
  marker=dict(size=5, color='red', opacity=0.8),
184
- name='Query'
 
 
185
  )
186
  fig = go.Figure(data=[trace, results_trace, query_trace], layout=layout)
187
  else:
@@ -209,7 +215,7 @@ class ArxivSearch:
209
  if not tokens:
210
  return []
211
  vectors = np.array([self.wv_model[word] for word in tokens])
212
- query_vec = normalize(np.mean(vectors, axis=0).reshape(1, -1))
213
  self.query_encoding = query_vec
214
  sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
215
  top_indices = sims.argsort()[::-1][:top_n]
@@ -219,7 +225,6 @@ class ArxivSearch:
219
  with torch.no_grad():
220
  inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
221
  outputs = self.model(**inputs)
222
- # query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
223
  query_vec = outputs.last_hidden_state[:, 0, :].numpy()
224
 
225
  self.query_encoding = query_vec
@@ -251,6 +256,38 @@ class ArxivSearch:
251
  top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]]
252
  print(f"sim, top_indices: {final_scores}, {top_indices}")
253
  return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  def load_model(self, embedding):
256
  self.embedding = embedding
@@ -291,8 +328,9 @@ class ArxivSearch:
291
  def set_embedding(self, embedding):
292
  self.embedding = embedding
293
 
294
- def search_function(self, query, embedding):
295
  self.set_embedding(embedding)
 
296
  query = query.encode().decode('unicode_escape') # Interpret escape sequences
297
 
298
  # Load or switch embedding model here if needed
 
57
  outputs=self.output_md
58
  )
59
  self.embedding_dropdown.change(
60
+ self.model_switch,
61
+ inputs=[self.embedding_dropdown],
62
  outputs=self.output_md
63
  )
64
  self.plot_button.click(
 
73
  )
74
 
75
  self.load_data(dataset)
76
+ self.load_model(embedding)
77
+ # self.load_model('tfidf')
78
+ # self.load_model('word2vec')
79
+ # self.load_model('bert')
80
+ # self.load_model('scibert')
81
+ # self.load_model('sbert')
82
 
83
  self.iface.launch()
84
 
 
115
  self.arxiv_ids.append(arxiv_id)
116
 
117
  def plot_dense(self, embedding, pca, results_indices):
 
118
  all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0]))))
119
  all_data = embedding[all_indices]
120
  pca.fit(all_data)
 
149
  z=reduced_data[:, 2],
150
  mode='markers',
151
  marker=dict(size=3.5, color="#ffffff", opacity=0.2),
152
+ name='All Documents',
153
+ text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))],
154
+ hoverinfo='text'
155
  )
156
  layout = go.Layout(
157
  margin=dict(l=0, r=0, b=0, t=0),
 
174
  z=reduced_results_points[:, 2],
175
  mode='markers',
176
  marker=dict(size=3.5, color='orange', opacity=0.75),
177
+ name='Results',
178
+ text=[f"<br>Snippet: {self.documents[i][:200]}" for i in results_indices],
179
+ hoverinfo='text'
180
  )
181
  if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
182
  query_trace = go.Scatter3d(
 
185
  z=query_point[:, 2],
186
  mode='markers',
187
  marker=dict(size=5, color='red', opacity=0.8),
188
+ name='Query',
189
+ text=[f"<br>Query: {self.query}"],
190
+ hoverinfo='text'
191
  )
192
  fig = go.Figure(data=[trace, results_trace, query_trace], layout=layout)
193
  else:
 
215
  if not tokens:
216
  return []
217
  vectors = np.array([self.wv_model[word] for word in tokens])
218
+ query_vec = np.mean(vectors, axis=0).reshape(1, -1)
219
  self.query_encoding = query_vec
220
  sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
221
  top_indices = sims.argsort()[::-1][:top_n]
 
225
  with torch.no_grad():
226
  inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
227
  outputs = self.model(**inputs)
 
228
  query_vec = outputs.last_hidden_state[:, 0, :].numpy()
229
 
230
  self.query_encoding = query_vec
 
256
  top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]]
257
  print(f"sim, top_indices: {final_scores}, {top_indices}")
258
  return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]]
259
+
260
+ def model_switch(self, embedding, progress=gr.Progress()):
261
+ if self.embedding != embedding:
262
+ old_embedding = self.embedding
263
+ print(f"Switching model to {embedding}")
264
+ self.load_model(embedding)
265
+ print(f"Loaded {embedding} model")
266
+ self.embedding = embedding
267
+ if old_embedding == "tfidf":
268
+ del self.tfidf_matrix
269
+ del self.feature_names
270
+ if old_embedding == "word2vec":
271
+ del self.word2vec_embeddings
272
+ del self.wv_model
273
+ if old_embedding == "bert":
274
+ del self.bert_embeddings
275
+ del self.tokenizer
276
+ del self.model
277
+ if old_embedding == "scibert":
278
+ del self.scibert_embeddings
279
+ del self.sci_tokenizer
280
+ del self.sci_model
281
+ if old_embedding == "sbert":
282
+ del self.sbert_model
283
+ del self.sbert_embedding
284
+ del self.cross_encoder
285
+ print(f"old embedding removed")
286
+ if hasattr(self, "query") and self.query:
287
+ return self.search_function(self.query, self.embedding)
288
+ else:
289
+ return "" # Or a message like "Model switched. Please enter a query."
290
+ return gr.update() # No change if embedding is the same
291
 
292
  def load_model(self, embedding):
293
  self.embedding = embedding
 
328
  def set_embedding(self, embedding):
329
  self.embedding = embedding
330
 
331
+ def search_function(self, query, embedding, progress=gr.Progress()):
332
  self.set_embedding(embedding)
333
+ self.query = query
334
  query = query.encode().decode('unicode_escape') # Interpret escape sequences
335
 
336
  # Load or switch embedding model here if needed