import gradio as gr import plotly.graph_objects as go import numpy as np from src.dataloading import get_leaderboard_models_cached, get_leaderboard_datasets def create_heatmap(selected_models, selected_dataset): if not selected_models or not selected_dataset: return None # Hide the plot if no selection # Generate random similarity matrix size = len(selected_models) similarities = np.random.rand(size, size) similarities = (similarities + similarities.T) / 2 # Make symmetric similarities = np.round(similarities, 2) # Round for clarity # Create the heatmap figure fig = go.Figure(data=go.Heatmap( z=similarities, x=selected_models, y=selected_models, colorscale='Viridis', zmin=0, zmax=1, text=similarities, hoverinfo="text" )) # Update layout for title, size, margins, etc. fig.update_layout( title=f"Similarity Matrix for {selected_dataset}", xaxis_title="Models", yaxis_title="Models", width=800 + 20 * len(selected_models), height=800 + 20 * len(selected_models), margin=dict(b=100, l=100) ) # Force axes to be categorical and explicitly set the order fig.update_xaxes( type="category", tickangle=45, categoryorder="array", categoryarray=selected_models, # Explicitly force ordering to match your list automargin=True, showgrid=True, showticklabels=True ) fig.update_yaxes( type="category", categoryorder="array", categoryarray=selected_models, automargin=True, showgrid=True, showticklabels=True ) return fig def validate_inputs(selected_models, selected_dataset): if not selected_models: raise gr.Error("Please select at least one model!") if not selected_dataset: raise gr.Error("Please select a dataset!") # Gradio interface setup with gr.Blocks(title="LLM Similarity Analyzer") as demo: gr.Markdown("## Model Similarity Comparison Tool") with gr.Row(): dataset_dropdown = gr.Dropdown( choices=get_leaderboard_datasets(), label="Select Dataset", filterable=True, interactive=True, info="Leaderboard benchmark datasets" ) model_dropdown = gr.Dropdown( choices=get_leaderboard_models_cached(), label="Select Models", multiselect=True, filterable=True, allow_custom_value=False, info="Search and select multiple models" ) generate_btn = gr.Button("Generate Heatmap", variant="primary") heatmap = gr.Plot(label="Similarity Heatmap", visible=True) # Use a single output (the figure) generate_btn.click( fn=validate_inputs, inputs=[model_dropdown, dataset_dropdown], queue=False ).then( fn=create_heatmap, inputs=[model_dropdown, dataset_dropdown], outputs=heatmap ) # Clear button: clear selections and the plot clear_btn = gr.Button("Clear Selection") clear_btn.click( lambda: [None, None, None], outputs=[model_dropdown, dataset_dropdown, heatmap] ) if __name__ == "__main__": demo.launch()