Jonas Leeb commited on
Commit
a7b2b6d
·
1 Parent(s): 66113e1

multiple deivces shouldnt interfer as much

Browse files
Files changed (1) hide show
  1. app.py +30 -25
app.py CHANGED
@@ -28,7 +28,7 @@ class ArxivSearch:
28
 
29
  # model selection
30
  self.embedding_dropdown = gr.Dropdown(
31
- choices=["tfidf", "word2vec", "bert", "scibert", "sbert"],
32
  value="bert",
33
  label="Model"
34
  )
@@ -56,9 +56,14 @@ class ArxivSearch:
56
  inputs=[self.query_box, self.embedding_dropdown],
57
  outputs=self.output_md
58
  )
 
 
 
 
 
59
  self.embedding_dropdown.change(
60
- self.model_switch,
61
- inputs=[self.embedding_dropdown],
62
  outputs=self.output_md
63
  )
64
  self.plot_button.click(
@@ -73,12 +78,12 @@ class ArxivSearch:
73
  )
74
 
75
  self.load_data(dataset)
76
- self.load_model(embedding)
77
- # self.load_model('tfidf')
78
- # self.load_model('word2vec')
79
- # self.load_model('bert')
80
  # self.load_model('scibert')
81
- # self.load_model('sbert')
82
 
83
  self.iface.launch()
84
 
@@ -139,8 +144,8 @@ class ArxivSearch:
139
  reduced_data, reduced_results_points, query_point = self.plot_dense(self.bert_embeddings, pca, results_indices)
140
  elif self.embedding == "sbert":
141
  reduced_data, reduced_results_points, query_point = self.plot_dense(self.sbert_embedding, pca, results_indices)
142
- elif self.embedding == "scibert":
143
- reduced_data, reduced_results_points, query_point = self.plot_dense(self.scibert_embeddings, pca, results_indices)
144
  else:
145
  raise ValueError(f"Unsupported embedding type: {self.embedding}")
146
  trace = go.Scatter3d(
@@ -241,17 +246,17 @@ class ArxivSearch:
241
  print(f"sim, top_indices: {sims}, {top_indices}")
242
  return [(i, sims[i]) for i in top_indices]
243
 
244
- def scibert_search(self, query, top_n=10):
245
- with torch.no_grad():
246
- inputs = self.sci_tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=512)
247
- outputs = self.sci_model(**inputs)
248
- query_vec = outputs.last_hidden_state[:, 0, :].numpy()
249
 
250
- self.query_encoding = query_vec
251
- sims = cosine_similarity(query_vec, self.scibert_embeddings).flatten()
252
- top_indices = sims.argsort()[::-1][:top_n]
253
- print(f"sim, top_indices: {sims}, {top_indices}")
254
- return [(i, sims[i]) for i in top_indices]
255
 
256
  def sbert_search(self, query, top_n=10):
257
  query_vec = self.sbert_model.encode([query])
@@ -312,11 +317,11 @@ class ArxivSearch:
312
  self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
313
  self.model = BertModel.from_pretrained('bert-base-uncased')
314
  self.model.eval()
315
- elif self.embedding == "scibert":
316
- self.scibert_embeddings = np.load("SciBERT_embeddings/scibert_embedding.npz")["bert_embedding"]
317
- self.sci_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
318
- self.sci_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
319
- self.sci_model.eval()
320
  elif self.embedding == "sbert":
321
  self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
322
  self.sbert_embedding = np.load("BERT embeddings/sbert_embedding.npz")["sbert_embedding"]
 
28
 
29
  # model selection
30
  self.embedding_dropdown = gr.Dropdown(
31
+ choices=["tfidf", "word2vec", "bert", "sbert"],
32
  value="bert",
33
  label="Model"
34
  )
 
56
  inputs=[self.query_box, self.embedding_dropdown],
57
  outputs=self.output_md
58
  )
59
+ # self.embedding_dropdown.change(
60
+ # self.model_switch,
61
+ # inputs=[self.embedding_dropdown],
62
+ # outputs=self.output_md
63
+ # )
64
  self.embedding_dropdown.change(
65
+ self.search_function,
66
+ inputs=[self.query_box, self.embedding_dropdown],
67
  outputs=self.output_md
68
  )
69
  self.plot_button.click(
 
78
  )
79
 
80
  self.load_data(dataset)
81
+ # self.load_model(embedding)
82
+ self.load_model('tfidf')
83
+ self.load_model('word2vec')
84
+ self.load_model('bert')
85
  # self.load_model('scibert')
86
+ self.load_model('sbert')
87
 
88
  self.iface.launch()
89
 
 
144
  reduced_data, reduced_results_points, query_point = self.plot_dense(self.bert_embeddings, pca, results_indices)
145
  elif self.embedding == "sbert":
146
  reduced_data, reduced_results_points, query_point = self.plot_dense(self.sbert_embedding, pca, results_indices)
147
+ # elif self.embedding == "scibert":
148
+ # reduced_data, reduced_results_points, query_point = self.plot_dense(self.scibert_embeddings, pca, results_indices)
149
  else:
150
  raise ValueError(f"Unsupported embedding type: {self.embedding}")
151
  trace = go.Scatter3d(
 
246
  print(f"sim, top_indices: {sims}, {top_indices}")
247
  return [(i, sims[i]) for i in top_indices]
248
 
249
+ # def scibert_search(self, query, top_n=10):
250
+ # with torch.no_grad():
251
+ # inputs = self.sci_tokenizer(query, return_tensors="pt", truncation=True, padding=True, max_length=512)
252
+ # outputs = self.sci_model(**inputs)
253
+ # query_vec = outputs.last_hidden_state[:, 0, :].numpy()
254
 
255
+ # self.query_encoding = query_vec
256
+ # sims = cosine_similarity(query_vec, self.scibert_embeddings).flatten()
257
+ # top_indices = sims.argsort()[::-1][:top_n]
258
+ # print(f"sim, top_indices: {sims}, {top_indices}")
259
+ # return [(i, sims[i]) for i in top_indices]
260
 
261
  def sbert_search(self, query, top_n=10):
262
  query_vec = self.sbert_model.encode([query])
 
317
  self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
318
  self.model = BertModel.from_pretrained('bert-base-uncased')
319
  self.model.eval()
320
+ # elif self.embedding == "scibert":
321
+ # self.scibert_embeddings = np.load("SciBERT_embeddings/scibert_embedding.npz")["bert_embedding"]
322
+ # self.sci_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
323
+ # self.sci_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
324
+ # self.sci_model.eval()
325
  elif self.embedding == "sbert":
326
  self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2")
327
  self.sbert_embedding = np.load("BERT embeddings/sbert_embedding.npz")["sbert_embedding"]