Jonas Leeb
commited on
Commit
·
b4a0b98
1
Parent(s):
6c71bbc
fixed bug with pca and fixed requirements
Browse files- app.py +15 -5
- requirements.txt +2 -1
app.py
CHANGED
@@ -132,16 +132,26 @@ class ArxivSearch:
|
|
132 |
pca = PCA(n_components=3)
|
133 |
results_indices = [i[0] for i in self.last_results]
|
134 |
if embedding == "tfidf":
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
137 |
|
138 |
elif embedding == "word2vec":
|
139 |
-
|
|
|
|
|
|
|
140 |
reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
141 |
|
142 |
elif embedding == "bert":
|
143 |
-
|
|
|
|
|
|
|
144 |
reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
|
|
145 |
else:
|
146 |
raise ValueError(f"Unsupported embedding type: {embedding}")
|
147 |
trace = go.Scatter3d(
|
@@ -171,7 +181,7 @@ class ArxivSearch:
|
|
171 |
y=reduced_results_points[:, 1],
|
172 |
z=reduced_results_points[:, 2],
|
173 |
mode='markers',
|
174 |
-
marker=dict(size=3.5, color='orange', opacity=0.
|
175 |
)
|
176 |
fig = go.Figure(data=[trace, results_trace], layout=layout)
|
177 |
else:
|
|
|
132 |
pca = PCA(n_components=3)
|
133 |
results_indices = [i[0] for i in self.last_results]
|
134 |
if embedding == "tfidf":
|
135 |
+
all_indices = list(set(results_indices) | set(range(min(5000, self.tfidf_matrix.shape[0]))))
|
136 |
+
all_data = self.tfidf_matrix[all_indices].toarray()
|
137 |
+
pca.fit(all_data)
|
138 |
+
reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray())
|
139 |
+
reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
|
140 |
|
141 |
elif embedding == "word2vec":
|
142 |
+
all_indices = list(set(results_indices) | set(range(min(5000, self.word2vec_embeddings.shape[0]))))
|
143 |
+
all_data = self.word2vec_embeddings[all_indices]
|
144 |
+
pca.fit(all_data)
|
145 |
+
reduced_data = pca.transform(self.word2vec_embeddings[:5000])
|
146 |
reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
147 |
|
148 |
elif embedding == "bert":
|
149 |
+
all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
|
150 |
+
all_data = self.bert_embeddings[all_indices]
|
151 |
+
pca.fit(all_data)
|
152 |
+
reduced_data = pca.transform(self.bert_embeddings[:5000])
|
153 |
reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
154 |
+
|
155 |
else:
|
156 |
raise ValueError(f"Unsupported embedding type: {embedding}")
|
157 |
trace = go.Scatter3d(
|
|
|
181 |
y=reduced_results_points[:, 1],
|
182 |
z=reduced_results_points[:, 2],
|
183 |
mode='markers',
|
184 |
+
marker=dict(size=3.5, color='orange', opacity=0.75),
|
185 |
)
|
186 |
fig = go.Figure(data=[trace, results_trace], layout=layout)
|
187 |
else:
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ datasets
|
|
5 |
torch
|
6 |
gensim
|
7 |
scikit-learn
|
8 |
-
transformers
|
|
|
|
5 |
torch
|
6 |
gensim
|
7 |
scikit-learn
|
8 |
+
transformers
|
9 |
+
plotly
|