Joschka Strueber commited on
Commit
8901fb0
·
1 Parent(s): d8f2ec7

[Fix] heatmap not generated and deselection causes error

Browse files
Files changed (1) hide show
  1. app.py +16 -39
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
5
 
6
  def create_heatmap(selected_models, selected_dataset):
 
7
  if not selected_models or not selected_dataset:
8
  return gr.Plot(visible=False)
9
 
@@ -30,6 +31,7 @@ def create_heatmap(selected_models, selected_dataset):
30
  return gr.Plot.update(value=fig, visible=True)
31
 
32
  def validate_inputs(selected_models, selected_dataset):
 
33
  if not selected_models:
34
  raise gr.Error("Please select at least one model!")
35
  if not selected_dataset:
@@ -39,43 +41,18 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
39
  gr.Markdown("## Model Similarity Comparison Tool")
40
 
41
  with gr.Row():
42
- dataset_dropdown = gr.Dropdown(
43
- choices=get_leaderboard_datasets(),
44
- label="Select Dataset",
45
- filterable=True,
46
- interactive=True,
47
- info="Leaderboard benchmark datasets"
48
- )
 
 
 
 
 
 
49
 
50
- model_dropdown = gr.Dropdown(
51
- choices=get_leaderboard_models_cached(),
52
- label="Select Models",
53
- multiselect=True,
54
- filterable=True,
55
- allow_custom_value=False,
56
- info="Search and select multiple models"
57
- )
58
-
59
- generate_btn = gr.Button("Generate Heatmap", variant="primary")
60
- heatmap = gr.Plot(label="Similarity Heatmap", visible=False)
61
-
62
- # Event handling
63
- generate_btn.click(
64
- fn=validate_inputs,
65
- inputs=[model_dropdown, dataset_dropdown],
66
- queue=False
67
- ).then(
68
- fn=create_heatmap,
69
- inputs=[model_dropdown, dataset_dropdown],
70
- outputs=heatmap
71
- )
72
-
73
- # Clear button should reset to empty lists
74
- clear_btn = gr.Button("Clear Selection")
75
- clear_btn.click(
76
- lambda: [[], [], gr.Plot.update(visible=False)],
77
- outputs=[model_dropdown, dataset_dropdown, heatmap]
78
- )
79
-
80
- if __name__ == "__main__":
81
- demo.launch()
 
4
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
5
 
6
  def create_heatmap(selected_models, selected_dataset):
7
+ print(f"Creating heatmap with models: {selected_models} and dataset: {selected_dataset}")
8
  if not selected_models or not selected_dataset:
9
  return gr.Plot(visible=False)
10
 
 
31
  return gr.Plot.update(value=fig, visible=True)
32
 
33
  def validate_inputs(selected_models, selected_dataset):
34
+ print(f"Validating inputs: models={selected_models}, dataset={selected_dataset}")
35
  if not selected_models:
36
  raise gr.Error("Please select at least one model!")
37
  if not selected_dataset:
 
41
  gr.Markdown("## Model Similarity Comparison Tool")
42
 
43
  with gr.Row():
44
+ model_selector = gr.Dropdown(label="Select Models", choices=get_leaderboard_models_cached(), multiselect=True)
45
+ dataset_selector = gr.Dropdown(label="Select Dataset", choices=get_leaderboard_datasets())
46
+ heatmap_output = gr.Plot(visible=False)
47
+
48
+ def on_submit(selected_models, selected_dataset):
49
+ try:
50
+ validate_inputs(selected_models, selected_dataset)
51
+ return create_heatmap(selected_models, selected_dataset)
52
+ except gr.Error as e:
53
+ return gr.Markdown(str(e))
54
+
55
+ submit_button = gr.Button("Generate Heatmap")
56
+ submit_button.click(on_submit, inputs=[model_selector, dataset_selector], outputs=heatmap_output)
57
 
58
+ demo.launch()