Jonas Leeb
commited on
Commit
·
2f53a52
1
Parent(s):
7db94f3
added colorcoded points
Browse files
app.py
CHANGED
@@ -156,20 +156,8 @@ class ArxivSearch:
|
|
156 |
raise ValueError(f"Unsupported embedding type: {self.embedding}")
|
157 |
|
158 |
results_scores = [i[1] for i in self.last_results]
|
159 |
-
# Normalize scores to [0, 1] for color mapping
|
160 |
-
if results_scores:
|
161 |
-
min_score = min(results_scores)
|
162 |
-
max_score = max(results_scores)
|
163 |
-
if max_score > min_score:
|
164 |
-
color_scale = [(score - min_score) / (max_score - min_score) for score in results_scores]
|
165 |
-
else:
|
166 |
-
color_scale = [0.5 for _ in results_scores]
|
167 |
-
else:
|
168 |
-
color_scale = []
|
169 |
|
170 |
-
|
171 |
-
cmap = matplotlib.cm.get_cmap('hot')
|
172 |
-
results_colors = [matplotlib.colors.rgb2hex(cmap(val)) for val in color_scale]
|
173 |
|
174 |
trace = go.Scatter3d(
|
175 |
x=reduced_data[:, 0],
|
@@ -183,39 +171,60 @@ class ArxivSearch:
|
|
183 |
text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))],
|
184 |
hoverinfo='text'
|
185 |
)
|
|
|
|
|
|
|
186 |
layout = go.Layout(
|
187 |
margin=dict(l=0, r=0, b=0, t=0),
|
188 |
scene=dict(
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
),
|
196 |
paper_bgcolor='black', # Outside the plotting area
|
197 |
plot_bgcolor='black', # Plotting area
|
198 |
font=dict(color='white'), # Axis and legend text
|
199 |
legend=dict(
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
)
|
207 |
)
|
|
|
208 |
if len(reduced_results_points) > 0:
|
|
|
|
|
|
|
|
|
|
|
209 |
results_trace = go.Scatter3d(
|
210 |
x=reduced_results_points[:, 0],
|
211 |
y=reduced_results_points[:, 1],
|
212 |
z=reduced_results_points[:, 2],
|
213 |
mode='markers',
|
214 |
-
marker=dict(size=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
name='Results',
|
216 |
text=[f"<br>{self.documents[i][:100]}" for i in results_indices],
|
217 |
hoverinfo='text'
|
218 |
)
|
|
|
|
|
|
|
219 |
if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
|
220 |
query_trace = go.Scatter3d(
|
221 |
x=query_point[:, 0],
|
@@ -227,11 +236,10 @@ class ArxivSearch:
|
|
227 |
text=[f"<br>Query: {self.query}"],
|
228 |
hoverinfo='text'
|
229 |
)
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
fig = go.Figure(data=[trace], layout=layout)
|
235 |
return fig
|
236 |
|
237 |
def keyword_match_ranking(self, query, top_n=10):
|
|
|
156 |
raise ValueError(f"Unsupported embedding type: {self.embedding}")
|
157 |
|
158 |
results_scores = [i[1] for i in self.last_results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
traces = []
|
|
|
|
|
161 |
|
162 |
trace = go.Scatter3d(
|
163 |
x=reduced_data[:, 0],
|
|
|
171 |
text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))],
|
172 |
hoverinfo='text'
|
173 |
)
|
174 |
+
|
175 |
+
traces.append(trace)
|
176 |
+
|
177 |
layout = go.Layout(
|
178 |
margin=dict(l=0, r=0, b=0, t=0),
|
179 |
scene=dict(
|
180 |
+
xaxis_title='PCA 1',
|
181 |
+
yaxis_title='PCA 2',
|
182 |
+
zaxis_title='PCA 3',
|
183 |
+
xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
|
184 |
+
yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
|
185 |
+
zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
|
186 |
),
|
187 |
paper_bgcolor='black', # Outside the plotting area
|
188 |
plot_bgcolor='black', # Plotting area
|
189 |
font=dict(color='white'), # Axis and legend text
|
190 |
legend=dict(
|
191 |
+
bgcolor='rgba(0,0,0,0)', # Transparent legend background
|
192 |
+
bordercolor='rgba(0,0,0,0)', # No border
|
193 |
+
x=0.01, # Place legend inside plot area (adjust as needed)
|
194 |
+
y=0.99,
|
195 |
+
xanchor='left',
|
196 |
+
yanchor='top'
|
197 |
)
|
198 |
)
|
199 |
+
|
200 |
if len(reduced_results_points) > 0:
|
201 |
+
custom_colorscale = [
|
202 |
+
[0.0, "#00ffea"], # Start color (e.g., bright cyan)
|
203 |
+
[1.0, "#ffea00"], # End color (e.g., bright yellow)
|
204 |
+
]
|
205 |
+
|
206 |
results_trace = go.Scatter3d(
|
207 |
x=reduced_results_points[:, 0],
|
208 |
y=reduced_results_points[:, 1],
|
209 |
z=reduced_results_points[:, 2],
|
210 |
mode='markers',
|
211 |
+
marker=dict(size=4.25,
|
212 |
+
color=results_scores,
|
213 |
+
colorscale=custom_colorscale,
|
214 |
+
opacity=0.99,
|
215 |
+
colorbar=dict(
|
216 |
+
title="Score",
|
217 |
+
bgcolor='rgba(0,0,0,0)', # <-- Transparent background for colorbar
|
218 |
+
bordercolor='rgba(0,0,0,0)' # No border
|
219 |
+
)
|
220 |
+
),
|
221 |
name='Results',
|
222 |
text=[f"<br>{self.documents[i][:100]}" for i in results_indices],
|
223 |
hoverinfo='text'
|
224 |
)
|
225 |
+
|
226 |
+
traces.append(results_trace)
|
227 |
+
|
228 |
if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
|
229 |
query_trace = go.Scatter3d(
|
230 |
x=query_point[:, 0],
|
|
|
236 |
text=[f"<br>Query: {self.query}"],
|
237 |
hoverinfo='text'
|
238 |
)
|
239 |
+
traces.append(query_trace)
|
240 |
+
|
241 |
+
fig = go.Figure(data=traces, layout=layout)
|
242 |
+
|
|
|
243 |
return fig
|
244 |
|
245 |
def keyword_match_ranking(self, query, top_n=10):
|