File size: 3,209 Bytes
2f2195a
3b16cfa
e1a6930
cdeeed6
2f2195a
88e5618
4adb140
c8f741c
 
53d5dd8
d8f2ec7
88e5618
e1a6930
 
3b16cfa
54b2baf
88e5618
3b16cfa
 
 
 
 
 
 
 
 
4adb140
e1a6930
53d5dd8
3b16cfa
 
 
 
 
e1a6930
4adb140
88e5618
4adb140
 
 
 
 
 
 
 
 
 
 
 
 
 
88e5618
 
53d5dd8
 
 
 
 
 
 
e1a6930
 
88e5618
e1a6930
cca1790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88e5618
 
cca1790
 
 
 
 
 
 
 
60ded99
cca1790
 
 
 
88e5618
723cce8
cca1790
 
 
88e5618
4adb140
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import plotly.graph_objects as go
import numpy as np
from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets

# Optionally, force a renderer (may or may not help)
import plotly.io as pio
pio.renderers.default = "iframe"

def create_heatmap(selected_models, selected_dataset):
    if not selected_models or not selected_dataset:
        return ""  # Return empty HTML if no input
    size = len(selected_models)
    similarities = np.random.rand(size, size)
    similarities = (similarities + similarities.T) / 2
    similarities = np.round(similarities, 2)
    
    fig = go.Figure(data=go.Heatmap(
        z=similarities,
        x=selected_models,
        y=selected_models,
        colorscale="Viridis",
        zmin=0, zmax=1,
        text=similarities,
        hoverinfo="text"
    ))
    
    fig.update_layout(
        title=f"Similarity Matrix for {selected_dataset}",
        xaxis_title="Models",
        yaxis_title="Models",
        width=800,
        height=800,
        margin=dict(l=100, r=100, t=100, b=100)
    )
    
    # Force categorical ordering with explicit tick settings.
    fig.update_xaxes(
        type="category",
        categoryorder="array",
        categoryarray=selected_models,
        tickangle=45,
        automargin=True
    )
    fig.update_yaxes(
        type="category",
        categoryorder="array",
        categoryarray=selected_models,
        automargin=True
    )
    
    # Convert the figure to an HTML string that includes Plotly.js via CDN.
    return fig.to_html(full_html=False, include_plotlyjs="cdn")

def validate_inputs(selected_models, selected_dataset):
    if not selected_models:
        raise gr.Error("Please select at least one model!")
    if not selected_dataset:
        raise gr.Error("Please select a dataset!")

with gr.Blocks(title="LLM Similarity Analyzer") as demo:
    gr.Markdown("## Model Similarity Comparison Tool")
    
    with gr.Row():
        dataset_dropdown = gr.Dropdown(
            choices=get_leaderboard_datasets(),
            label="Select Dataset",
            filterable=True,
            interactive=True,
            info="Leaderboard benchmark datasets"
        )
        model_dropdown = gr.Dropdown(
            choices=get_leaderboard_models_cached(),
            label="Select Models",
            multiselect=True,
            filterable=True,
            allow_custom_value=False,
            info="Search and select multiple models"
        )
    
    generate_btn = gr.Button("Generate Heatmap", variant="primary")
    # Use an HTML component instead of gr.Plot.
    heatmap = gr.HTML(label="Similarity Heatmap", visible=True)
    
    generate_btn.click(
        fn=validate_inputs,
        inputs=[model_dropdown, dataset_dropdown],
        queue=False
    ).then(
        fn=create_heatmap,
        inputs=[model_dropdown, dataset_dropdown],
        outputs=heatmap
    )
    
    clear_btn = gr.Button("Clear Selection")
    clear_btn.click(
        lambda: [None, None, ""],
        outputs=[model_dropdown, dataset_dropdown, heatmap]
    )

if __name__ == "__main__":
    # On Spaces, disable server-side rendering.
    demo.launch(ssr_mode=False)