Joschka Strueber commited on
Commit
d35fe98
·
1 Parent(s): f8fdd68

[Fix] plot axes and visibility

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -5,31 +5,35 @@ from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datas
5
 
6
  def create_heatmap(selected_models, selected_dataset):
7
  if not selected_models or not selected_dataset:
8
- return None
9
 
10
  # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
  similarities = (similarities + similarities.T) / 2 # Make symmetric
14
 
15
- # Explicitly create a Figure object before adding Heatmap
16
- fig = go.Figure()
17
- fig.add_trace(go.Heatmap(
18
  z=similarities,
19
  x=selected_models,
20
  y=selected_models,
21
- colorscale='Viridis'
 
22
  ))
23
-
 
24
  fig.update_layout(
25
  title=f"Similarity Matrix for {selected_dataset}",
26
- width=800,
27
- height=800,
28
  xaxis_title="Models",
29
- yaxis_title="Models"
 
 
 
 
 
30
  )
31
 
32
- return fig
33
 
34
  def validate_inputs(selected_models, selected_dataset):
35
  if not selected_models:
@@ -69,7 +73,7 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
69
  ).then(
70
  fn=create_heatmap,
71
  inputs=[model_dropdown, dataset_dropdown],
72
- outputs=[heatmap]
73
  )
74
 
75
  # Clear button
 
5
 
6
  def create_heatmap(selected_models, selected_dataset):
7
  if not selected_models or not selected_dataset:
8
+ return None, gr.update(visible=False) # Hide if no selection
9
 
10
  # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
  similarities = (similarities + similarities.T) / 2 # Make symmetric
14
 
15
+ # Create plot
16
+ fig = go.Figure(data=go.Heatmap(
 
17
  z=similarities,
18
  x=selected_models,
19
  y=selected_models,
20
+ colorscale='Viridis',
21
+ hoverongaps=False
22
  ))
23
+
24
+ # Improve axis readability
25
  fig.update_layout(
26
  title=f"Similarity Matrix for {selected_dataset}",
 
 
27
  xaxis_title="Models",
28
+ yaxis_title="Models",
29
+ xaxis=dict(tickangle=45, automargin=True),
30
+ yaxis=dict(automargin=True),
31
+ width=800 + 20*len(selected_models),
32
+ height=800 + 20*len(selected_models),
33
+ margin=dict(b=100, l=100) # Add bottom/left margin for labels
34
  )
35
 
36
+ return fig, gr.update(visible=True)
37
 
38
  def validate_inputs(selected_models, selected_dataset):
39
  if not selected_models:
 
73
  ).then(
74
  fn=create_heatmap,
75
  inputs=[model_dropdown, dataset_dropdown],
76
+ outputs=[heatmap, heatmap.visible]
77
  )
78
 
79
  # Clear button