Joschka Strueber commited on
Commit
bd28414
·
1 Parent(s): 1b549fb

[Ref, Add] custom css for sizing, move demo utility to its own file

Browse files
Files changed (2) hide show
  1. app.py +7 -93
  2. src/app_util.py +99 -0
app.py CHANGED
@@ -1,94 +1,8 @@
1
- import os
2
  import gradio as gr
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
6
- from io import BytesIO
7
- from PIL import Image
8
- from datasets.exceptions import DatasetNotFoundError
9
 
 
10
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
11
- from src.similarity import load_data_and_compute_similarities
12
 
13
- # Set matplotlib backend for non-GUI environments
14
- plt.switch_backend('Agg')
15
-
16
-
17
- def create_heatmap(selected_models, selected_dataset, selected_metric):
18
- if not selected_models or not selected_dataset:
19
- return None
20
-
21
- # Sort models and get short names
22
- similarities = load_data_and_compute_similarities(selected_models, selected_dataset, selected_metric)
23
-
24
- # Check if similarity matrix contains NaN rows
25
- failed_models = []
26
- for i in range(len(similarities)):
27
- if np.isnan(similarities[i]).all():
28
- failed_models.append(selected_models[i])
29
-
30
- if failed_models:
31
- gr.Warning(f"Failed to load data for models: {', '.join(failed_models)}")
32
-
33
- # Create figure and heatmap using seaborn
34
- plt.figure(figsize=(8, 6))
35
- ax = sns.heatmap(
36
- similarities,
37
- annot=True,
38
- fmt=".2f",
39
- cmap="viridis",
40
- vmin=0,
41
- vmax=1,
42
- xticklabels=selected_models,
43
- yticklabels=selected_models
44
- )
45
-
46
- # Customize plot
47
- plt.title(f"{selected_metric} for {selected_dataset}", fontsize=16)
48
- plt.xlabel("Models", fontsize=14)
49
- plt.ylabel("Models", fontsize=14)
50
- plt.xticks(rotation=45, ha='right')
51
- plt.yticks(rotation=0)
52
- plt.tight_layout()
53
-
54
- # Save to buffer
55
- buf = BytesIO()
56
- plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
57
- plt.close()
58
-
59
- # Convert to PIL Image
60
- buf.seek(0)
61
- img = Image.open(buf).convert("RGB")
62
- return img
63
-
64
- def validate_inputs(selected_models, selected_dataset):
65
- if not selected_models:
66
- raise gr.Error("Please select at least one model!")
67
- if not selected_dataset:
68
- raise gr.Error("Please select a dataset!")
69
-
70
-
71
- def update_datasets_based_on_models(selected_models, current_dataset):
72
- try:
73
- available_datasets = get_leaderboard_datasets(selected_models) if selected_models else []
74
- if current_dataset in available_datasets:
75
- valid_dataset = current_dataset
76
- elif "mmlu_pro" in available_datasets:
77
- valid_dataset = "mmlu_pro"
78
- else:
79
- valid_dataset = None
80
- return gr.update(
81
- choices=available_datasets,
82
- value=valid_dataset
83
- )
84
- except DatasetNotFoundError as e:
85
- # Extract model name from error message
86
- model_name = e.args[0].split("'")[1]
87
- model_name = model_name.split("/")[-1].replace("__", "/").replace("_details", "")
88
-
89
- # Display a shorter warning
90
- gr.Warning(f"Data for '{model_name}' is gated or unavailable.")
91
- return gr.update(choices=[], value=None)
92
 
93
  links_markdown = """
94
  [📄 Paper](https://arxiv.org/abs/2502.04313)   |  
@@ -104,7 +18,7 @@ metric_init = "CAPA"
104
 
105
 
106
  # Create Gradio interface
107
- with gr.Blocks(title="LLM Similarity Analyzer") as demo:
108
  gr.Markdown("# Model Similarity Comparison Tool")
109
  gr.Markdown(links_markdown)
110
  gr.Markdown('Demo for the recent publication ["Great Models Think Alike and this Undermines AI Oversight"](https://huggingface.co/papers/2502.04313).')
@@ -137,20 +51,20 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
137
  )
138
 
139
  model_dropdown.change(
140
- fn=update_datasets_based_on_models,
141
  inputs=[model_dropdown, dataset_dropdown],
142
  outputs=dataset_dropdown
143
  )
144
 
145
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
146
- heatmap = gr.Image(value=create_heatmap(model_init, dataset_init, metric_init), label="Similarity Heatmap", visible=True)
147
 
148
  generate_btn.click(
149
- fn=validate_inputs,
150
  inputs=[model_dropdown, dataset_dropdown],
151
  queue=False
152
  ).then(
153
- fn=create_heatmap,
154
  inputs=[model_dropdown, dataset_dropdown, metric_dropdown],
155
  outputs=heatmap
156
  )
@@ -170,7 +84,7 @@ biased towards more similar models controlling for the model's capability. (2) G
170
  of weak supervisors (weak-to-strong generalization) is higher when the two models are more different. (3) Concerningly, model \
171
  errors are getting more correlated as capabilities increase.""")
172
  with gr.Row():
