Joschka Strueber commited on
Commit
d8f2ec7
·
1 Parent(s): c192c72

[Fix] similarity heatmap creation

Browse files
Files changed (1) hide show
  1. app.py +11 -27
app.py CHANGED
@@ -1,21 +1,16 @@
1
  import gradio as gr
2
  import plotly.graph_objects as go
3
  import numpy as np
4
-
5
-
6
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
7
 
8
-
9
  def create_heatmap(selected_models, selected_dataset):
10
- if not selected_models:
11
  return gr.Plot(visible=False)
12
 
13
- # Generate random similarity matrix (replace with actual computation)
14
  size = len(selected_models)
15
  similarities = np.random.rand(size, size)
16
-
17
- # Create symmetric matrix
18
- similarities = (similarities + similarities.T) / 2
19
 
20
  # Create plot
21
  fig = go.Figure(data=go.Heatmap(
@@ -31,9 +26,8 @@ def create_heatmap(selected_models, selected_dataset):
31
  height=800
32
  )
33
 
34
- with gr.Loading():
35
- return gr.Plot(value=fig, visible=True)
36
-
37
 
38
  def validate_inputs(selected_models, selected_dataset):
39
  if not selected_models:
@@ -41,11 +35,9 @@ def validate_inputs(selected_models, selected_dataset):
41
  if not selected_dataset:
42
  raise gr.Error("Please select a dataset!")
43
 
44
-
45
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
46
  gr.Markdown("## Model Similarity Comparison Tool")
47
 
48
- # Model selection section
49
  with gr.Row():
50
  dataset_dropdown = gr.Dropdown(
51
  choices=get_leaderboard_datasets(),
@@ -61,20 +53,13 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
61
  multiselect=True,
62
  filterable=True,
63
  allow_custom_value=False,
64
- info="Search and select multiple models (click selected models to remove)"
65
  )
66
 
67
- # Add generate button
68
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
 
69
 
70
- # Heatmap display
71
- heatmap = gr.Plot(
72
- label="Similarity Heatmap",
73
- visible=False,
74
- container=False
75
- )
76
-
77
- # Button click handler
78
  generate_btn.click(
79
  fn=validate_inputs,
80
  inputs=[model_dropdown, dataset_dropdown],
@@ -84,14 +69,13 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
84
  inputs=[model_dropdown, dataset_dropdown],
85
  outputs=heatmap
86
  )
87
-
 
88
  clear_btn = gr.Button("Clear Selection")
89
  clear_btn.click(
90
- lambda: [None, None, gr.Plot(visible=False)],
91
  outputs=[model_dropdown, dataset_dropdown, heatmap]
92
  )
93
 
94
-
95
-
96
  if __name__ == "__main__":
97
  demo.launch()
 
1
  import gradio as gr
2
  import plotly.graph_objects as go
3
  import numpy as np
 
 
4
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
5
 
 
6
  def create_heatmap(selected_models, selected_dataset):
7
+ if not selected_models or not selected_dataset:
8
  return gr.Plot(visible=False)
9
 
10
+ # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
+ similarities = (similarities + similarities.T) / 2 # Make symmetric
 
 
14
 
15
  # Create plot
16
  fig = go.Figure(data=go.Heatmap(
 
26
  height=800
27
  )
28
 
29
+ # Return both the figure and visibility update
30
+ return gr.Plot.update(value=fig, visible=True)
 
31
 
32
  def validate_inputs(selected_models, selected_dataset):
33
  if not selected_models:
 
35
  if not selected_dataset:
36
  raise gr.Error("Please select a dataset!")
37
 
 
38
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
39
  gr.Markdown("## Model Similarity Comparison Tool")
40
 
 
41
  with gr.Row():
42
  dataset_dropdown = gr.Dropdown(
43
  choices=get_leaderboard_datasets(),
 
53
  multiselect=True,
54
  filterable=True,
55
  allow_custom_value=False,
56
+ info="Search and select multiple models"
57
  )
58
 
 
59
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
60
+ heatmap = gr.Plot(label="Similarity Heatmap", visible=False)
61
 
62
+ # Event handling
 
 
 
 
 
 
 
63
  generate_btn.click(
64
  fn=validate_inputs,
65
  inputs=[model_dropdown, dataset_dropdown],
 
69
  inputs=[model_dropdown, dataset_dropdown],
70
  outputs=heatmap
71
  )
72
+
73
+ # Clear button should reset to empty lists
74
  clear_btn = gr.Button("Clear Selection")
75
  clear_btn.click(
76
+ lambda: [[], [], gr.Plot.update(visible=False)],
77
  outputs=[model_dropdown, dataset_dropdown, heatmap]
78
  )
79
 
 
 
80
  if __name__ == "__main__":
81
  demo.launch()