Jonas Leeb
commited on
Commit
·
dc760b4
1
Parent(s):
0fbc2c7
bug fixes and usability improvements
Browse files
app.py
CHANGED
@@ -57,8 +57,8 @@ class ArxivSearch:
|
|
57 |
outputs=self.output_md
|
58 |
)
|
59 |
self.embedding_dropdown.change(
|
60 |
-
self.
|
61 |
-
inputs=[self.
|
62 |
outputs=self.output_md
|
63 |
)
|
64 |
self.plot_button.click(
|
@@ -73,11 +73,12 @@ class ArxivSearch:
|
|
73 |
)
|
74 |
|
75 |
self.load_data(dataset)
|
76 |
-
self.load_model(
|
77 |
-
self.load_model('
|
78 |
-
self.load_model('
|
79 |
-
self.load_model('
|
80 |
-
self.load_model('
|
|
|
81 |
|
82 |
self.iface.launch()
|
83 |
|
@@ -114,7 +115,6 @@ class ArxivSearch:
|
|
114 |
self.arxiv_ids.append(arxiv_id)
|
115 |
|
116 |
def plot_dense(self, embedding, pca, results_indices):
|
117 |
-
print(self.query_encoding.shape[0])
|
118 |
all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0]))))
|
119 |
all_data = embedding[all_indices]
|
120 |
pca.fit(all_data)
|
@@ -149,7 +149,9 @@ class ArxivSearch:
|
|
149 |
z=reduced_data[:, 2],
|
150 |
mode='markers',
|
151 |
marker=dict(size=3.5, color="#ffffff", opacity=0.2),
|
152 |
-
name='All Documents'
|
|
|
|
|
153 |
)
|
154 |
layout = go.Layout(
|
155 |
margin=dict(l=0, r=0, b=0, t=0),
|
@@ -172,7 +174,9 @@ class ArxivSearch:
|
|
172 |
z=reduced_results_points[:, 2],
|
173 |
mode='markers',
|
174 |
marker=dict(size=3.5, color='orange', opacity=0.75),
|
175 |
-
name='Results'
|
|
|
|
|
176 |
)
|
177 |
if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
|
178 |
query_trace = go.Scatter3d(
|
@@ -181,7 +185,9 @@ class ArxivSearch:
|
|
181 |
z=query_point[:, 2],
|
182 |
mode='markers',
|
183 |
marker=dict(size=5, color='red', opacity=0.8),
|
184 |
-
name='Query'
|
|
|
|
|
185 |
)
|
186 |
fig = go.Figure(data=[trace, results_trace, query_trace], layout=layout)
|
187 |
else:
|
@@ -209,7 +215,7 @@ class ArxivSearch:
|
|
209 |
if not tokens:
|
210 |
return []
|
211 |
vectors = np.array([self.wv_model[word] for word in tokens])
|
212 |
-
query_vec =
|
213 |
self.query_encoding = query_vec
|
214 |
sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
|
215 |
top_indices = sims.argsort()[::-1][:top_n]
|
@@ -219,7 +225,6 @@ class ArxivSearch:
|
|
219 |
with torch.no_grad():
|
220 |
inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
|
221 |
outputs = self.model(**inputs)
|
222 |
-
# query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
|
223 |
query_vec = outputs.last_hidden_state[:, 0, :].numpy()
|
224 |
|
225 |
self.query_encoding = query_vec
|
@@ -251,6 +256,38 @@ class ArxivSearch:
|
|
251 |
top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]]
|
252 |
print(f"sim, top_indices: {final_scores}, {top_indices}")
|
253 |
return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
|
255 |
def load_model(self, embedding):
|
256 |
self.embedding = embedding
|
@@ -291,8 +328,9 @@ class ArxivSearch:
|
|
291 |
def set_embedding(self, embedding):
|
292 |
self.embedding = embedding
|
293 |
|
294 |
-
def search_function(self, query, embedding):
|
295 |
self.set_embedding(embedding)
|
|
|
296 |
query = query.encode().decode('unicode_escape') # Interpret escape sequences
|
297 |
|
298 |
# Load or switch embedding model here if needed
|
|
|
57 |
outputs=self.output_md
|
58 |
)
|
59 |
self.embedding_dropdown.change(
|
60 |
+
self.model_switch,
|
61 |
+
inputs=[self.embedding_dropdown],
|
62 |
outputs=self.output_md
|
63 |
)
|
64 |
self.plot_button.click(
|
|
|
73 |
)
|
74 |
|
75 |
self.load_data(dataset)
|
76 |
+
self.load_model(embedding)
|
77 |
+
# self.load_model('tfidf')
|
78 |
+
# self.load_model('word2vec')
|
79 |
+
# self.load_model('bert')
|
80 |
+
# self.load_model('scibert')
|
81 |
+
# self.load_model('sbert')
|
82 |
|
83 |
self.iface.launch()
|
84 |
|
|
|
115 |
self.arxiv_ids.append(arxiv_id)
|
116 |
|
117 |
def plot_dense(self, embedding, pca, results_indices):
|
|
|
118 |
all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0]))))
|
119 |
all_data = embedding[all_indices]
|
120 |
pca.fit(all_data)
|
|
|
149 |
z=reduced_data[:, 2],
|
150 |
mode='markers',
|
151 |
marker=dict(size=3.5, color="#ffffff", opacity=0.2),
|
152 |
+
name='All Documents',
|
153 |
+
text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))],
|
154 |
+
hoverinfo='text'
|
155 |
)
|
156 |
layout = go.Layout(
|
157 |
margin=dict(l=0, r=0, b=0, t=0),
|
|
|
174 |
z=reduced_results_points[:, 2],
|
175 |
mode='markers',
|
176 |
marker=dict(size=3.5, color='orange', opacity=0.75),
|
177 |
+
name='Results',
|
178 |
+
text=[f"<br>Snippet: {self.documents[i][:200]}" for i in results_indices],
|
179 |
+
hoverinfo='text'
|
180 |
)
|
181 |
if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
|
182 |
query_trace = go.Scatter3d(
|
|
|
185 |
z=query_point[:, 2],
|
186 |
mode='markers',
|
187 |
marker=dict(size=5, color='red', opacity=0.8),
|
188 |
+
name='Query',
|
189 |
+
text=[f"<br>Query: {self.query}"],
|
190 |
+
hoverinfo='text'
|
191 |
)
|
192 |
fig = go.Figure(data=[trace, results_trace, query_trace], layout=layout)
|
193 |
else:
|
|
|
215 |
if not tokens:
|
216 |
return []
|
217 |
vectors = np.array([self.wv_model[word] for word in tokens])
|
218 |
+
query_vec = np.mean(vectors, axis=0).reshape(1, -1)
|
219 |
self.query_encoding = query_vec
|
220 |
sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
|
221 |
top_indices = sims.argsort()[::-1][:top_n]
|
|
|
225 |
with torch.no_grad():
|
226 |
inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
|
227 |
outputs = self.model(**inputs)
|
|
|
228 |
query_vec = outputs.last_hidden_state[:, 0, :].numpy()
|
229 |
|
230 |
self.query_encoding = query_vec
|
|
|
256 |
top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]]
|
257 |
print(f"sim, top_indices: {final_scores}, {top_indices}")
|
258 |
return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]]
|
259 |
+
|
260 |
+
def model_switch(self, embedding, progress=gr.Progress()):
|
261 |
+
if self.embedding != embedding:
|
262 |
+
old_embedding = self.embedding
|
263 |
+
print(f"Switching model to {embedding}")
|
264 |
+
self.load_model(embedding)
|
265 |
+
print(f"Loaded {embedding} model")
|
266 |
+
self.embedding = embedding
|
267 |
+
if old_embedding == "tfidf":
|
268 |
+
del self.tfidf_matrix
|
269 |
+
del self.feature_names
|
270 |
+
if old_embedding == "word2vec":
|
271 |
+
del self.word2vec_embeddings
|
272 |
+
del self.wv_model
|
273 |
+
if old_embedding == "bert":
|
274 |
+
del self.bert_embeddings
|
275 |
+
del self.tokenizer
|
276 |
+
del self.model
|
277 |
+
if old_embedding == "scibert":
|
278 |
+
del self.scibert_embeddings
|
279 |
+
del self.sci_tokenizer
|
280 |
+
del self.sci_model
|
281 |
+
if old_embedding == "sbert":
|
282 |
+
del self.sbert_model
|
283 |
+
del self.sbert_embedding
|
284 |
+
del self.cross_encoder
|
285 |
+
print(f"old embedding removed")
|
286 |
+
if hasattr(self, "query") and self.query:
|
287 |
+
return self.search_function(self.query, self.embedding)
|
288 |
+
else:
|
289 |
+
return "" # Or a message like "Model switched. Please enter a query."
|
290 |
+
return gr.update() # No change if embedding is the same
|
291 |
|
292 |
def load_model(self, embedding):
|
293 |
self.embedding = embedding
|
|
|
328 |
def set_embedding(self, embedding):
|
329 |
self.embedding = embedding
|
330 |
|
331 |
+
def search_function(self, query, embedding, progress=gr.Progress()):
|
332 |
self.set_embedding(embedding)
|
333 |
+
self.query = query
|
334 |
query = query.encode().decode('unicode_escape') # Interpret escape sequences
|
335 |
|
336 |
# Load or switch embedding model here if needed
|