Joschka Strueber commited on
Commit
fbb8c61
·
1 Parent(s): afa5fdc

[Fix] categorical axis names

Browse files
Files changed (1) hide show
  1. app.py +23 -24
app.py CHANGED
@@ -5,41 +5,41 @@ from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datas
5
 
6
  def create_heatmap(selected_models, selected_dataset):
7
  if not selected_models or not selected_dataset:
8
- return None # Just return None to hide the plot
9
-
10
  # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
  similarities = (similarities + similarities.T) / 2 # Make symmetric
14
  similarities = np.round(similarities, 2) # Round for clarity
15
-
16
  # Create the heatmap figure
17
- fig = go.Figure(
18
- data=go.Heatmap(
19
- z=similarities, # ✅ Ensure it's a NumPy array
20
- x=selected_models,
21
- y=selected_models,
22
- colorscale='Viridis',
23
- zmin=0, zmax=1, # Normalize scale
24
- text=similarities, # ✅ Show values in heatmap
25
- hoverinfo="text"
26
- )
27
- )
28
-
29
- # Improve axis readability
30
  fig.update_layout(
31
  title=f"Similarity Matrix for {selected_dataset}",
32
  xaxis_title="Models",
33
  yaxis_title="Models",
34
- xaxis=dict(tickangle=45, automargin=True),
35
- yaxis=dict(automargin=True),
36
  width=800 + 20 * len(selected_models),
37
  height=800 + 20 * len(selected_models),
38
- margin=dict(b=100, l=100), # Add bottom/left margin for labels
39
  )
40
 
41
- return fig # Return only the figure
 
 
42
 
 
43
 
44
  def validate_inputs(selected_models, selected_dataset):
45
  if not selected_models:
@@ -69,9 +69,9 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
69
  )
70
 
71
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
72
- heatmap = gr.Plot(label="Similarity Heatmap", visible=True) # ✅ Ensure visible=True
73
 
74
- # Event handling
75
  generate_btn.click(
76
  fn=validate_inputs,
77
  inputs=[model_dropdown, dataset_dropdown],
@@ -79,7 +79,7 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
79
  ).then(
80
  fn=create_heatmap,
81
  inputs=[model_dropdown, dataset_dropdown],
82
- outputs=heatmap # Only one output (gr.Plot)
83
  )
84
 
85
  # Clear button
@@ -89,6 +89,5 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
89
  outputs=[model_dropdown, dataset_dropdown, heatmap]
90
  )
91
 
92
-
93
  if __name__ == "__main__":
94
  demo.launch()
 
5
 
6
  def create_heatmap(selected_models, selected_dataset):
7
  if not selected_models or not selected_dataset:
8
+ return None # Return None to hide the plot if inputs are missing
9
+
10
  # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
  similarities = (similarities + similarities.T) / 2 # Make symmetric
14
  similarities = np.round(similarities, 2) # Round for clarity
15
+
16
  # Create the heatmap figure
17
+ fig = go.Figure(data=go.Heatmap(
18
+ z=similarities,
19
+ x=selected_models,
20
+ y=selected_models,
21
+ colorscale='Viridis',
22
+ zmin=0,
23
+ zmax=1,
24
+ text=similarities,
25
+ hoverinfo="text"
26
+ ))
27
+
28
+ # Update layout with axis titles and margins
 
29
  fig.update_layout(
30
  title=f"Similarity Matrix for {selected_dataset}",
31
  xaxis_title="Models",
32
  yaxis_title="Models",
 
 
33
  width=800 + 20 * len(selected_models),
34
  height=800 + 20 * len(selected_models),
35
+ margin=dict(b=100, l=100)
36
  )
37
 
38
+ # Force the axes to use category types so that model names appear
39
+ fig.update_xaxes(type="category", tickangle=45, automargin=True)
40
+ fig.update_yaxes(type="category", automargin=True)
41
 
42
+ return fig # Return only the figure
43
 
44
  def validate_inputs(selected_models, selected_dataset):
45
  if not selected_models:
 
69
  )
70
 
71
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
72
+ heatmap = gr.Plot(label="Similarity Heatmap", visible=True)
73
 
74
+ # Event handling: Validate first then create the heatmap
75
  generate_btn.click(
76
  fn=validate_inputs,
77
  inputs=[model_dropdown, dataset_dropdown],
 
79
  ).then(
80
  fn=create_heatmap,
81
  inputs=[model_dropdown, dataset_dropdown],
82
+ outputs=heatmap # Only one output is needed
83
  )
84
 
85
  # Clear button
 
89
  outputs=[model_dropdown, dataset_dropdown, heatmap]
90
  )
91
 
 
92
  if __name__ == "__main__":
93
  demo.launch()