PROBE / app.py
mgyigit's picture
Update app.py
edb9d91 verified
raw
history blame
17 kB
import gradio as gr
import pandas as pd
import re
import os
import json
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine as p9
import sys
sys.path.append('./src')
sys.path.append('.')
from huggingface_hub import HfApi
repo_id = "HUBioDataLab/PROBE"
api = HfApi()
from src.about import *
from src.saving_utils import *
from src.vis_utils import *
from src.bin.PROBE import run_probe
# ------------------------------------------------------------------
# Helper functions moved / added here so that UI callbacks can see them
# ------------------------------------------------------------------
def add_new_eval(
human_file,
skempi_file,
model_name_textbox: str,
revision_name_textbox: str,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
save,
):
"""Validate inputs, run evaluation and (optionally) save results."""
if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
return -1
if 'affinity' in benchmark_types and skempi_file is None:
gr.Warning("SKEMPI representations are required for affinity benchmark!")
return -1
gr.Info("Your submission is being processed…")
representation_name = model_name_textbox if revision_name_textbox == '' else revision_name_textbox
try:
results = run_probe(
benchmark_types,
representation_name,
human_file,
skempi_file,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
)
except Exception:
gr.Warning("Your submission has not been processed. Please check your representation files!")
return -1
if save:
save_results(representation_name, benchmark_types, results)
gr.Info("Your submission has been processed and results are saved!")
else:
gr.Info("Your submission has been processed!")
return 0
def refresh_data():
"""Re‑start the space and pull fresh leaderboard CSVs from the HF Hub."""
api.restart_space(repo_id=repo_id)
benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
for benchmark_type in benchmark_types:
path = f"/tmp/{benchmark_type}_results.csv"
if os.path.exists(path):
os.remove(path)
benchmark_types.remove("leaderboard")
download_from_hub(benchmark_types)
# ------- Leaderboard helpers -------------------------------------------------
def update_metrics(selected_benchmarks):
"""Populate metric selector according to chosen benchmark types."""
updated_metrics = set()
for benchmark in selected_benchmarks:
updated_metrics.update(benchmark_metric_mapping.get(benchmark, []))
return list(updated_metrics)
def update_leaderboard(selected_methods, selected_metrics):
updated_df = get_baseline_df(selected_methods, selected_metrics)
return updated_df
# ------- Visualisation helpers ----------------------------------------------
def get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric):
"""Return a short natural‑language explanation for the produced plot."""
if benchmark_type == "similarity":
return (
f"The scatter plot compares models on **{x_metric}** (x‑axis) and "
f"**{y_metric}** (y‑axis). Points further to the upper‑right indicate better "
"performance on both metrics."
)
elif benchmark_type == "function":
return (
f"The heat‑map shows performance of each model (columns) across GO terms "
f"for the **{aspect.upper()}** aspect using the **{single_metric}** metric. "
"Darker squares correspond to stronger performance; hierarchical clustering "
"groups similar models and tasks together."
)
elif benchmark_type == "family":
return (
f"The horizontal box‑plots summarise cross‑validation performance on the "
f"**{dataset}** dataset. Higher median MCC values indicate better family‑"
"classification accuracy."
)
elif benchmark_type == "affinity":
return (
f"Each box‑plot shows the distribution of **{single_metric}** scores for every "
"model when predicting binding affinity changes. Higher values are better."
)
return ""
def generate_plot_and_explanation(
benchmark_type,
methods_selected,
x_metric,
y_metric,
aspect,
dataset,
single_metric,
):
"""Callback wrapper that returns both the image path and a textual explanation."""
plot_path = benchmark_plot(
benchmark_type,
methods_selected,
x_metric,
y_metric,
aspect,
dataset,
single_metric,
)
explanation = get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric)
return plot_path, explanation
# ------------------------------------------------------------------
# UI definition
# ------------------------------------------------------------------
block = gr.Blocks()
with block:
gr.Markdown(LEADERBOARD_INTRODUCTION)
with gr.Tabs(elem_classes="tab-buttons") as tabs:
# ------------------------------------------------------------------
# 1️⃣ Leaderboard tab
# ------------------------------------------------------------------
with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
leaderboard = get_baseline_df(None, None) # baseline leaderboard without filtering
method_names = leaderboard['Method'].unique().tolist()
metric_names = leaderboard.columns.tolist()
metric_names.remove('Method') # remove non‑metric column
benchmark_metric_mapping = {
"similarity": [m for m in metric_names if m.startswith('sim_')],
"function": [m for m in metric_names if m.startswith('func')],
"family": [m for m in metric_names if m.startswith('fam_')],
"affinity": [m for m in metric_names if m.startswith('aff_')],
}
# selectors -----------------------------------------------------
leaderboard_method_selector = gr.CheckboxGroup(
choices=method_names,
label="Select Methods for the Leaderboard",
value=method_names,
interactive=True,
)
benchmark_type_selector_lb = gr.CheckboxGroup(
choices=list(benchmark_metric_mapping.keys()),
label="Select Benchmark Types",
value=None,
interactive=True,
)
leaderboard_metric_selector = gr.CheckboxGroup(
choices=metric_names,
label="Select Metrics for the Leaderboard",
value=None,
interactive=True,
)
# leaderboard table --------------------------------------------
baseline_value = get_baseline_df(method_names, metric_names)
baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
baseline_header = ["Method"] + metric_names
baseline_datatype = ['markdown'] + ['number'] * len(metric_names)
with gr.Row(show_progress=True, variant='panel'):
data_component = gr.Dataframe(
value=baseline_value,
headers=baseline_header,
type="pandas",
datatype=baseline_datatype,
interactive=False,
visible=True,
)
# callbacks -----------------------------------------------------
leaderboard_method_selector.change(
get_baseline_df,
inputs=[leaderboard_method_selector, leaderboard_metric_selector],
outputs=data_component,
)
benchmark_type_selector_lb.change(
lambda selected: update_metrics(selected),
inputs=[benchmark_type_selector_lb],
outputs=leaderboard_metric_selector,
)
leaderboard_metric_selector.change(
get_baseline_df,
inputs=[leaderboard_method_selector, leaderboard_metric_selector],
outputs=data_component,
)
# ------------------------------------------------------------------
# 2️⃣ Visualisation tab
# ------------------------------------------------------------------
with gr.TabItem("📊 Visualization", elem_id="probe-benchmark-tab-visualization", id=2):
# Intro / instructions
gr.Markdown(
"""
## **Interactive Visualizations**
Select a benchmark type first; context‑specific options will appear automatically.
Once your parameters are set, click **Plot** to generate the figure.
**How to read the plots**
* **Similarity (scatter)** – Each point is a model. Points nearer the top‑right perform well on both chosen similarity metrics.
* **Function prediction (heat‑map)** – Darker squares denote better scores. Rows/columns are clustered to reveal shared structure.
* **Family / Affinity (boxplots)** – Boxes summarise distribution across folds/targets. Higher medians indicate stronger performance.
""",
elem_classes="markdown-text",
)
# ------------------------------------------------------------------
# selectors specific to visualisation
# ------------------------------------------------------------------
vis_benchmark_type_selector = gr.Dropdown(
choices=list(benchmark_specific_metrics.keys()),
label="Select Benchmark Type",
value=None,
)
with gr.Row():
vis_x_metric_selector = gr.Dropdown(choices=[], label="Select X‑axis Metric", visible=False)
vis_y_metric_selector = gr.Dropdown(choices=[], label="Select Y‑axis Metric", visible=False)
vis_aspect_type_selector = gr.Dropdown(choices=[], label="Select Aspect Type", visible=False)
vis_dataset_selector = gr.Dropdown(choices=[], label="Select Dataset", visible=False)
vis_single_metric_selector = gr.Dropdown(choices=[], label="Select Metric", visible=False)
vis_method_selector = gr.CheckboxGroup(
choices=method_names,
label="Select methods to visualize",
interactive=True,
value=method_names,
)
plot_button = gr.Button("Plot")
with gr.Row(show_progress=True, variant='panel'):
plot_output = gr.Image(label="Plot")
# textual explanation below the image
plot_explanation = gr.Markdown(visible=False)
# ------------------------------------------------------------------
# callbacks for visualisation tab
# ------------------------------------------------------------------
vis_benchmark_type_selector.change(
update_metric_choices,
inputs=[vis_benchmark_type_selector],
outputs=[
vis_x_metric_selector,
vis_y_metric_selector,
vis_aspect_type_selector,
vis_dataset_selector,
vis_single_metric_selector,
],
)
plot_button.click(
generate_plot_and_explanation,
inputs=[
vis_benchmark_type_selector,
vis_method_selector,
vis_x_metric_selector,
vis_y_metric_selector,
vis_aspect_type_selector,
vis_dataset_selector,
vis_single_metric_selector,
],
outputs=[plot_output, plot_explanation],
)
# ------------------------------------------------------------------
# 3️⃣ About tab
# ------------------------------------------------------------------
with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=3):
with gr.Row():
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Image(
value="./src/data/PROBE_workflow_figure.jpg",
label="PROBE Workflow Figure",
elem_classes="about-image",
)
# ------------------------------------------------------------------
# 4️⃣ Submit tab
# ------------------------------------------------------------------
with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4):
with gr.Row():
gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")
with gr.Row():
with gr.Column():
model_name_textbox = gr.Textbox(label="Method name")
revision_name_textbox = gr.Textbox(label="Revision Method Name")
benchmark_types = gr.CheckboxGroup(
choices=TASK_INFO,
label="Benchmark Types",
interactive=True,
)
similarity_tasks = gr.CheckboxGroup(
choices=similarity_tasks_options,
label="Similarity Tasks",
interactive=True,
)
function_prediction_aspect = gr.Radio(
choices=function_prediction_aspect_options,
label="Function Prediction Aspects",
interactive=True,
)
family_prediction_dataset = gr.CheckboxGroup(
choices=family_prediction_dataset_options,
label="Family Prediction Datasets",
interactive=True,
)
function_dataset = gr.Textbox(
label="Function Prediction Datasets",
visible=False,
value="All_Data_Sets",
)
save_checkbox = gr.Checkbox(
label="Save results for leaderboard and visualization",
value=True,
)
with gr.Row():
human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath')
skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath')
submit_button = gr.Button("Submit Eval")
submission_result = gr.Markdown()
submit_button.click(
add_new_eval,
inputs=[
human_file,
skempi_file,
model_name_textbox,
revision_name_textbox,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_dataset,
family_prediction_dataset,
save_checkbox,
],
)
# ----------------------------------------------------------------------
# global refresh button & citation accordion
# ----------------------------------------------------------------------
with gr.Row():
data_run = gr.Button("Refresh")
data_run.click(refresh_data, outputs=[data_component])
with gr.Accordion("Citation", open=False):
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
elem_id="citation-button",
show_copy_button=True,
)
# -----------------------------------------------------------------------------
block.launch()