lm-similarity / app_simple.py
Joschka Strueber
[Add] heatmap plot with seaborn instead of plotly
465a95b
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
from PIL import Image
from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets
from src.similarity import compute_similarity
# Set the backend to 'Agg' for non-GUI environments (optional)
import matplotlib
matplotlib.use('Agg')
def generate_plot():
# Generate data
x = np.linspace(0, 10, 100)
y = np.sin(x)
# Create figure
fig, ax = plt.subplots()
ax.plot(x, y)
ax.set_title("Sine Wave")
# Save figure to a BytesIO buffer
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", facecolor="white", dpi=100)
plt.close(fig) # Close the figure to free memory
# Convert buffer to PIL Image
buf.seek(0)
img = Image.open(buf).convert("RGB")
return img
def validate_inputs(selected_model_a, selected_model_b, selected_dataset):
if not selected_model_a:
raise gr.Error("Please select Model A!")
if not selected_model_b:
raise gr.Error("Please select Model B!")
if not selected_dataset:
raise gr.Error("Please select a dataset!")
def display_similarity(model_a, model_b, dataset):
# Assuming compute_similarity returns a float or a string
similarity_score = compute_similarity(model_a, model_b, dataset)
return f"The similarity between {model_a} and {model_b} on {dataset} is: {similarity_score}"
with gr.Blocks(title="LLM Similarity Analyzer") as demo:
gr.Markdown("## Model Similarity Comparison Tool")
dataset_dropdown = gr.Dropdown(
choices=get_leaderboard_datasets(),
label="Select Dataset",
filterable=True,
interactive=True,
info="Leaderboard benchmark datasets"
)
model_a_dropdown = gr.Dropdown(
choices=get_leaderboard_models_cached(),
label="Select Model A",
filterable=True,
allow_custom_value=False,
info="Search and select models"
)
model_b_dropdown = gr.Dropdown(
choices=get_leaderboard_models_cached(),
label="Select Model B",
filterable=True,
allow_custom_value=False,
info="Search and select models"
)
generate_btn = gr.Button("Compute Similarity", variant="primary")
# Textbox to display the similarity result
similarity_output = gr.Textbox(
label="Similarity Result",
interactive=False
)
generate_btn.click(
fn=validate_inputs,
inputs=[model_a_dropdown, model_b_dropdown, dataset_dropdown],
queue=False
).then(
fn=display_similarity,
inputs=[model_a_dropdown, model_b_dropdown, dataset_dropdown],
outputs=similarity_output
)
clear_btn = gr.Button("Clear Selection")
clear_btn.click(
lambda: [None, None, None, ""],
outputs=[model_a_dropdown, model_b_dropdown, dataset_dropdown, similarity_output]
)
gr.Markdown("## Matplotlib Plot in Gradio")
plot_button = gr.Button("Generate Plot")
plot_output = gr.Image(label="Sine Wave Plot")
plot_button.click(fn=generate_plot, outputs=plot_output)
if __name__ == "__main__":
demo.launch()