Spaces:
Running
Running
File size: 4,024 Bytes
0f7de99 2f2195a 874e761 465a95b 874e761 fc18b54 2f2195a 0f7de99 465a95b 874e761 fc18b54 0f7de99 fc18b54 ffacaaa 465a95b 874e761 cc861f0 36159b1 0f7de99 465a95b 3c1039a 465a95b 0f7de99 465a95b 874e761 465a95b 3c1039a 36159b1 465a95b 874e761 465a95b 874e761 465a95b 874e761 465a95b f3cd231 a48b15f c8f741c f3cd231 4adb140 465a95b 1168f81 465a95b 1168f81 32f9617 465a95b 2cee451 465a95b a48b15f 1168f81 a48b15f ffacaaa cca1790 465a95b cca1790 465a95b cca1790 465a95b ffacaaa 465a95b cca1790 36159b1 465a95b cca1790 465a95b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import BytesIO
from PIL import Image
from huggingface_hub import login
from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
from src.similarity import load_data_and_compute_similarities
# Set matplotlib backend for non-GUI environments
plt.switch_backend('Agg')
# Login to Hugging Face Hub
token = os.getenv("HF_TOKEN")
login(token=token)
def create_heatmap(selected_models, selected_dataset, selected_metric):
if not selected_models or not selected_dataset:
return None
# Sort models and get short names
selected_models = sorted(selected_models)
similarities = load_data_and_compute_similarities(selected_models, selected_dataset, selected_metric)
# Create figure and heatmap using seaborn
plt.figure(figsize=(8, 6))
ax = sns.heatmap(
similarities,
annot=True,
fmt=".2f",
cmap="viridis",
vmin=0,
vmax=1,
xticklabels=selected_models,
yticklabels=selected_models
)
# Customize plot
plt.title(f"{selected_metric} Similarities for {selected_dataset}", fontsize=16)
plt.xlabel("Models", fontsize=14)
plt.ylabel("Models", fontsize=14)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
# Save to buffer
buf = BytesIO()
plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
plt.close()
# Convert to PIL Image
buf.seek(0)
img = Image.open(buf).convert("RGB")
return img
def validate_inputs(selected_models, selected_dataset):
if not selected_models:
raise gr.Error("Please select at least one model!")
if not selected_dataset:
raise gr.Error("Please select a dataset!")
def update_datasets_based_on_models(selected_models, current_dataset):
# Get available datasets for selected models
available_datasets = get_leaderboard_datasets(selected_models) if selected_models else []
# Check if current dataset is still valid
valid_dataset = current_dataset if current_dataset in available_datasets else None
return gr.Dropdown.update(
choices=available_datasets,
value=valid_dataset
)
with gr.Blocks(title="LLM Similarity Analyzer") as demo:
gr.Markdown("## Model Similarity Comparison Tool")
with gr.Row():
dataset_dropdown = gr.Dropdown(
choices=get_leaderboard_datasets(None),
label="Select Dataset",
filterable=True,
interactive=True,
allow_custom_value=False,
info="Open LLM Leaderboard v2 benchmark datasets"
)
metric_dropdown = gr.Dropdown(
choices=["Kappa_p (prob.)", "Kappa_p (det.)", "Error Consistency"],
label="Select Metric",
info="Select a similarity metric to compute"
)
model_dropdown = gr.Dropdown(
choices=get_leaderboard_models_cached(),
label="Select Models",
multiselect=True,
filterable=True,
allow_custom_value=False,
info="Search and select multiple models"
)
model_dropdown.change(
fn=update_datasets_based_on_models,
inputs=[model_dropdown, dataset_dropdown],
outputs=dataset_dropdown
)
generate_btn = gr.Button("Generate Heatmap", variant="primary")
heatmap = gr.Image(label="Similarity Heatmap", visible=True)
generate_btn.click(
fn=validate_inputs,
inputs=[model_dropdown, dataset_dropdown],
queue=False
).then(
fn=create_heatmap,
inputs=[model_dropdown, dataset_dropdown, metric_dropdown],
outputs=heatmap
)
clear_btn = gr.Button("Clear Selection")
clear_btn.click(
lambda: [[], None, None],
outputs=[model_dropdown, dataset_dropdown, heatmap]
)
if __name__ == "__main__":
demo.launch(ssr_mode=False) |