File size: 2,984 Bytes
2f2195a
3b16cfa
e1a6930
cdeeed6
2f2195a
b3306d0
4adb140
c8f741c
 
53d5dd8
d8f2ec7
3b16cfa
1a7f19c
e1a6930
 
3b16cfa
54b2baf
1a7f19c
3b16cfa
 
 
 
 
 
 
 
 
4adb140
e1a6930
53d5dd8
3b16cfa
 
 
 
 
e1a6930
4adb140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3306d0
 
53d5dd8
 
 
 
 
 
 
e1a6930
 
 
cca1790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbb8c61
cca1790
 
 
 
 
 
 
 
60ded99
cca1790
 
 
 
b18c9e2
723cce8
cca1790
 
 
b3306d0
4adb140
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import plotly.graph_objects as go
import numpy as np
from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets

# Force Plotly to use the iframe renderer
import plotly.io as pio
pio.renderers.default = "iframe"

def create_heatmap(selected_models, selected_dataset):
    if not selected_models or not selected_dataset:
        return None

    size = len(selected_models)
    similarities = np.random.rand(size, size)
    similarities = (similarities + similarities.T) / 2
    similarities = np.round(similarities, 2)

    fig = go.Figure(data=go.Heatmap(
        z=similarities,
        x=selected_models,
        y=selected_models,
        colorscale="Viridis",
        zmin=0, zmax=1,
        text=similarities,
        hoverinfo="text"
    ))
    
    fig.update_layout(
        title=f"Similarity Matrix for {selected_dataset}",
        xaxis_title="Models",
        yaxis_title="Models",
        width=800,
        height=800,
        margin=dict(l=100, r=100, t=100, b=100)
    )
    
    fig.update_xaxes(
        type="category",
        categoryorder="array",
        categoryarray=selected_models,
        tickangle=45,
        automargin=True
    )
    fig.update_yaxes(
        type="category",
        categoryorder="array",
        categoryarray=selected_models,
        automargin=True
    )
    
    # Return the Plotly figure object directly
    return fig

def validate_inputs(selected_models, selected_dataset):
    if not selected_models:
        raise gr.Error("Please select at least one model!")
    if not selected_dataset:
        raise gr.Error("Please select a dataset!")

with gr.Blocks(title="LLM Similarity Analyzer") as demo:
    gr.Markdown("## Model Similarity Comparison Tool")
    with gr.Row():
        dataset_dropdown = gr.Dropdown(
            choices=get_leaderboard_datasets(),
            label="Select Dataset",
            filterable=True,
            interactive=True,
            info="Leaderboard benchmark datasets"
        )
        model_dropdown = gr.Dropdown(
            choices=get_leaderboard_models_cached(),
            label="Select Models",
            multiselect=True,
            filterable=True,
            allow_custom_value=False,
            info="Search and select multiple models"
        )
    
    generate_btn = gr.Button("Generate Heatmap", variant="primary")
    heatmap = gr.Plot(label="Similarity Heatmap", visible=True)
    
    generate_btn.click(
        fn=validate_inputs,
        inputs=[model_dropdown, dataset_dropdown],
        queue=False
    ).then(
        fn=create_heatmap,
        inputs=[model_dropdown, dataset_dropdown],
        outputs=heatmap
    )
    
    clear_btn = gr.Button("Clear Selection")
    clear_btn.click(
        lambda: [None, None, None],
        outputs=[model_dropdown, dataset_dropdown, heatmap]
    )

if __name__ == "__main__":
    # Disable SSR to force client-side rendering on Spaces
    demo.launch(ssr_mode=False)