Joschka Strueber commited on
Commit
60ded99
·
1 Parent(s): 90b2246

[Fix] debugging

Browse files
Files changed (1) hide show
  1. app.py +36 -20
app.py CHANGED
@@ -6,30 +6,26 @@ from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datas
6
  def create_heatmap(selected_models, selected_dataset):
7
  if not selected_models or not selected_dataset:
8
  return None # Return None to hide the plot if inputs are missing
9
-
10
  # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
  similarities = (similarities + similarities.T) / 2 # Make symmetric
14
  similarities = np.round(similarities, 2) # Round for clarity
15
-
16
- print(f"Generated heatmap with {len(selected_models)} models")
17
- print("Sample coordinates:", selected_models[:2])
18
- print("Sample similarity value:", similarities[0][0])
19
-
20
  # Create the heatmap figure
21
- fig = go.Figure(data=go.Heatmap(
 
22
  z=similarities,
23
  x=selected_models,
24
  y=selected_models,
25
  colorscale='Viridis',
26
- zmin=0,
27
- zmax=1,
28
- text=similarities,
29
  hoverinfo="text"
30
  ))
31
-
32
- # Update layout with axis titles and margins
33
  fig.update_layout(
34
  title=f"Similarity Matrix for {selected_dataset}",
35
  xaxis_title="Models",
@@ -38,19 +34,39 @@ def create_heatmap(selected_models, selected_dataset):
38
  height=800 + 20 * len(selected_models),
39
  margin=dict(b=100, l=100)
40
  )
41
-
42
- # Force the axes to use category types so that model names appear
43
- fig.update_xaxes(type="category", tickangle=45, automargin=True)
44
- fig.update_yaxes(type="category", automargin=True)
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  return fig # Return only the figure
47
 
 
48
  def validate_inputs(selected_models, selected_dataset):
49
  if not selected_models:
50
  raise gr.Error("Please select at least one model!")
51
  if not selected_dataset:
52
  raise gr.Error("Please select a dataset!")
 
53
 
 
54
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
55
  gr.Markdown("## Model Similarity Comparison Tool")
56
 
@@ -75,7 +91,7 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
75
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
76
  heatmap = gr.Plot(label="Similarity Heatmap", visible=True)
77
 
78
- # Event handling: Validate first then create the heatmap
79
  generate_btn.click(
80
  fn=validate_inputs,
81
  inputs=[model_dropdown, dataset_dropdown],
@@ -83,10 +99,10 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
83
  ).then(
84
  fn=create_heatmap,
85
  inputs=[model_dropdown, dataset_dropdown],
86
- outputs=heatmap # Only one output is needed
87
  )
88
 
89
- # Clear button
90
  clear_btn = gr.Button("Clear Selection")
91
  clear_btn.click(
92
  lambda: [None, None, None],
 
6
  def create_heatmap(selected_models, selected_dataset):
7
  if not selected_models or not selected_dataset:
8
  return None # Return None to hide the plot if inputs are missing
9
+
10
  # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
  similarities = (similarities + similarities.T) / 2 # Make symmetric
14
  similarities = np.round(similarities, 2) # Round for clarity
15
+
 
 
 
 
16
  # Create the heatmap figure
17
+ fig = go.Figure()
18
+ fig.add_trace(go.Heatmap(
19
  z=similarities,
20
  x=selected_models,
21
  y=selected_models,
22
  colorscale='Viridis',
23
+ zmin=0, zmax=1,
24
+ text=similarities, # Values to show on hover
 
25
  hoverinfo="text"
26
  ))
27
+
28
+ # Update layout for overall figure settings
29
  fig.update_layout(
30
  title=f"Similarity Matrix for {selected_dataset}",
31
  xaxis_title="Models",
 
34
  height=800 + 20 * len(selected_models),
35
  margin=dict(b=100, l=100)
36
  )
37
+
38
+ # Force both axes to be categorical by explicitly specifying tick values and text
39
+ fig.update_xaxes(
40
+ type="category",
41
+ tickmode="array",
42
+ tickvals=selected_models,
43
+ ticktext=selected_models,
44
+ tickangle=45,
45
+ automargin=True,
46
+ showgrid=True,
47
+ showticklabels=True
48
+ )
49
+ fig.update_yaxes(
50
+ type="category",
51
+ tickmode="array",
52
+ tickvals=selected_models,
53
+ ticktext=selected_models,
54
+ automargin=True,
55
+ showgrid=True,
56
+ showticklabels=True
57
+ )
58
+
59
  return fig # Return only the figure
60
 
61
+
62
  def validate_inputs(selected_models, selected_dataset):
63
  if not selected_models:
64
  raise gr.Error("Please select at least one model!")
65
  if not selected_dataset:
66
  raise gr.Error("Please select a dataset!")
67
+
68
 
69
+ # Gradio interface setup
70
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
71
  gr.Markdown("## Model Similarity Comparison Tool")
72
 
 
91
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
92
  heatmap = gr.Plot(label="Similarity Heatmap", visible=True)
93
 
94
+ # Use a single output (the figure)
95
  generate_btn.click(
96
  fn=validate_inputs,
97
  inputs=[model_dropdown, dataset_dropdown],
 
99
  ).then(
100
  fn=create_heatmap,
101
  inputs=[model_dropdown, dataset_dropdown],
102
+ outputs=heatmap
103
  )
104
 
105
+ # Clear button: clear selections and the plot
106
  clear_btn = gr.Button("Clear Selection")
107
  clear_btn.click(
108
  lambda: [None, None, None],