Joschka Strueber commited on
Commit
54b2baf
·
1 Parent(s): 1a7f19c

[Fix] grey plot

Browse files
Files changed (1) hide show
  1. app.py +16 -41
app.py CHANGED
@@ -1,61 +1,36 @@
1
  import gradio as gr
2
- import plotly.graph_objects as go
3
  import numpy as np
4
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
5
 
6
  def create_heatmap(selected_models, selected_dataset):
7
  if not selected_models or not selected_dataset:
8
- return None # Hide the plot if no selection
9
 
10
  # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
- similarities = (similarities + similarities.T) / 2 # Make symmetric
14
- similarities = np.round(similarities, 2) # Round for clarity
15
 
16
- # Create the heatmap figure
17
- fig = go.Figure(data=go.Heatmap(
18
- z=similarities,
19
- x=selected_models,
20
- y=selected_models,
21
- colorscale='Viridis',
22
- zmin=0, zmax=1,
23
- text=similarities,
24
- hoverinfo="text"
25
- ))
26
-
27
- # Update layout for title, size, margins, etc.
28
  fig.update_layout(
29
  title=f"Similarity Matrix for {selected_dataset}",
30
- xaxis_title="Models",
31
- yaxis_title="Models",
32
- width=800 + 20 * len(selected_models),
33
- height=800 + 20 * len(selected_models),
34
  margin=dict(b=100, l=100)
35
  )
36
-
37
- # Force axes to be categorical and explicitly set the order
38
- fig.update_xaxes(
39
- type="category",
40
- tickangle=45,
41
- categoryorder="array",
42
- categoryarray=selected_models, # Explicitly force ordering to match your list
43
- automargin=True,
44
- showgrid=True,
45
- showticklabels=True
46
- )
47
- fig.update_yaxes(
48
- type="category",
49
- categoryorder="array",
50
- categoryarray=selected_models,
51
- automargin=True,
52
- showgrid=True,
53
- showticklabels=True
54
- )
55
-
56
  return fig
57
 
58
-
59
  def validate_inputs(selected_models, selected_dataset):
60
  if not selected_models:
61
  raise gr.Error("Please select at least one model!")
 
1
  import gradio as gr
2
+ import plotly.express as px
3
  import numpy as np
4
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
5
 
6
  def create_heatmap(selected_models, selected_dataset):
7
  if not selected_models or not selected_dataset:
8
+ return None # Hide plot if inputs are missing
9
 
10
  # Generate random similarity matrix
11
  size = len(selected_models)
12
  similarities = np.random.rand(size, size)
13
+ similarities = (similarities + similarities.T) / 2 # make symmetric
14
+ similarities = np.round(similarities, 2)
15
 
16
+ # Use Plotly Express imshow to create a heatmap
17
+ fig = px.imshow(similarities,
18
+ x=selected_models,
19
+ y=selected_models,
20
+ color_continuous_scale='Viridis',
21
+ zmin=0, zmax=1,
22
+ text_auto=True)
23
+ # Move x-axis labels to top and adjust tick angle for readability
24
+ fig.update_xaxes(side="top", tickangle=45)
25
+ # Update overall layout: title, dimensions, margins
 
 
26
  fig.update_layout(
27
  title=f"Similarity Matrix for {selected_dataset}",
28
+ width=800 + 20 * size,
29
+ height=800 + 20 * size,
 
 
30
  margin=dict(b=100, l=100)
31
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return fig
33
 
 
34
  def validate_inputs(selected_models, selected_dataset):
35
  if not selected_models:
36
  raise gr.Error("Please select at least one model!")