File size: 2,163 Bytes
2f2195a
e1a6930
 
cdeeed6
2f2195a
53d5dd8
8901fb0
d8f2ec7
53d5dd8
e1a6930
d8f2ec7
e1a6930
 
d8f2ec7
e1a6930
53d5dd8
e1a6930
 
 
 
53d5dd8
e1a6930
 
 
53d5dd8
 
 
e1a6930
 
d8f2ec7
 
53d5dd8
 
8901fb0
53d5dd8
 
 
 
 
e1a6930
 
 
 
8901fb0
 
 
 
 
 
 
 
 
 
 
 
 
7fa11aa
8901fb0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import plotly.graph_objects as go
import numpy as np
from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets

def create_heatmap(selected_models, selected_dataset):
    print(f"Creating heatmap with models: {selected_models} and dataset: {selected_dataset}")
    if not selected_models or not selected_dataset:
        return gr.Plot(visible=False)
    
    # Generate random similarity matrix
    size = len(selected_models)
    similarities = np.random.rand(size, size)
    similarities = (similarities + similarities.T) / 2  # Make symmetric
    
    # Create plot
    fig = go.Figure(data=go.Heatmap(
        z=similarities,
        x=selected_models,
        y=selected_models,
        colorscale='Viridis'
    ))
    
    fig.update_layout(
        title=f"Similarity Matrix for {selected_dataset}",
        width=800,
        height=800
    )
    
    # Return both the figure and visibility update
    return gr.Plot.update(value=fig, visible=True)

def validate_inputs(selected_models, selected_dataset):
    print(f"Validating inputs: models={selected_models}, dataset={selected_dataset}")
    if not selected_models:
        raise gr.Error("Please select at least one model!")
    if not selected_dataset:
        raise gr.Error("Please select a dataset!")

with gr.Blocks(title="LLM Similarity Analyzer") as demo:
    gr.Markdown("## Model Similarity Comparison Tool")
    
    with gr.Row():
        model_selector = gr.Dropdown(label="Select Models", choices=get_leaderboard_models_cached(), multiselect=True)
        dataset_selector = gr.Dropdown(label="Select Dataset", choices=get_leaderboard_datasets())
        heatmap_output = gr.Plot(visible=False)
    
    def on_submit(selected_models, selected_dataset):
        try:
            validate_inputs(selected_models, selected_dataset)
            return create_heatmap(selected_models, selected_dataset)
        except gr.Error as e:
            return gr.Markdown(str(e))
    
    submit_button = gr.Button("Generate Heatmap")
    submit_button.click(on_submit, inputs=[model_selector, dataset_selector], outputs=heatmap_output)

demo.launch()