Joschka Strueber commited on
Commit
465a95b
·
1 Parent(s): 874e761

[Add] heatmap plot with seaborn instead of plotly

Browse files
Files changed (3) hide show
  1. app.py +64 -72
  2. app_heatmap.py +0 -103
  3. app_simple.py +106 -0
app.py CHANGED
@@ -1,106 +1,98 @@
1
  import gradio as gr
2
- import matplotlib.pyplot as plt
3
  import numpy as np
 
 
4
  from io import BytesIO
5
  from PIL import Image
6
-
7
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
8
- from src.similarity import compute_similarity
9
 
10
- # Set the backend to 'Agg' for non-GUI environments (optional)
11
- import matplotlib
12
- matplotlib.use('Agg')
13
 
14
-
15
- def generate_plot():
16
- # Generate data
17
- x = np.linspace(0, 10, 100)
18
- y = np.sin(x)
19
 
20
- # Create figure
21
- fig, ax = plt.subplots()
22
- ax.plot(x, y)
23
- ax.set_title("Sine Wave")
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # Save figure to a BytesIO buffer
 
 
 
 
 
 
 
 
26
  buf = BytesIO()
27
- fig.savefig(buf, format="png", bbox_inches="tight", facecolor="white", dpi=100)
28
- plt.close(fig) # Close the figure to free memory
29
 
30
- # Convert buffer to PIL Image
31
  buf.seek(0)
32
  img = Image.open(buf).convert("RGB")
33
  return img
34
 
35
-
36
- def validate_inputs(selected_model_a, selected_model_b, selected_dataset):
37
- if not selected_model_a:
38
- raise gr.Error("Please select Model A!")
39
- if not selected_model_b:
40
- raise gr.Error("Please select Model B!")
41
  if not selected_dataset:
42
  raise gr.Error("Please select a dataset!")
43
 
44
- def display_similarity(model_a, model_b, dataset):
45
- # Assuming compute_similarity returns a float or a string
46
- similarity_score = compute_similarity(model_a, model_b, dataset)
47
- return f"The similarity between {model_a} and {model_b} on {dataset} is: {similarity_score}"
48
-
49
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
50
  gr.Markdown("## Model Similarity Comparison Tool")
51
 
52
- dataset_dropdown = gr.Dropdown(
53
- choices=get_leaderboard_datasets(),
54
- label="Select Dataset",
55
- filterable=True,
56
- interactive=True,
57
- info="Leaderboard benchmark datasets"
58
- )
59
-
60
- model_a_dropdown = gr.Dropdown(
61
- choices=get_leaderboard_models_cached(),
62
- label="Select Model A",
63
- filterable=True,
64
- allow_custom_value=False,
65
- info="Search and select models"
66
- )
67
-
68
- model_b_dropdown = gr.Dropdown(
69
- choices=get_leaderboard_models_cached(),
70
- label="Select Model B",
71
- filterable=True,
72
- allow_custom_value=False,
73
- info="Search and select models"
74
- )
75
-
76
- generate_btn = gr.Button("Compute Similarity", variant="primary")
77
 
78
- # Textbox to display the similarity result
79
- similarity_output = gr.Textbox(
80
- label="Similarity Result",
81
- interactive=False
82
- )
83
 
84
  generate_btn.click(
85
  fn=validate_inputs,
86
- inputs=[model_a_dropdown, model_b_dropdown, dataset_dropdown],
87
  queue=False
88
  ).then(
89
- fn=display_similarity,
90
- inputs=[model_a_dropdown, model_b_dropdown, dataset_dropdown],
91
- outputs=similarity_output
92
  )
93
 
94
  clear_btn = gr.Button("Clear Selection")
95
  clear_btn.click(
96
- lambda: [None, None, None, ""],
97
- outputs=[model_a_dropdown, model_b_dropdown, dataset_dropdown, similarity_output]
98
  )
99
 
100
- gr.Markdown("## Matplotlib Plot in Gradio")
101
- plot_button = gr.Button("Generate Plot")
102
- plot_output = gr.Image(label="Sine Wave Plot")
103
- plot_button.click(fn=generate_plot, outputs=plot_output)
104
-
105
  if __name__ == "__main__":
106
- demo.launch()
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
  from io import BytesIO
6
  from PIL import Image
 
7
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
 
8
 
9
+ # Set matplotlib backend for non-GUI environments
10
+ plt.switch_backend('Agg')
 
11
 
12
+ def create_heatmap(selected_models, selected_dataset):
13
+ if not selected_models or not selected_dataset:
14
+ return None
 
 
15
 
16
+ size = len(selected_models)
17
+ similarities = np.random.rand(size, size)
18
+ similarities = (similarities + similarities.T) / 2
19
+ similarities = np.round(similarities, 2)
20
+
21
+ # Create figure and heatmap using seaborn
22
+ plt.figure(figsize=(10, 8))
23
+ ax = sns.heatmap(
24
+ similarities,
25
+ annot=True,
26
+ fmt=".2f",
27
+ cmap="viridis",
28
+ vmin=0,
29
+ vmax=1,
30
+ xticklabels=selected_models,
31
+ yticklabels=selected_models
32
+ )
33
 
