lm-similarity / app.py
Joschka Strueber
[Add] load models and datasets from hub, compute similarities
a48b15f
raw
history blame
4 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image
from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
# Set matplotlib backend for non-GUI environments
plt.switch_backend('Agg')
def create_heatmap(selected_models, selected_dataset, selected_metric):
if not selected_models or not selected_dataset:
return None
# Sort models and get short names
selected_models = sorted(selected_models)
selected_models_short = [model.split("/")[-1] for model in selected_models]
# Generate random similarity matrix
size = len(selected_models)
similarities = np.random.rand(size, size)
similarities = (similarities + similarities.T) / 2
similarities = np.round(similarities, 2)
# Create figure and heatmap using seaborn
plt.figure(figsize=(8, 6))
ax = sns.heatmap(
similarities,
annot=True,
fmt=".2f",
cmap="viridis",
vmin=0,
vmax=1,
xticklabels=selected_models_short,
yticklabels=selected_models_short
)
# Customize plot
plt.title(f"{selected_metric} Similarities for {selected_dataset}", fontsize=16)
plt.xlabel("Models", fontsize=14)
plt.ylabel("Models", fontsize=14)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
# Save to buffer
buf = BytesIO()
plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
plt.close()
# Convert to PIL Image
buf.seek(0)
img = Image.open(buf).convert("RGB")
return img
def validate_inputs(selected_models, selected_dataset):
if not selected_models:
raise gr.Error("Please select at least one model!")
if not selected_dataset:
raise gr.Error("Please select a dataset!")
def update_datasets_based_on_models(selected_models, current_dataset):
# Get available datasets for selected models
available_datasets = get_leaderboard_datasets(selected_models) if selected_models else []
# Check if current dataset is still valid
valid_dataset = current_dataset if current_dataset in available_datasets else None
return gr.Dropdown.update(
choices=available_datasets,
value=valid_dataset
)
with gr.Blocks(title="LLM Similarity Analyzer") as demo:
gr.Markdown("## Model Similarity Comparison Tool")
model_dropdown = gr.Dropdown(
choices=get_leaderboard_models_cached(),
label="Select Models",
multiselect=True,
filterable=True,
allow_custom_value=False,
info="Search and select multiple models"
)
with gr.Row():
dataset_dropdown = gr.Dropdown(
choices=get_leaderboard_datasets(),
label="Select Dataset",
filterable=True,
interactive=True,
info="Open LLM Leaderboard v2 benchmark datasets"
)
metric_dropdown = gr.Dropdown(
choices=["Kappa_p (prob.)", "Kappa_p (det.)", "Error Consistency"],
label="Select Metric",
info="Select a similarity metric to compute"
)
model_dropdown.change(
fn=update_datasets_based_on_models,
inputs=[model_dropdown, dataset_dropdown],
outputs=dataset_dropdown
)
generate_btn = gr.Button("Generate Heatmap", variant="primary")
heatmap = gr.Image(label="Similarity Heatmap", visible=True)
generate_btn.click(
fn=validate_inputs,
inputs=[model_dropdown, dataset_dropdown],
queue=False
).then(
fn=create_heatmap,
inputs=[model_dropdown, dataset_dropdown, metric_dropdown],
outputs=heatmap
)
clear_btn = gr.Button("Clear Selection")
clear_btn.click(
lambda: [[], None, None],
outputs=[model_dropdown, dataset_dropdown, heatmap]
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)