Joschka Strueber commited on
Commit
e1a6930
·
1 Parent(s): 228927e

[Add] create heatmaps for multiselection

Browse files
Files changed (2) hide show
  1. app.py +60 -51
  2. src/dataloading.py +5 -4
app.py CHANGED
@@ -1,68 +1,77 @@
1
  import gradio as gr
2
-
 
3
 
4
 
5
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
6
 
7
 
8
- def create_demo():
9
- # Fetch data once on startup (cache this in production)
10
- models = get_leaderboard_models_cached()
11
- datasets = get_leaderboard_datasets()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- with gr.Blocks(title="LLM Similarity Analyzer") as demo:
14
- gr.Markdown("## Compare Models/Datasets from Open LLM Leaderboard")
15
-
16
- with gr.Row():
17
- model_dropdown = gr.Dropdown(
18
- choices=models,
19
- label="Select Model",
20
- filterable=True,
21
- interactive=True,
22
- allow_custom_value=False,
23
- info="Search models from Open LLM Leaderboard"
24
- )
25
-
26
- dataset_dropdown = gr.Dropdown(
27
- choices=datasets,
28
  label="Select Dataset",
29
  filterable=True,
30
  interactive=True,
31
  info="Leaderboard benchmark datasets"
32
  )
33
 
34
- # Add your similarity computation and visualization components here
35
- # Example placeholder:
36
- similarity_output = gr.Textbox(label="Similarity Score")
37
- compute_btn = gr.Button("Compute Similarity")
38
-
39
- def compute_similarity(model, dataset):
40
- # Replace with your actual similarity metric
41
- return f"Similarity between {model} and {dataset}: {0.85:.2f}"
42
-
43
- compute_btn.click(
44
- fn=compute_similarity,
45
- inputs=[model_dropdown, dataset_dropdown],
46
- outputs=similarity_output
47
  )
48
-
49
- return demo
50
-
51
-
52
- def create_demo_with_refresh():
53
- demo = create_demo()
54
 
55
- with demo:
56
- refresh_btn = gr.Button("Refresh Model List")
57
- def refresh_models():
58
- return gr.Dropdown(choices=get_leaderboard_models_cached())
59
-
60
- refresh_btn.click(
61
- fn=refresh_models,
62
- outputs=model_dropdown
63
- )
64
 
65
- return demo
 
 
 
 
 
 
66
 
67
- demo = create_demo_with_refresh()
68
- demo.launch()
 
1
  import gradio as gr
2
+ import plotly.graph_objects as go
3
+ import numpy as np
4
 
5
 
6
  from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
7
 
8
 
9
+ def create_heatmap(selected_models, benchmark):
10
+ if not selected_models:
11
+ return gr.update(visible=False)
12
+
13
+ # Generate random similarity matrix
14
+ size = len(selected_models)
15
+ similarities = np.random.rand(size, size)
16
+
17
+ # Create symmetric matrix (for demo purposes)
18
+ similarities = (similarities + similarities.T) / 2
19
+
20
+ # Create heatmap with Plotly
21
+ fig = go.Figure(data=go.Heatmap(
22
+ z=similarities,
23
+ x=selected_models,
24
+ y=selected_models,
25
+ colorscale='Viridis',
26
+ hoverongaps=False
27
+ ))
28
+
29
+ fig.update_layout(
30
+ title=f"Model Similarity for {benchmark}",
31
+ xaxis_title="Models",
32
+ yaxis_title="Models",
33
+ height=600,
34
+ width=800
35
+ )
36
+
37
+ return fig
38
 
39
+ with gr.Blocks(title="LLM Similarity Analyzer") as demo:
40
+ gr.Markdown("## Model Similarity Comparison Tool")
41
+
42
+ # Model selection section
43
+ with gr.Row():
44
+ dataset_dropdown = gr.Dropdown(
45
+ choices=get_leaderboard_datasets(),
 
 
 
 
 
 
 
 
46
  label="Select Dataset",
47
  filterable=True,
48
  interactive=True,
49
  info="Leaderboard benchmark datasets"
50
  )
51
 
52
+ model_dropdown = gr.Dropdown(
53
+ choices=get_leaderboard_models_cached(),
54
+ label="Select Models",
55
+ multiselect=True,
56
+ filterable=True,
57
+ allow_custom_value=False,
58
+ info="Search and select multiple models (click selected models to remove)"
 
 
 
 
 
 
59
  )
 
 
 
 
 
 
60
 
61
+ # Heatmap display
62
+ heatmap = gr.Plot(
63
+ label="Similarity Heatmap",
64
+ visible=False,
65
+ container=False
66
+ )
 
 
 
67
 
68
+ # Interactive updates
69
+ model_dropdown.input(
70
+ fn=create_heatmap,
71
+ inputs=(model_dropdown, dataset_dropdown),
72
+ outputs=heatmap
73
+ )
74
+
75
 
76
+ if __name__ == "__main__":
77
+ demo.launch()
src/dataloading.py CHANGED
@@ -5,16 +5,16 @@ from functools import lru_cache
5
  def get_leaderboard_models():
6
  api = HfApi()
7
 
8
- # List all files in the "open_llm_leaderboard" directory of the Space
9
  files = api.list_repo_files(
10
  repo_id="open-llm-leaderboard/open_llm_leaderboard",
11
- repo_type="space",
12
- path="open_llm_leaderboard"
13
  )
14
 
15
  models = []
16
  for file in files:
17
- if "-details" in file and "__" in file:
 
18
  # Extract provider and model name from filename
19
  filename = file.split("/")[-1].replace("-details", "")
20
  provider, model = filename.split("__", 1)
@@ -23,6 +23,7 @@ def get_leaderboard_models():
23
  return sorted(list(set(models))) # Remove duplicates
24
 
25
 
 
26
  @lru_cache(maxsize=1)
27
  def get_leaderboard_models_cached():
28
  return get_leaderboard_models()
 
5
  def get_leaderboard_models():
6
  api = HfApi()
7
 
8
+ # List all files in the repository
9
  files = api.list_repo_files(
10
  repo_id="open-llm-leaderboard/open_llm_leaderboard",
11
+ repo_type="space"
 
12
  )
13
 
14
  models = []
15
  for file in files:
16
+ # Filter files in the "open_llm_leaderboard" directory
17
+ if file.startswith("open_llm_leaderboard/") and "-details" in file and "__" in file:
18
  # Extract provider and model name from filename
19
  filename = file.split("/")[-1].replace("-details", "")
20
  provider, model = filename.split("__", 1)
 
23
  return sorted(list(set(models))) # Remove duplicates
24
 
25
 
26
+
27
  @lru_cache(maxsize=1)
28
  def get_leaderboard_models_cached():
29
  return get_leaderboard_models()