34
+ # Customize plot
35
+ plt.title(f"Similarity Matrix for {selected_dataset}", fontsize=14)
36
+ plt.xlabel("Models")
37
+ plt.ylabel("Models")
38
+ plt.xticks(rotation=45, ha='right')
39
+ plt.yticks(rotation=0)
40
+ plt.tight_layout()
41
+
42
+ # Save to buffer
43
  buf = BytesIO()
44
+ plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
45
+ plt.close()
46
 
47
+ # Convert to PIL Image
48
  buf.seek(0)
49
  img = Image.open(buf).convert("RGB")
50
  return img
51
 
52
+ def validate_inputs(selected_models, selected_dataset):
53
+ if not selected_models:
54
+ raise gr.Error("Please select at least one model!")
 
 
 
55
  if not selected_dataset:
56
  raise gr.Error("Please select a dataset!")
57
 
 
 
 
 
 
58
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
59
  gr.Markdown("## Model Similarity Comparison Tool")
60
 
61
+ with gr.Row():
62
+ dataset_dropdown = gr.Dropdown(
63
+ choices=get_leaderboard_datasets(),
64
+ label="Select Dataset",
65
+ filterable=True,
66
+ interactive=True,
67
+ info="Leaderboard benchmark datasets"
68
+ )
69
+ model_dropdown = gr.Dropdown(
70
+ choices=get_leaderboard_models_cached(),
71
+ label="Select Models",
72
+ multiselect=True,
73
+ filterable=True,
74
+ allow_custom_value=False,
75
+ info="Search and select multiple models"
76
+ )
 
 
 
 
 
 
 
 
 
77
 
78
+ generate_btn = gr.Button("Generate Heatmap", variant="primary")
79
+ heatmap = gr.Image(label="Similarity Heatmap", visible=True)
 
 
 
80
 
81
  generate_btn.click(
82
  fn=validate_inputs,
83
+ inputs=[model_dropdown, dataset_dropdown],
84
  queue=False
85
  ).then(
86
+ fn=create_heatmap,
87
+ inputs=[model_dropdown, dataset_dropdown],
88
+ outputs=heatmap
89
  )
90
 
91
  clear_btn = gr.Button("Clear Selection")
92
  clear_btn.click(
93
+ lambda: [None, None, None],
94
+ outputs=[model_dropdown, dataset_dropdown, heatmap]
95
  )
96
 
 
 
 
 
 
97
  if __name__ == "__main__":
