Jonas Leeb commited on
Commit
2f53a52
·
1 Parent(s): 7db94f3

added colorcoded points

Browse files
Files changed (1) hide show
  1. app.py +39 -31
app.py CHANGED
@@ -156,20 +156,8 @@ class ArxivSearch:
156
  raise ValueError(f"Unsupported embedding type: {self.embedding}")
157
 
158
  results_scores = [i[1] for i in self.last_results]
159
- # Normalize scores to [0, 1] for color mapping
160
- if results_scores:
161
- min_score = min(results_scores)
162
- max_score = max(results_scores)
163
- if max_score > min_score:
164
- color_scale = [(score - min_score) / (max_score - min_score) for score in results_scores]
165
- else:
166
- color_scale = [0.5 for _ in results_scores]
167
- else:
168
- color_scale = []
169
 
170
- # Use a "hot" colorscale (yellow/red for high, black for low)
171
- cmap = matplotlib.cm.get_cmap('hot')
172
- results_colors = [matplotlib.colors.rgb2hex(cmap(val)) for val in color_scale]
173
 
174
  trace = go.Scatter3d(
175
  x=reduced_data[:, 0],
@@ -183,39 +171,60 @@ class ArxivSearch:
183
  text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))],
184
  hoverinfo='text'
185
  )
 
 
 
186
  layout = go.Layout(
187
  margin=dict(l=0, r=0, b=0, t=0),
188
  scene=dict(
189
- xaxis_title='PCA 1',
190
- yaxis_title='PCA 2',
191
- zaxis_title='PCA 3',
192
- xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
193
- yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
194
- zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
195
  ),
196
  paper_bgcolor='black', # Outside the plotting area
197
  plot_bgcolor='black', # Plotting area
198
  font=dict(color='white'), # Axis and legend text
199
  legend=dict(
200
- bgcolor='rgba(0,0,0,0)', # Transparent legend background
201
- bordercolor='rgba(0,0,0,0)', # No border
202
- x=0.01, # Place legend inside plot area (adjust as needed)
203
- y=0.99,
204
- xanchor='left',
205
- yanchor='top'
206
  )
207
  )
 
208
  if len(reduced_results_points) > 0:
 
 
 
 
 
209
  results_trace = go.Scatter3d(
210
  x=reduced_results_points[:, 0],
211
  y=reduced_results_points[:, 1],
212
  z=reduced_results_points[:, 2],
213
  mode='markers',
214
- marker=dict(size=3.5, color='orange', opacity=0.75),
 
 
 
 
 
 
 
 
 
215
  name='Results',
216
  text=[f"<br>{self.documents[i][:100]}" for i in results_indices],
217
  hoverinfo='text'
218
  )
 
 
 
219
  if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
220
  query_trace = go.Scatter3d(
221
  x=query_point[:, 0],
@@ -227,11 +236,10 @@ class ArxivSearch:
227
  text=[f"<br>Query: {self.query}"],
228
  hoverinfo='text'
229
  )
230
- fig = go.Figure(data=[trace, results_trace, query_trace], layout=layout)
231
- else:
232
- fig = go.Figure(data=[trace, results_trace], layout=layout)
233
- else:
234
- fig = go.Figure(data=[trace], layout=layout)
235
  return fig
236
 
237
  def keyword_match_ranking(self, query, top_n=10):
 
156
  raise ValueError(f"Unsupported embedding type: {self.embedding}")
157
 
158
  results_scores = [i[1] for i in self.last_results]
 
 
 
 
 
 
 
 
 
 
159
 
160
+ traces = []
 
 
161
 
162
  trace = go.Scatter3d(
163
  x=reduced_data[:, 0],
 
171
  text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))],
172
  hoverinfo='text'
173
  )
174
+
175
+ traces.append(trace)
176
+
177
  layout = go.Layout(
178
  margin=dict(l=0, r=0, b=0, t=0),
179
  scene=dict(
180
+ xaxis_title='PCA 1',
181
+ yaxis_title='PCA 2',
182
+ zaxis_title='PCA 3',
183
+ xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
184
+ yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
185
+ zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
186
  ),
187
  paper_bgcolor='black', # Outside the plotting area
188
  plot_bgcolor='black', # Plotting area
189
  font=dict(color='white'), # Axis and legend text
190
  legend=dict(
191
+ bgcolor='rgba(0,0,0,0)', # Transparent legend background
192
+ bordercolor='rgba(0,0,0,0)', # No border
193
+ x=0.01, # Place legend inside plot area (adjust as needed)
194
+ y=0.99,
195
+ xanchor='left',
196
+ yanchor='top'
197
  )
198
  )
199
+
200
  if len(reduced_results_points) > 0:
201
+ custom_colorscale = [
202
+ [0.0, "#00ffea"], # Start color (e.g., bright cyan)
203
+ [1.0, "#ffea00"], # End color (e.g., bright yellow)
204
+ ]
205
+
206
  results_trace = go.Scatter3d(
207
  x=reduced_results_points[:, 0],
208
  y=reduced_results_points[:, 1],
209
  z=reduced_results_points[:, 2],
210
  mode='markers',
211
+ marker=dict(size=4.25,
212
+ color=results_scores,
213
+ colorscale=custom_colorscale,
214
+ opacity=0.99,
215
+ colorbar=dict(
216
+ title="Score",
217
+ bgcolor='rgba(0,0,0,0)', # <-- Transparent background for colorbar
218
+ bordercolor='rgba(0,0,0,0)' # No border
219
+ )
220
+ ),
221
  name='Results',
222
  text=[f"<br>{self.documents[i][:100]}" for i in results_indices],
223
  hoverinfo='text'
224
  )
225
+
226
+ traces.append(results_trace)
227
+
228
  if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
229
  query_trace = go.Scatter3d(
230
  x=query_point[:, 0],
 
236
  text=[f"<br>Query: {self.query}"],
237
  hoverinfo='text'
238
  )
239
+ traces.append(query_trace)
240
+
241
+ fig = go.Figure(data=traces, layout=layout)
242
+
 
243
  return fig
244
 
245
  def keyword_match_ranking(self, query, top_n=10):