Spaces:
Running
Running
File size: 3,209 Bytes
2f2195a 3b16cfa e1a6930 cdeeed6 2f2195a 88e5618 4adb140 c8f741c 53d5dd8 d8f2ec7 88e5618 e1a6930 3b16cfa 54b2baf 88e5618 3b16cfa 4adb140 e1a6930 53d5dd8 3b16cfa e1a6930 4adb140 88e5618 4adb140 88e5618 53d5dd8 e1a6930 88e5618 e1a6930 cca1790 88e5618 cca1790 60ded99 cca1790 88e5618 723cce8 cca1790 88e5618 4adb140 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
import plotly.graph_objects as go
import numpy as np
from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
# Optionally, force a renderer (may or may not help)
import plotly.io as pio
pio.renderers.default = "iframe"
def create_heatmap(selected_models, selected_dataset):
if not selected_models or not selected_dataset:
return "" # Return empty HTML if no input
size = len(selected_models)
similarities = np.random.rand(size, size)
similarities = (similarities + similarities.T) / 2
similarities = np.round(similarities, 2)
fig = go.Figure(data=go.Heatmap(
z=similarities,
x=selected_models,
y=selected_models,
colorscale="Viridis",
zmin=0, zmax=1,
text=similarities,
hoverinfo="text"
))
fig.update_layout(
title=f"Similarity Matrix for {selected_dataset}",
xaxis_title="Models",
yaxis_title="Models",
width=800,
height=800,
margin=dict(l=100, r=100, t=100, b=100)
)
# Force categorical ordering with explicit tick settings.
fig.update_xaxes(
type="category",
categoryorder="array",
categoryarray=selected_models,
tickangle=45,
automargin=True
)
fig.update_yaxes(
type="category",
categoryorder="array",
categoryarray=selected_models,
automargin=True
)
# Convert the figure to an HTML string that includes Plotly.js via CDN.
return fig.to_html(full_html=False, include_plotlyjs="cdn")
def validate_inputs(selected_models, selected_dataset):
if not selected_models:
raise gr.Error("Please select at least one model!")
if not selected_dataset:
raise gr.Error("Please select a dataset!")
with gr.Blocks(title="LLM Similarity Analyzer") as demo:
gr.Markdown("## Model Similarity Comparison Tool")
with gr.Row():
dataset_dropdown = gr.Dropdown(
choices=get_leaderboard_datasets(),
label="Select Dataset",
filterable=True,
interactive=True,
info="Leaderboard benchmark datasets"
)
model_dropdown = gr.Dropdown(
choices=get_leaderboard_models_cached(),
label="Select Models",
multiselect=True,
filterable=True,
allow_custom_value=False,
info="Search and select multiple models"
)
generate_btn = gr.Button("Generate Heatmap", variant="primary")
# Use an HTML component instead of gr.Plot.
heatmap = gr.HTML(label="Similarity Heatmap", visible=True)
generate_btn.click(
fn=validate_inputs,
inputs=[model_dropdown, dataset_dropdown],
queue=False
).then(
fn=create_heatmap,
inputs=[model_dropdown, dataset_dropdown],
outputs=heatmap
)
clear_btn = gr.Button("Clear Selection")
clear_btn.click(
lambda: [None, None, ""],
outputs=[model_dropdown, dataset_dropdown, heatmap]
)
if __name__ == "__main__":
# On Spaces, disable server-side rendering.
demo.launch(ssr_mode=False)
|