Jonas Leeb commited on
Commit
b4a0b98
·
1 Parent(s): 6c71bbc

fixed bug with pca and fixed requirements

Browse files
Files changed (2) hide show
  1. app.py +15 -5
  2. requirements.txt +2 -1
app.py CHANGED
@@ -132,16 +132,26 @@ class ArxivSearch:
132
  pca = PCA(n_components=3)
133
  results_indices = [i[0] for i in self.last_results]
134
  if embedding == "tfidf":
135
- reduced_data = pca.fit_transform(self.tfidf_matrix[:5000].toarray())
136
- reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
 
 
 
137
 
138
  elif embedding == "word2vec":
139
- reduced_data = pca.fit_transform(self.word2vec_embeddings[:5000])
 
 
 
140
  reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
141
 
142
  elif embedding == "bert":
143
- reduced_data = pca.fit_transform(self.bert_embeddings[:5000])
 
 
 
144
  reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
 
145
  else:
146
  raise ValueError(f"Unsupported embedding type: {embedding}")
147
  trace = go.Scatter3d(
@@ -171,7 +181,7 @@ class ArxivSearch:
171
  y=reduced_results_points[:, 1],
172
  z=reduced_results_points[:, 2],
173
  mode='markers',
174
- marker=dict(size=3.5, color='orange', opacity=0.9),
175
  )
176
  fig = go.Figure(data=[trace, results_trace], layout=layout)
177
  else:
 
132
  pca = PCA(n_components=3)
133
  results_indices = [i[0] for i in self.last_results]
134
  if embedding == "tfidf":
135
+ all_indices = list(set(results_indices) | set(range(min(5000, self.tfidf_matrix.shape[0]))))
136
+ all_data = self.tfidf_matrix[all_indices].toarray()
137
+ pca.fit(all_data)
138
+ reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray())
139
+ reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
140
 
141
  elif embedding == "word2vec":
142
+ all_indices = list(set(results_indices) | set(range(min(5000, self.word2vec_embeddings.shape[0]))))
143
+ all_data = self.word2vec_embeddings[all_indices]
144
+ pca.fit(all_data)
145
+ reduced_data = pca.transform(self.word2vec_embeddings[:5000])
146
  reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
147
 
148
  elif embedding == "bert":
149
+ all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
150
+ all_data = self.bert_embeddings[all_indices]
151
+ pca.fit(all_data)
152
+ reduced_data = pca.transform(self.bert_embeddings[:5000])
153
  reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
154
+
155
  else:
156
  raise ValueError(f"Unsupported embedding type: {embedding}")
157
  trace = go.Scatter3d(
 
181
  y=reduced_results_points[:, 1],
182
  z=reduced_results_points[:, 2],
183
  mode='markers',
184
+ marker=dict(size=3.5, color='orange', opacity=0.75),
185
  )
186
  fig = go.Figure(data=[trace, results_trace], layout=layout)
187
  else:
requirements.txt CHANGED
@@ -5,4 +5,5 @@ datasets
5
  torch
6
  gensim
7
  scikit-learn
8
- transformers
 
 
5
  torch
6
  gensim
7
  scikit-learn
8
+ transformers
9
+ plotly