173
- gr.Image(value="data/table_capa.png", label="Comparison of different similarity metrics for multiple-choice questions", interactive=False, scale=1)
174
  gr.Markdown("""
175
  - **Datasets**: [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard#/) benchmark datasets \n
176
  - Some datasets are not multiple-choice - for these, the metrics are not applicable. \n
 
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
+ import src.app_util as app_util
4
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  links_markdown = """
8
  [📄 Paper](https://arxiv.org/abs/2502.04313)   |  
 
18
 
19
 
20
  # Create Gradio interface
21
+ with gr.Blocks(title="LLM Similarity Analyzer", css=app_util.custom_css) as demo:
22
  gr.Markdown("# Model Similarity Comparison Tool")
23
  gr.Markdown(links_markdown)
24
  gr.Markdown('Demo for the recent publication ["Great Models Think Alike and this Undermines AI Oversight"](https://huggingface.co/papers/2502.04313).')
 
51
  )
52
 
53
  model_dropdown.change(
54
+ fn=app_util.update_datasets_based_on_models,
55
  inputs=[model_dropdown, dataset_dropdown],
56
  outputs=dataset_dropdown
57
  )
58
 
59
  generate_btn = gr.Button("Generate Heatmap", variant="primary")
60
+ heatmap = gr.Image(value=app_util.create_heatmap(model_init, dataset_init, metric_init), label="Similarity Heatmap", visible=True)
61
 
62
  generate_btn.click(
63
+ fn=app_util.validate_inputs,
64
  inputs=[model_dropdown, dataset_dropdown],
65
  queue=False
66
  ).then(
67
+ fn=app_util.create_heatmap,
68
  inputs=[model_dropdown, dataset_dropdown, metric_dropdown],
69
  outputs=heatmap
70
  )
 
84
  of weak supervisors (weak-to-strong generalization) is higher when the two models are more different. (3) Concerningly, model \
85
  errors are getting more correlated as capabilities increase.""")
86
  with gr.Row():
87
+ gr.Image(value="data/table_capa.png", label="Comparison of different similarity metrics for multiple-choice questions", elem_classes="image_container", interactive=False)
88
  gr.Markdown("""
89
  - **Datasets**: [Open LLM Leaderboard v2](https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard#/) benchmark datasets \n
90
  - Some datasets are not multiple-choice - for these, the metrics are not applicable. \n
src/app_util.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from io import BytesIO
6
+ from PIL import Image
7
+ from datasets.exceptions import DatasetNotFoundError
8
+
9
+ from src.dataloading import get_leaderboard_datasets
10
+ from src.similarity import load_data_and_compute_similarities
11
+
12
+ # Set matplotlib backend for non-GUI environments
13
+ plt.switch_backend('Agg')
14
+
15
+ def create_heatmap(selected_models, selected_dataset, selected_metric):
16
+ if not selected_models or not selected_dataset:
17
+ return None
18
+
19
+ # Sort models and get short names
20
+ similarities = load_data_and_compute_similarities(selected_models, selected_dataset, selected_metric)
21
+
22
+ # Check if similarity matrix contains NaN rows
23
+ failed_models = []
24
+ for i in range(len(similarities)):
25
+ if np.isnan(similarities[i]).all():
26
+ failed_models.append(selected_models[i])
27
+
28
+ if failed_models:
29
+ gr.Warning(f"Failed to load data for models: {', '.join(failed_models)}")
30
+
31
+ # Create figure and heatmap using seaborn
32
+ plt.figure(figsize=(8, 6))
33
+ ax = sns.heatmap(
34
+ similarities,
35
+ annot=True,
36
+ fmt=".2f",
37
+ cmap="viridis",
38
+ vmin=0,
39
+ vmax=1,
40
+ xticklabels=selected_models,
41
+ yticklabels=selected_models
42
+ )
43
+
44
+ # Customize plot
45
+ plt.title(f"{selected_metric} for {selected_dataset}", fontsize=16)
46
+ plt.xlabel("Models", fontsize=14)
47
+ plt.ylabel("Models", fontsize=14)
48
+ plt.xticks(rotation=45, ha='right')
49
+ plt.yticks(rotation=0)
50
+ plt.tight_layout()
51
+
52
+ # Save to buffer
53
+ buf = BytesIO()
54
+ plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
55
+ plt.close()
56
+
57
+ # Convert to PIL Image
58
+ buf.seek(0)
59
+ img = Image.open(buf).convert("RGB")
60
+ return img
61
+
62
+ def validate_inputs(selected_models, selected_dataset):
63
+ if not selected_models:
64
+ raise gr.Error("Please select at least one model!")
65
+ if not selected_dataset:
66
+ raise gr.Error("Please select a dataset!")
67
+
68
+
69
+ def update_datasets_based_on_models(selected_models, current_dataset):
70
+ try:
71
+ available_datasets = get_leaderboard_datasets(selected_models) if selected_models else []
72
+ if current_dataset in available_datasets:
73
+ valid_dataset = current_dataset
74
+ elif "mmlu_pro" in available_datasets:
75
+ valid_dataset = "mmlu_pro"
76
+ else:
77
+ valid_dataset = None
78
+ return gr.update(
79
+ choices=available_datasets,
80
+ value=valid_dataset
81
+ )
82
+ except DatasetNotFoundError as e:
83
+ # Extract model name from error message
84
+ model_name = e.args[0].split("'")[1]
85
+ model_name = model_name.split("/")[-1].replace("__", "/").replace("_details", "")
86
+
87
+ # Display a shorter warning
88
+ gr.Warning(f"Data for '{model_name}' is gated or unavailable.")
89
+ return gr.update(choices=[], value=None)
90
+
91
+ custom_css = """
92
+ .image-container img {
93
+ width: 80% !important; /* Make it 80% of the parent container */
94
+ height: auto !important; /* Maintain aspect ratio */
95
+ max-width: 800px; /* Optional: Set a max limit */
96
+ display: block;
97
+ margin: auto; /* Center the image */
98
+ }
99
+ """