Joschka Strueber commited on
Commit
53d5dd8
·
1 Parent(s): e1a6930

[Add] clear button, load the right data, create plot on click

Browse files
Files changed (2) hide show
  1. app.py +43 -23
  2. src/dataloading.py +8 -12
app.py CHANGED
@@ -6,35 +6,41 @@ import numpy as np
6
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
7
 
8
 
9
- def create_heatmap(selected_models, benchmark):
10
  if not selected_models:
11
- return gr.update(visible=False)
12
 
13
- # Generate random similarity matrix
14
  size = len(selected_models)
15
  similarities = np.random.rand(size, size)
16
 
17
- # Create symmetric matrix (for demo purposes)
18
  similarities = (similarities + similarities.T) / 2
19
 
20
- # Create heatmap with Plotly
21
  fig = go.Figure(data=go.Heatmap(
22
  z=similarities,
23
  x=selected_models,
24
  y=selected_models,
25
- colorscale='Viridis',
26
- hoverongaps=False
27
  ))
28
 
29
  fig.update_layout(
30
- title=f"Model Similarity for {benchmark}",
31
- xaxis_title="Models",
32
- yaxis_title="Models",
33
- height=600,
34
- width=800
35
  )
36
 
37
- return fig
 
 
 
 
 
 
 
 
 
38
 
39
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
40
  gr.Markdown("## Model Similarity Comparison Tool")
@@ -42,12 +48,12 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
42
  # Model selection section
43
  with gr.Row():
44
  dataset_dropdown = gr.Dropdown(
45
- choices=get_leaderboard_datasets(),
46
- label="Select Dataset",
47
- filterable=True,
48
- interactive=True,
49
- info="Leaderboard benchmark datasets"
50
- )
51
 
52
  model_dropdown = gr.Dropdown(
53
  choices=get_leaderboard_models_cached(),
@@ -58,20 +64,34 @@ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
58
  info="Search and select multiple models (click selected models to remove)"
59
  )
60
 
 
 
 
61
  # Heatmap display
62
  heatmap = gr.Plot(
63
  label="Similarity Heatmap",
64
  visible=False,
65
  container=False
66
  )
67
-
68
- # Interactive updates
69
- model_dropdown.input(
 
 
 
 
70
  fn=create_heatmap,
71
- inputs=(model_dropdown, dataset_dropdown),
72
  outputs=heatmap
73
  )
74
 
 
 
 
 
 
 
 
75
 
76
  if __name__ == "__main__":
77
  demo.launch()
 
6
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
7
 
8
 
9
+ def create_heatmap(selected_models, selected_dataset):
10
  if not selected_models:
11
+ return gr.Plot(visible=False)
12
 
13
+ # Generate random similarity matrix (replace with actual computation)
14
  size = len(selected_models)
15
  similarities = np.random.rand(size, size)
16
 
17
+ # Create symmetric matrix
18
  similarities = (similarities + similarities.T) / 2
19
 
20
+ # Create plot
21
  fig = go.Figure(data=go.Heatmap(
22
  z=similarities,
23
  x=selected_models,
24
  y=selected_models,
25
+ colorscale='Viridis'
 
26
  ))
27
 
28
  fig.update_layout(
29
+ title=f"Similarity Matrix for {selected_dataset}",
30
+ width=800,
31
+ height=800
 
 
32
  )
33
 
34
+ with gr.Loading():
35
+ return gr.Plot(value=fig, visible=True)
36
+
37
+
38
+ def validate_inputs(selected_models, selected_dataset):
39
+ if not selected_models:
40
+ raise gr.Error("Please select at least one model!")
41
+ if not selected_dataset:
42
+ raise gr.Error("Please select a dataset!")
43
+
44
 
45
  with gr.Blocks(title="LLM Similarity Analyzer") as demo:
46
  gr.Markdown("## Model Similarity Comparison Tool")
 
48
  # Model selection section
49
  with gr.Row():
50
  dataset_dropdown = gr.Dropdown(
51
+ choices=get_leaderboard_datasets(),
52
+ label="Select Dataset",
53
+ filterable=True,
54
+ interactive=True,
55
+ info="Leaderboard benchmark datasets"
56
+ )
57
 
58
  model_dropdown = gr.Dropdown(
59
  choices=get_leaderboard_models_cached(),
 
64
  info="Search and select multiple models (click selected models to remove)"
65
  )
66
 
67
+ # Add generate button
68
+ generate_btn = gr.Button("Generate Heatmap", variant="primary")
69
+
70
  # Heatmap display
71
  heatmap = gr.Plot(
72
  label="Similarity Heatmap",
73
  visible=False,
74
  container=False
75
  )
76
+
77
+ # Button click handler
78
+ generate_btn.click(
79
+ fn=validate_inputs,
80
+ inputs=[model_dropdown, dataset_dropdown],
81
+ queue=False
82
+ ).then(
83
  fn=create_heatmap,
84
+ inputs=[model_dropdown, dataset_dropdown],
85
  outputs=heatmap
86
  )
87
 
88
+ clear_btn = gr.Button("Clear Selection")
89
+ clear_btn.click(
90
+ lambda: [None, None, gr.Plot(visible=False)],
91
+ outputs=[model_dropdown, dataset_dropdown, heatmap]
92
+ )
93
+
94
+
95
 
96
  if __name__ == "__main__":
97
  demo.launch()
src/dataloading.py CHANGED
@@ -5,22 +5,18 @@ from functools import lru_cache
5
  def get_leaderboard_models():
6
  api = HfApi()
7
 
8
- # List all files in the repository
9
- files = api.list_repo_files(
10
- repo_id="open-llm-leaderboard/open_llm_leaderboard",
11
- repo_type="space"
12
- )
13
 
14
  models = []
15
- for file in files:
16
- # Filter files in the "open_llm_leaderboard" directory
17
- if file.startswith("open_llm_leaderboard/") and "-details" in file and "__" in file:
18
- # Extract provider and model name from filename
19
- filename = file.split("/")[-1].replace("-details", "")
20
- provider, model = filename.split("__", 1)
21
  models.append(f"{provider}/{model}")
22
 
23
- return sorted(list(set(models))) # Remove duplicates
24
 
25
 
26
 
 
5
  def get_leaderboard_models():
6
  api = HfApi()
7
 
8
+ # List all datasets in the open-llm-leaderboard organization
9
+ datasets = api.list_datasets(author="open-llm-leaderboard")
 
 
 
10
 
11
  models = []
12
+ for dataset in datasets:
13
+ if dataset.id.endswith("-details"):
14
+ # Format: "open-llm-leaderboard/<provider>__<model_name>-details"
15
+ model_part = dataset.id.split("/")[-1].replace("-details", "")
16
+ provider, model = model_part.split("__", 1)
 
17
  models.append(f"{provider}/{model}")
18
 
19
+ return sorted(models)
20
 
21
 
22