Joschka Strueber commited on
Commit
3b16cfa
·
1 Parent(s): 54b2baf

[Fix] plotly heatmap

Browse files
Files changed (3) hide show
  1. app.py +25 -22
  2. src/heatmap.html +0 -0
  3. src/test.py +14 -0
app.py CHANGED
@@ -1,33 +1,38 @@
1
  import gradio as gr
2
- import plotly.express as px
3
  import numpy as np
4
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
5
 
6
  def create_heatmap(selected_models, selected_dataset):
 
7
  if not selected_models or not selected_dataset:
8
- return None # Hide plot if inputs are missing
9
 
10
- # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
- similarities = (similarities + similarities.T) / 2 # make symmetric
14
  similarities = np.round(similarities, 2)
15
 
16
- # Use Plotly Express imshow to create a heatmap
17
- fig = px.imshow(similarities,
18
- x=selected_models,
19
- y=selected_models,
20
- color_continuous_scale='Viridis',
21
- zmin=0, zmax=1,
22
- text_auto=True)
23
- # Move x-axis labels to top and adjust tick angle for readability
24
- fig.update_xaxes(side="top", tickangle=45)
25
- # Update overall layout: title, dimensions, margins
 
 
26
  fig.update_layout(
27
  title=f"Similarity Matrix for {selected_dataset}",
28
- width=800 + 20 * size,
29
- height=800 + 20 * size,
30
- margin=dict(b=100, l=100)
 
 
31
  )
32
  return fig
33
 
@@ -36,9 +41,7 @@ def validate_inputs(selected_models, selected_dataset):
36
  raise gr.Error("Please select at least one model!")
37
  if not selected_dataset:
38
  raise gr.Error("Please select a dataset!")
39
-
40
 
41
- # Gradio interface setup
42
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
43
  gr.Markdown("## Model Similarity Comparison Tool")
44
 
@@ -50,7 +53,6 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
50
  interactive=True,
51
  info="Leaderboard benchmark datasets"
52
  )
53
-
54
  model_dropdown = gr.Dropdown(
55
  choices=get_leaderboard_models_cached(),
56
  label="Select Models",
@@ -61,9 +63,10 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
61
  )
62
 
63
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
 
64
  heatmap = gr.Plot(label="Similarity Heatmap", visible=True)
65
 
66
- # Use a single output (the figure)
67
  generate_btn.click(
68
  fn=validate_inputs,
69
  inputs=[model_dropdown, dataset_dropdown],
@@ -74,7 +77,7 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
74
  outputs=heatmap
75
  )
76
 
77
- # Clear button: clear selections and the plot
78
  clear_btn = gr.Button("Clear Selection")
79
  clear_btn.click(
80
  lambda: [None, None, None],
 
1
  import gradio as gr
2
+ import plotly.graph_objects as go
3
  import numpy as np
4
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
5
 
6
  def create_heatmap(selected_models, selected_dataset):
7
+ # Return nothing if no inputs are provided
8
  if not selected_models or not selected_dataset:
9
+ return None
10
 
11
+ # Generate a random symmetric similarity matrix
12
  size = len(selected_models)
13
  similarities = np.random.rand(size, size)
14
+ similarities = (similarities + similarities.T) / 2
15
  similarities = np.round(similarities, 2)
16
 
17
+ # Create a heatmap trace using go.Heatmap; we set x and y to the model names.
18
+ fig = go.Figure(data=go.Heatmap(
19
+ z=similarities,
20
+ x=selected_models,
21
+ y=selected_models,
22
+ colorscale="Viridis",
23
+ zmin=0, zmax=1,
24
+ text=similarities,
25
+ hoverinfo="text"
26
+ ))
27
+
28
+ # Update layout: add title, axis titles, set fixed dimensions and margins
29
  fig.update_layout(
30
  title=f"Similarity Matrix for {selected_dataset}",
31
+ xaxis_title="Models",
32
+ yaxis_title="Models",
33
+ width=800,
34
+ height=800,
35
+ margin=dict(l=100, r=100, t=100, b=100)
36
  )
37
  return fig
38
 
 
41
  raise gr.Error("Please select at least one model!")
42
  if not selected_dataset:
43
  raise gr.Error("Please select a dataset!")
 
44
 
 
45
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
46
  gr.Markdown("## Model Similarity Comparison Tool")
47
 
 
53
  interactive=True,
54
  info="Leaderboard benchmark datasets"
55
  )
 
56
  model_dropdown = gr.Dropdown(
57
  choices=get_leaderboard_models_cached(),
58
  label="Select Models",
 
63
  )
64
 
65
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
66
+ # Initialize the Plot component without a figure (it will be updated)
67
  heatmap = gr.Plot(label="Similarity Heatmap", visible=True)
68
 
69
+ # First validate inputs, then create the heatmap; note that we use a single output.
70
  generate_btn.click(
71
  fn=validate_inputs,
72
  inputs=[model_dropdown, dataset_dropdown],
 
77
  outputs=heatmap
78
  )
79
 
80
+ # Clear button to reset selections and clear the plot
81
  clear_btn = gr.Button("Clear Selection")
82
  clear_btn.click(
83
  lambda: [None, None, None],
src/heatmap.html ADDED
The diff for this file is too large to render. See raw diff
 
src/test.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objects as go
2
+ import numpy as np
3
+
4
+ models = ["model1", "model2", "model3"]
5
+ size = len(models)
6
+ sim = np.random.rand(size, size)
7
+ sim = (sim + sim.T) / 2
8
+ sim = np.round(sim, 2)
9
+ fig = go.Figure(data=go.Heatmap(z=sim, x=models, y=models, colorscale="Viridis"))
10
+ fig.update_layout(title="Test Heatmap", xaxis_title="Models", yaxis_title="Models", width=800, height=800)
11
+ fig.show()
12
+
13
+ # Save fig
14
+ fig.write_html("heatmap.html")