File size: 3,209 Bytes
f3cd231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import plotly.graph_objects as go
import numpy as np
from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets

# Optionally, force a renderer (may or may not help)
import plotly.io as pio
pio.renderers.default = "iframe"

def create_heatmap(selected_models, selected_dataset):
    if not selected_models or not selected_dataset:
        return ""  # Return empty HTML if no input
    size = len(selected_models)
    similarities = np.random.rand(size, size)
    similarities = (similarities + similarities.T) / 2
    similarities = np.round(similarities, 2)
    
    fig = go.Figure(data=go.Heatmap(
        z=similarities,
        x=selected_models,
        y=selected_models,
        colorscale="Viridis",
        zmin=0, zmax=1,
        text=similarities,
        hoverinfo="text"
    ))
    
    fig.update_layout(
        title=f"Similarity Matrix for {selected_dataset}",
        xaxis_title="Models",
        yaxis_title="Models",
        width=800,
        height=800,
        margin=dict(l=100, r=100, t=100, b=100)
    )
    
    # Force categorical ordering with explicit tick settings.
    fig.update_xaxes(
        type="category",
        categoryorder="array",
        categoryarray=selected_models,
        tickangle=45,
        automargin=True
    )
    fig.update_yaxes(
        type="category",
        categoryorder="array",
        categoryarray=selected_models,
        automargin=True
    )
    
    # Convert the figure to an HTML string that includes Plotly.js via CDN.
    return fig.to_html(full_html=False, include_plotlyjs="cdn")

def validate_inputs(selected_models, selected_dataset):
    if not selected_models:
        raise gr.Error("Please select at least one model!")
    if not selected_dataset:
        raise gr.Error("Please select a dataset!")

with gr.Blocks(title="LLM Similarity Analyzer") as demo:
    gr.Markdown("## Model Similarity Comparison Tool")
    
    with gr.Row():
        dataset_dropdown = gr.Dropdown(
            choices=get_leaderboard_datasets(),
            label="Select Dataset",
            filterable=True,
            interactive=True,
            info="Leaderboard benchmark datasets"
        )
        model_dropdown = gr.Dropdown(
            choices=get_leaderboard_models_cached(),
            label="Select Models",
            multiselect=True,
            filterable=True,
            allow_custom_value=False,
            info="Search and select multiple models"
        )
    
    generate_btn = gr.Button("Generate Heatmap", variant="primary")
    # Use an HTML component instead of gr.Plot.
    heatmap = gr.HTML(label="Similarity Heatmap", visible=True)
    
    generate_btn.click(
        fn=validate_inputs,
        inputs=[model_dropdown, dataset_dropdown],
        queue=False
    ).then(
        fn=create_heatmap,
        inputs=[model_dropdown, dataset_dropdown],
        outputs=heatmap
    )
    
    clear_btn = gr.Button("Clear Selection")
    clear_btn.click(
        lambda: [None, None, ""],
        outputs=[model_dropdown, dataset_dropdown, heatmap]
    )

if __name__ == "__main__":
    # On Spaces, disable server-side rendering.
    demo.launch(ssr_mode=False)