98
+ demo.launch(ssr_mode=False)
app_heatmap.py DELETED
@@ -1,103 +0,0 @@
1
- import gradio as gr
2
- import plotly.graph_objects as go
3
- import numpy as np
4
- from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
5
-
6
- # Optionally, force a renderer (may or may not help)
7
- import plotly.io as pio
8
- pio.renderers.default = "iframe"
9
-
10
- def create_heatmap(selected_models, selected_dataset):
11
- if not selected_models or not selected_dataset:
12
- return "" # Return empty HTML if no input
13
- size = len(selected_models)
14
- similarities = np.random.rand(size, size)
15
- similarities = (similarities + similarities.T) / 2
16
- similarities = np.round(similarities, 2)
17
-
18
- fig = go.Figure(data=go.Heatmap(
19
- z=similarities,
20
- x=selected_models,
21
- y=selected_models,
22
- colorscale="Viridis",
23
- zmin=0, zmax=1,
24
- text=similarities,
25
- hoverinfo="text"
26
- ))
27
-
28
- fig.update_layout(
29
- title=f"Similarity Matrix for {selected_dataset}",
30
- xaxis_title="Models",
31
- yaxis_title="Models",
32
- width=800,
33
- height=800,
34
- margin=dict(l=100, r=100, t=100, b=100)
35
- )
36
-
37
- # Force categorical ordering with explicit tick settings.
38
- fig.update_xaxes(
39
- type="category",
40
- categoryorder="array",
41
- categoryarray=selected_models,
42
- tickangle=45,
43
- automargin=True
44
- )
45
- fig.update_yaxes(
46
- type="category",
47
- categoryorder="array",
48
- categoryarray=selected_models,
49
- automargin=True
50
- )
51
-
52
- # Convert the figure to an HTML string that includes Plotly.js via CDN.
53
- return fig.to_html(full_html=False, include_plotlyjs="cdn")
54
-
55
- def validate_inputs(selected_models, selected_dataset):
56
- if not selected_models:
57
- raise gr.Error("Please select at least one model!")
58
- if not selected_dataset:
59
- raise gr.Error("Please select a dataset!")
60
-
61
- with gr.Blocks(title="LLM Similarity Analyzer") as demo:
62
- gr.Markdown("## Model Similarity Comparison Tool")
63
-
64
- with gr.Row():
65
- dataset_dropdown = gr.Dropdown(
66
- choices=get_leaderboard_datasets(),
67
- label="Select Dataset",
68
- filterable=True,
69
- interactive=True,
70
- info="Leaderboard benchmark datasets"
71
- )
72
- model_dropdown = gr.Dropdown(
73
- choices=get_leaderboard_models_cached(),
74
- label="Select Models",
75
- multiselect=True,
76
- filterable=True,
77
- allow_custom_value=False,
78
- info="Search and select multiple models"
79
- )
80
-
81
- generate_btn = gr.Button("Generate Heatmap", variant="primary")
82
- # Use an HTML component instead of gr.Plot.
83
- heatmap = gr.HTML(label="Similarity Heatmap", visible=True)
84
-
85
- generate_btn.click(
86
- fn=validate_inputs,
87
- inputs=[model_dropdown, dataset_dropdown],
88
- queue=False
89
- ).then(
90
- fn=create_heatmap,
91
- inputs=[model_dropdown, dataset_dropdown],
92
- outputs=heatmap
93
- )
94
-
95
- clear_btn = gr.Button("Clear Selection")
96
- clear_btn.click(
97
- lambda: [None, None, ""],
98
- outputs=[model_dropdown, dataset_dropdown, heatmap]
99
- )
100
-
101
- if __name__ == "__main__":
102
- # On Spaces, disable server-side rendering.
103
- demo.launch(ssr_mode=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_simple.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from io import BytesIO
5
+ from PIL import Image
6
+
7
+ from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
8
+ from src.similarity import compute_similarity
9
+
10
+ # Set the backend to 'Agg' for non-GUI environments (optional)
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
+
14
+
15
+ def generate_plot():
16
+ # Generate data
17
+ x = np.linspace(0, 10, 100)
18
+ y = np.sin(x)
19
+
20
+ # Create figure
21
+ fig, ax = plt.subplots()
22
+ ax.plot(x, y)
23
+ ax.set_title("Sine Wave")
24
+
25
+ # Save figure to a BytesIO buffer
26
+ buf = BytesIO()
27
+ fig.savefig(buf, format="png", bbox_inches="tight", facecolor="white", dpi=100)
28
+ plt.close(fig) # Close the figure to free memory
29
+
30
+ # Convert buffer to PIL Image
31
+ buf.seek(0)
32
+ img = Image.open(buf).convert("RGB")
33
+ return img
34
+
35
+
36
+ def validate_inputs(selected_model_a, selected_model_b, selected_dataset):
37
+ if not selected_model_a:
38
+ raise gr.Error("Please select Model A!")
39
+ if not selected_model_b:
40
+ raise gr.Error("Please select Model B!")
41
+ if not selected_dataset:
42
+ raise gr.Error("Please select a dataset!")
43
+
44
+ def display_similarity(model_a, model_b, dataset):
45
+ # Assuming compute_similarity returns a float or a string
46
+ similarity_score = compute_similarity(model_a, model_b, dataset)
47
+ return f"The similarity between {model_a} and {model_b} on {dataset} is: {similarity_score}"
48
+
49
+ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
50
+ gr.Markdown("## Model Similarity Comparison Tool")
51
+
52
+ dataset_dropdown = gr.Dropdown(
53
+ choices=get_leaderboard_datasets(),
54
+ label="Select Dataset",
55
+ filterable=True,
56
+ interactive=True,
57
+ info="Leaderboard benchmark datasets"
58
+ )
59
+
60
+ model_a_dropdown = gr.Dropdown(
61
+ choices=get_leaderboard_models_cached(),
62
+ label="Select Model A",
63
+ filterable=True,
64
+ allow_custom_value=False,
65
+ info="Search and select models"
66
+ )
67
+
68
+ model_b_dropdown = gr.Dropdown(
69
+ choices=get_leaderboard_models_cached(),
70
+ label="Select Model B",
71
+ filterable=True,
72
+ allow_custom_value=False,
73
+ info="Search and select models"
74
+ )
75
+
76
+ generate_btn = gr.Button("Compute Similarity", variant="primary")
77
+
78
+ # Textbox to display the similarity result
79
+ similarity_output = gr.Textbox(
80
+ label="Similarity Result",
81
+ interactive=False
82
+ )
83
+
84
+ generate_btn.click(
85
+ fn=validate_inputs,
86
+ inputs=[model_a_dropdown, model_b_dropdown, dataset_dropdown],
87
+ queue=False
88
+ ).then(
89
+ fn=display_similarity,
90
+ inputs=[model_a_dropdown, model_b_dropdown, dataset_dropdown],
91
+ outputs=similarity_output
92
+ )
93
+
94
+ clear_btn = gr.Button("Clear Selection")
95
+ clear_btn.click(
96
+ lambda: [None, None, None, ""],
97
+ outputs=[model_a_dropdown, model_b_dropdown, dataset_dropdown, similarity_output]
98
+ )
99
+
100
+ gr.Markdown("## Matplotlib Plot in Gradio")
101
+ plot_button = gr.Button("Generate Plot")
102
+ plot_output = gr.Image(label="Sine Wave Plot")
103
+ plot_button.click(fn=generate_plot, outputs=plot_output)
104
+
105
+ if __name__ == "__main__":
106
+ demo.launch()