lm-similarity / app.py
Joschka Strueber
[Fix] heatmap not generated and deselection causes error
8901fb0
raw
history blame
2.16 kB
import gradio as gr
import plotly.graph_objects as go
import numpy as np
from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
def create_heatmap(selected_models, selected_dataset):
print(f"Creating heatmap with models: {selected_models} and dataset: {selected_dataset}")
if not selected_models or not selected_dataset:
return gr.Plot(visible=False)
# Generate random similarity matrix
size = len(selected_models)
similarities = np.random.rand(size, size)
similarities = (similarities + similarities.T) / 2 # Make symmetric
# Create plot
fig = go.Figure(data=go.Heatmap(
z=similarities,
x=selected_models,
y=selected_models,
colorscale='Viridis'
))
fig.update_layout(
title=f"Similarity Matrix for {selected_dataset}",
width=800,
height=800
)
# Return both the figure and visibility update
return gr.Plot.update(value=fig, visible=True)
def validate_inputs(selected_models, selected_dataset):
print(f"Validating inputs: models={selected_models}, dataset={selected_dataset}")
if not selected_models:
raise gr.Error("Please select at least one model!")
if not selected_dataset:
raise gr.Error("Please select a dataset!")
with gr.Blocks(title="LLM Similarity Analyzer") as demo:
gr.Markdown("## Model Similarity Comparison Tool")
with gr.Row():
model_selector = gr.Dropdown(label="Select Models", choices=get_leaderboard_models_cached(), multiselect=True)
dataset_selector = gr.Dropdown(label="Select Dataset", choices=get_leaderboard_datasets())
heatmap_output = gr.Plot(visible=False)
def on_submit(selected_models, selected_dataset):
try:
validate_inputs(selected_models, selected_dataset)
return create_heatmap(selected_models, selected_dataset)
except gr.Error as e:
return gr.Markdown(str(e))
submit_button = gr.Button("Generate Heatmap")
submit_button.click(on_submit, inputs=[model_selector, dataset_selector], outputs=heatmap_output)
demo.launch()