import gradio as gr import pandas as pd import re import os import json import yaml import matplotlib.pyplot as plt import seaborn as sns import plotnine as p9 import sys sys.path.append('./src') sys.path.append('.') from huggingface_hub import HfApi repo_id = "HUBioDataLab/PROBE" api = HfApi() from src.about import * from src.saving_utils import * from src.vis_utils import * from src.bin.PROBE import run_probe # ------------------------------------------------------------------ # Helper functions -------------------------------------------------- # ------------------------------------------------------------------ def add_new_eval( human_file, skempi_file, model_name_textbox: str, benchmark_types, similarity_tasks, function_prediction_aspect, function_prediction_dataset, family_prediction_dataset, save, ): """Validate inputs, run evaluation and (optionally) save results.""" if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None: gr.Warning("Human representations are required for similarity, family, or function benchmarks!") return -1 if 'affinity' in benchmark_types and skempi_file is None: gr.Warning("SKEMPI representations are required for affinity benchmark!") return -1 gr.Info("Your submission is being processed…") representation_name = model_name_textbox try: results = run_probe( benchmark_types, representation_name, human_file, skempi_file, similarity_tasks, function_prediction_aspect, function_prediction_dataset, family_prediction_dataset, ) except Exception: gr.Warning("Your submission has not been processed. Please check your representation files!") return -1 if save: save_results(representation_name, benchmark_types, results) gr.Info("Your submission has been processed and results are saved!") else: gr.Info("Your submission has been processed!") return 0 def refresh_data(): """Re‑start the space and pull fresh leaderboard CSVs from the HF Hub.""" api.restart_space(repo_id=repo_id) benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"] for benchmark_type in benchmark_types: path = f"/tmp/{benchmark_type}_results.csv" if os.path.exists(path): os.remove(path) benchmark_types.remove("leaderboard") download_from_hub(benchmark_types) # ------- Leaderboard helpers ----------------------------------------------- def update_metrics(selected_benchmarks): updated_metrics = set() for benchmark in selected_benchmarks: updated_metrics.update(benchmark_metric_mapping.get(benchmark, [])) return list(updated_metrics) def update_leaderboard(selected_methods, selected_metrics): return get_baseline_df(selected_methods, selected_metrics) # ------- Visualisation helpers --------------------------------------------- def generate_plot(benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric): plot_path = benchmark_plot( benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric, ) return plot_path # --------------------------------------------------------------------------- # Custom CSS for frozen first column and clearer table styles # --------------------------------------------------------------------------- CUSTOM_CSS = """ /* freeze first column */ #leaderboard-table table tr th:first-child, #leaderboard-table table tr td:first-child { position: sticky; left: 0; background: white; z-index: 2; } /* striped rows for readability */ #leaderboard-table table tr:nth-child(odd) { background: #fafafa; } /* centre numeric cells */ #leaderboard-table td:not(:first-child) { text-align: center; } /* scrollable and taller table */ #leaderboard-table .dataframe-wrap { max-height: 1200px; overflow-y: auto; overflow-x: auto; } """ # --------------------------------------------------------------------------- # UI definition # --------------------------------------------------------------------------- block = gr.Blocks(css=CUSTOM_CSS) with block: gr.Markdown(LEADERBOARD_INTRODUCTION) with gr.Tabs(elem_classes="tab-buttons") as tabs: # ------------------------------------------------------------------ # 1️⃣ Leaderboard tab # ------------------------------------------------------------------ with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1): # small workflow figure at top gr.Image( value="./src/data/PROBE_workflow_figure.jpg", show_label=False, height=1000, container=False, ) gr.Markdown( "## For detailed explanations of the metrics and benchmarks, please refer to the 📝 About tab.", elem_classes="leaderboard-note", ) leaderboard = get_baseline_df(None, None) method_names = leaderboard['Method'].unique().tolist() metric_names = leaderboard.columns.tolist(); metric_names.remove('Method') benchmark_metric_mapping = { "similarity": [m for m in metric_names if m.startswith('sim_')], "function": [m for m in metric_names if m.startswith('func')], "family": [m for m in metric_names if m.startswith('fam_')], "affinity": [m for m in metric_names if m.startswith('aff_')], } leaderboard_method_selector = gr.CheckboxGroup( choices=method_names, label="Select Methods", value=method_names, interactive=True, ) benchmark_type_selector_lb = gr.CheckboxGroup( choices=list(benchmark_metric_mapping.keys()), label="Select Benchmark Types", value=None, interactive=True, ) leaderboard_metric_selector = gr.CheckboxGroup( choices=metric_names, label="Select Metrics", value=None, interactive=True, ) baseline_value = get_baseline_df(method_names, metric_names) baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x) baseline_header = ["Method"] + metric_names baseline_datatype = ['markdown'] + ['number'] * len(metric_names) with gr.Row(show_progress=True, variant='panel'): data_component = gr.Dataframe( value=baseline_value, headers=baseline_header, type="pandas", datatype=baseline_datatype, interactive=False, elem_id="leaderboard-table", pinned_columns=1, max_height=1000, ) # callbacks leaderboard_method_selector.change( get_baseline_df, inputs=[leaderboard_method_selector, leaderboard_metric_selector], outputs=data_component, ) benchmark_type_selector_lb.change( lambda selected: update_metrics(selected), inputs=[benchmark_type_selector_lb], outputs=leaderboard_metric_selector, ) leaderboard_metric_selector.change( get_baseline_df, inputs=[leaderboard_method_selector, leaderboard_metric_selector], outputs=data_component, ) # ------------------------------------------------------------------ # 2️⃣ Visualisation tab # ------------------------------------------------------------------ with gr.TabItem("📊 Visualization", elem_id="probe-benchmark-tab-visualization", id=2): gr.Markdown( """## **Interactive Visualizations** Choose a benchmark type; context‑specific options will appear.""", elem_classes="markdown-text", ) vis_benchmark_type_selector = gr.Dropdown( choices=list(benchmark_specific_metrics.keys()), label="Benchmark Type", value=None, ) with gr.Row(): vis_x_metric_selector = gr.Dropdown(choices=[], label="X‑axis Metric", visible=False) vis_y_metric_selector = gr.Dropdown(choices=[], label="Y‑axis Metric", visible=False) vis_aspect_type_selector = gr.Dropdown(choices=[], label="Aspect", visible=False) vis_dataset_selector = gr.Dropdown(choices=[], label="Dataset", visible=False) vis_single_metric_selector = gr.Dropdown(choices=[], label="Metric", visible=False) vis_method_selector = gr.CheckboxGroup( choices=method_names, label="Methods", value=method_names, interactive=True, ) plot_button = gr.Button("Plot") with gr.Row(show_progress=True, variant='panel'): plot_output = gr.Image(label="Plot") plot_explanation = gr.Markdown(visible=False) # callbacks vis_benchmark_type_selector.change( update_metric_choices, inputs=[vis_benchmark_type_selector], outputs=[ vis_x_metric_selector, vis_y_metric_selector, vis_aspect_type_selector, vis_dataset_selector, vis_single_metric_selector, ], ) plot_button.click( generate_plot, inputs=[ vis_benchmark_type_selector, vis_method_selector, vis_x_metric_selector, vis_y_metric_selector, vis_aspect_type_selector, vis_dataset_selector, vis_single_metric_selector, ], outputs=[plot_output], ) # ------------------------------------------------------------------ # 3️⃣ About tab # ------------------------------------------------------------------ with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=3): with gr.Row(): gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text") with gr.Row(): gr.Image( value="./src/data/PROBE_workflow_figure.jpg", label="PROBE Workflow Figure", elem_classes="about-image", ) # ------------------------------------------------------------------ # 4️⃣ Submit tab # ------------------------------------------------------------------ with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4): with gr.Row(): gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text") with gr.Row(): gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text") with gr.Row(): with gr.Column(): model_name_textbox = gr.Textbox(label="Method name") benchmark_types = gr.CheckboxGroup(choices=TASK_INFO, label="Benchmark Types", interactive=True) similarity_tasks = gr.CheckboxGroup(choices=similarity_tasks_options, label="Similarity Tasks", interactive=True) function_prediction_aspect = gr.Radio(choices=function_prediction_aspect_options, label="Function Prediction Aspects", interactive=True) family_prediction_dataset = gr.CheckboxGroup(choices=family_prediction_dataset_options, label="Family Prediction Datasets", interactive=True) function_dataset = gr.Textbox(label="Function Prediction Datasets", visible=False, value="All_Data_Sets") save_checkbox = gr.Checkbox(label="Save results for leaderboard and visualization", value=True) with gr.Row(): human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath') skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath') submit_button = gr.Button("Submit Eval") submission_result = gr.Markdown() submit_button.click( add_new_eval, inputs=[ human_file, skempi_file, model_name_textbox, benchmark_types, similarity_tasks, function_prediction_aspect, function_dataset, family_prediction_dataset, save_checkbox, ], ) # global refresh + citation --------------------------------------------- with gr.Row(): data_run = gr.Button("Refresh") data_run.click(refresh_data, outputs=[data_component]) with gr.Accordion("Citation", open=False): citation_button = gr.Textbox( value=CITATION_BUTTON_TEXT, label=CITATION_BUTTON_LABEL, elem_id="citation-button", show_copy_button=True, ) # --------------------------------------------------------------------------- block.launch()