File size: 16,975 Bytes
1cc2077
 
 
 
 
25f445b
b90013e
dc9c8a6
4826928
df66c51
 
 
1cc2077
b696eae
 
 
 
1cc2077
3e6bf0e
 
a2e6203
1cc2077
edb9d91
 
 
3624a97
1cc2077
 
 
 
 
2781be6
d10decd
 
 
 
72f465f
1cc2077
edb9d91
cdf41df
11d1b83
cdf41df
b696eae
cdf41df
11d1b83
cdf41df
 
edb9d91
b363799
1cc2077
b972165
7f7ea9c
edb9d91
 
 
 
 
 
 
 
 
 
 
 
56d7438
7f7ea9c
72f465f
 
edb9d91
e1bfbc1
edb9d91
7f7ea9c
72f465f
1cc2077
edb9d91
90fcb15
edb9d91
b696eae
37c0c8d
b696eae
37c0c8d
 
 
 
b696eae
37c0c8d
8cc60a4
b20cd7e
edb9d91
 
 
b696eae
edb9d91
b696eae
 
 
 
edb9d91
 
b696eae
 
 
edb9d91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56d7438
1cc2077
 
b90013e
edb9d91
1cc2077
edb9d91
 
 
524ef7e
edb9d91
b696eae
508ed01
90fcb15
edb9d91
c0e572f
 
edb9d91
 
 
 
c0e572f
edb9d91
 
51bfc88
edb9d91
 
 
 
51bfc88
b696eae
edb9d91
 
 
 
 
c0e572f
edb9d91
51bfc88
edb9d91
 
 
 
51bfc88
 
edb9d91
51bfc88
edb9d91
508ed01
51bfc88
 
6962b8e
edb9d91
761c866
 
 
 
 
 
 
6962b8e
edb9d91
51bfc88
edb9d91
 
 
51bfc88
edb9d91
 
 
 
 
c0e572f
b696eae
51bfc88
edb9d91
 
 
51bfc88
 
edb9d91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b696eae
edb9d91
 
 
 
 
 
 
 
b696eae
edb9d91
 
 
 
 
 
 
 
 
 
 
 
 
b696eae
51bfc88
b696eae
6962b8e
 
edb9d91
 
 
 
 
 
 
 
51bfc88
edb9d91
 
 
 
 
 
 
 
51bfc88
edb9d91
51bfc88
edb9d91
 
 
 
 
 
 
 
 
 
 
51bfc88
edb9d91
 
 
 
 
1cc2077
 
fe897f2
 
edb9d91
 
 
fe897f2
edb9d91
 
 
 
 
1cc2077
 
b696eae
1cc2077
 
b696eae
1cc2077
 
edb9d91
 
b696eae
72f465f
1cc2077
72f465f
1cc2077
 
c806fef
 
b696eae
c806fef
 
edb9d91
c806fef
 
b696eae
c806fef
 
edb9d91
c806fef
 
b696eae
c806fef
 
edb9d91
6323d6b
72f465f
6323d6b
edb9d91
6323d6b
b696eae
72f465f
 
edb9d91
72f465f
b696eae
1b91391
edb9d91
 
 
d280876
b696eae
d280876
b696eae
d280876
 
 
 
 
 
 
 
 
 
 
 
 
1cc2077
edb9d91
 
 
1cc2077
 
b90013e
1cc2077
 
 
 
 
 
 
 
 
edb9d91
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import gradio as gr
import pandas as pd
import re
import os
import json
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine as p9
import sys
sys.path.append('./src')
sys.path.append('.')

from huggingface_hub import HfApi
repo_id = "HUBioDataLab/PROBE"
api = HfApi()

from src.about import *
from src.saving_utils import *
from src.vis_utils import *
from src.bin.PROBE import run_probe

# ------------------------------------------------------------------
# Helper functions moved / added here so that UI callbacks can see them
# ------------------------------------------------------------------

def add_new_eval(
    human_file,
    skempi_file,
    model_name_textbox: str,
    revision_name_textbox: str,
    benchmark_types,
    similarity_tasks,
    function_prediction_aspect,
    function_prediction_dataset,
    family_prediction_dataset,
    save,
):
    """Validate inputs, run evaluation and (optionally) save results."""
    if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
        gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
        return -1
    
    if 'affinity' in benchmark_types and skempi_file is None:
        gr.Warning("SKEMPI representations are required for affinity benchmark!")
        return -1

    gr.Info("Your submission is being processed…")

    representation_name = model_name_textbox if revision_name_textbox == '' else revision_name_textbox

    try:
        results = run_probe(
            benchmark_types,
            representation_name,
            human_file,
            skempi_file,
            similarity_tasks,
            function_prediction_aspect,
            function_prediction_dataset,
            family_prediction_dataset,
        )
    except Exception:
        gr.Warning("Your submission has not been processed. Please check your representation files!")
        return -1

    if save:
        save_results(representation_name, benchmark_types, results)
        gr.Info("Your submission has been processed and results are saved!")
    else:
        gr.Info("Your submission has been processed!")

    return 0


def refresh_data():
    """Re‑start the space and pull fresh leaderboard CSVs from the HF Hub."""
    api.restart_space(repo_id=repo_id)
    benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]

    for benchmark_type in benchmark_types:
        path = f"/tmp/{benchmark_type}_results.csv"
        if os.path.exists(path):
            os.remove(path)

    benchmark_types.remove("leaderboard")
    download_from_hub(benchmark_types)


# ------- Leaderboard helpers -------------------------------------------------

def update_metrics(selected_benchmarks):
    """Populate metric selector according to chosen benchmark types."""
    updated_metrics = set()
    for benchmark in selected_benchmarks:
        updated_metrics.update(benchmark_metric_mapping.get(benchmark, []))
    return list(updated_metrics)


def update_leaderboard(selected_methods, selected_metrics):
    updated_df = get_baseline_df(selected_methods, selected_metrics)
    return updated_df

# ------- Visualisation helpers ----------------------------------------------

def get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric):
    """Return a short natural‑language explanation for the produced plot."""
    if benchmark_type == "similarity":
        return (
            f"The scatter plot compares models on **{x_metric}** (x‑axis) and "
            f"**{y_metric}** (y‑axis). Points further to the upper‑right indicate better "
            "performance on both metrics."
        )
    elif benchmark_type == "function":
        return (
            f"The heat‑map shows performance of each model (columns) across GO terms "
            f"for the **{aspect.upper()}** aspect using the **{single_metric}** metric. "
            "Darker squares correspond to stronger performance; hierarchical clustering "
            "groups similar models and tasks together."
        )
    elif benchmark_type == "family":
        return (
            f"The horizontal box‑plots summarise cross‑validation performance on the "
            f"**{dataset}** dataset. Higher median MCC values indicate better family‑"
            "classification accuracy."
        )
    elif benchmark_type == "affinity":
        return (
            f"Each box‑plot shows the distribution of **{single_metric}** scores for every "
            "model when predicting binding affinity changes. Higher values are better."
        )
    return ""


def generate_plot_and_explanation(
    benchmark_type,
    methods_selected,
    x_metric,
    y_metric,
    aspect,
    dataset,
    single_metric,
):
    """Callback wrapper that returns both the image path and a textual explanation."""
    plot_path = benchmark_plot(
        benchmark_type,
        methods_selected,
        x_metric,
        y_metric,
        aspect,
        dataset,
        single_metric,
    )
    explanation = get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric)
    return plot_path, explanation

# ------------------------------------------------------------------
# UI definition
# ------------------------------------------------------------------
block = gr.Blocks()

with block:
    gr.Markdown(LEADERBOARD_INTRODUCTION)

    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        # ------------------------------------------------------------------
        # 1️⃣  Leaderboard tab
        # ------------------------------------------------------------------
        with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
            leaderboard = get_baseline_df(None, None)  # baseline leaderboard without filtering

            method_names = leaderboard['Method'].unique().tolist()
            metric_names = leaderboard.columns.tolist()
            metric_names.remove('Method')  # remove non‑metric column

            benchmark_metric_mapping = {
                "similarity": [m for m in metric_names if m.startswith('sim_')],
                "function": [m for m in metric_names if m.startswith('func')],
                "family": [m for m in metric_names if m.startswith('fam_')],
                "affinity": [m for m in metric_names if m.startswith('aff_')],
            }

            # selectors -----------------------------------------------------
            leaderboard_method_selector = gr.CheckboxGroup(
                choices=method_names,
                label="Select Methods for the Leaderboard",
                value=method_names,
                interactive=True,
            )

            benchmark_type_selector_lb = gr.CheckboxGroup(
                choices=list(benchmark_metric_mapping.keys()),
                label="Select Benchmark Types",
                value=None,
                interactive=True,
            )

            leaderboard_metric_selector = gr.CheckboxGroup(
                choices=metric_names,
                label="Select Metrics for the Leaderboard",
                value=None,
                interactive=True,
            )

            # leaderboard table --------------------------------------------
            baseline_value = get_baseline_df(method_names, metric_names)
            baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
            baseline_header = ["Method"] + metric_names
            baseline_datatype = ['markdown'] + ['number'] * len(metric_names)

            with gr.Row(show_progress=True, variant='panel'):
                data_component = gr.Dataframe(
                    value=baseline_value,
                    headers=baseline_header,
                    type="pandas",
                    datatype=baseline_datatype,
                    interactive=False,
                    visible=True,
                )

            # callbacks -----------------------------------------------------
            leaderboard_method_selector.change(
                get_baseline_df,
                inputs=[leaderboard_method_selector, leaderboard_metric_selector],
                outputs=data_component,
            )

            benchmark_type_selector_lb.change(
                lambda selected: update_metrics(selected),
                inputs=[benchmark_type_selector_lb],
                outputs=leaderboard_metric_selector,
            )

            leaderboard_metric_selector.change(
                get_baseline_df,
                inputs=[leaderboard_method_selector, leaderboard_metric_selector],
                outputs=data_component,
            )

        # ------------------------------------------------------------------
        # 2️⃣ Visualisation tab
        # ------------------------------------------------------------------
        with gr.TabItem("📊 Visualization", elem_id="probe-benchmark-tab-visualization", id=2):
            # Intro / instructions
            gr.Markdown(
                """
                ## **Interactive Visualizations**
                Select a benchmark type first; context‑specific options will appear automatically.  
                Once your parameters are set, click **Plot** to generate the figure.

                **How to read the plots**
                * **Similarity (scatter)** – Each point is a model. Points nearer the top‑right perform well on both chosen similarity metrics.
                * **Function prediction (heat‑map)** – Darker squares denote better scores. Rows/columns are clustered to reveal shared structure.
                * **Family / Affinity (boxplots)** – Boxes summarise distribution across folds/targets. Higher medians indicate stronger performance.
                """,
                elem_classes="markdown-text",
            )

            # ------------------------------------------------------------------
            # selectors specific to visualisation
            # ------------------------------------------------------------------
            vis_benchmark_type_selector = gr.Dropdown(
                choices=list(benchmark_specific_metrics.keys()),
                label="Select Benchmark Type",
                value=None,
            )

            with gr.Row():
                vis_x_metric_selector = gr.Dropdown(choices=[], label="Select X‑axis Metric", visible=False)
                vis_y_metric_selector = gr.Dropdown(choices=[], label="Select Y‑axis Metric", visible=False)
                vis_aspect_type_selector = gr.Dropdown(choices=[], label="Select Aspect Type", visible=False)
                vis_dataset_selector = gr.Dropdown(choices=[], label="Select Dataset", visible=False)
                vis_single_metric_selector = gr.Dropdown(choices=[], label="Select Metric", visible=False)

            vis_method_selector = gr.CheckboxGroup(
                choices=method_names,
                label="Select methods to visualize",
                interactive=True,
                value=method_names,
            )

            plot_button = gr.Button("Plot")

            with gr.Row(show_progress=True, variant='panel'):
                plot_output = gr.Image(label="Plot")

            # textual explanation below the image
            plot_explanation = gr.Markdown(visible=False)

            # ------------------------------------------------------------------
            # callbacks for visualisation tab
            # ------------------------------------------------------------------
            vis_benchmark_type_selector.change(
                update_metric_choices,
                inputs=[vis_benchmark_type_selector],
                outputs=[
                    vis_x_metric_selector,
                    vis_y_metric_selector,
                    vis_aspect_type_selector,
                    vis_dataset_selector,
                    vis_single_metric_selector,
                ],
            )

            plot_button.click(
                generate_plot_and_explanation,
                inputs=[
                    vis_benchmark_type_selector,
                    vis_method_selector,
                    vis_x_metric_selector,
                    vis_y_metric_selector,
                    vis_aspect_type_selector,
                    vis_dataset_selector,
                    vis_single_metric_selector,
                ],
                outputs=[plot_output, plot_explanation],
            )

        # ------------------------------------------------------------------
        # 3️⃣  About tab
        # ------------------------------------------------------------------
        with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=3):
            with gr.Row():
                gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
            with gr.Row():
                gr.Image(
                    value="./src/data/PROBE_workflow_figure.jpg",
                    label="PROBE Workflow Figure",
                    elem_classes="about-image",
                )

        # ------------------------------------------------------------------
        # 4️⃣  Submit tab
        # ------------------------------------------------------------------
        with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4):
            with gr.Row():
                gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")

            with gr.Row():
                gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")

            with gr.Row():
                with gr.Column():
                    model_name_textbox = gr.Textbox(label="Method name")
                    revision_name_textbox = gr.Textbox(label="Revision Method Name")
                    
                    benchmark_types = gr.CheckboxGroup(
                        choices=TASK_INFO,
                        label="Benchmark Types",
                        interactive=True,
                    )
                    similarity_tasks = gr.CheckboxGroup(
                        choices=similarity_tasks_options,
                        label="Similarity Tasks",
                        interactive=True,
                    )
                    
                    function_prediction_aspect = gr.Radio(
                        choices=function_prediction_aspect_options,
                        label="Function Prediction Aspects",
                        interactive=True,
                    )
                    
                    family_prediction_dataset = gr.CheckboxGroup(
                        choices=family_prediction_dataset_options,
                        label="Family Prediction Datasets",
                        interactive=True,
                    )
                    
                    function_dataset = gr.Textbox(
                        label="Function Prediction Datasets",
                        visible=False,
                        value="All_Data_Sets",
                    )

                    save_checkbox = gr.Checkbox(
                        label="Save results for leaderboard and visualization",
                        value=True,
                    )

            with gr.Row():
                human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath')
                skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath')

            submit_button = gr.Button("Submit Eval")
            submission_result = gr.Markdown()
            submit_button.click(
                add_new_eval,
                inputs=[
                    human_file,
                    skempi_file,
                    model_name_textbox,
                    revision_name_textbox,
                    benchmark_types,
                    similarity_tasks,
                    function_prediction_aspect,
                    function_dataset,
                    family_prediction_dataset,
                    save_checkbox,
                ],
            )

    # ----------------------------------------------------------------------
    # global refresh button & citation accordion
    # ----------------------------------------------------------------------
    with gr.Row():
        data_run = gr.Button("Refresh")
        data_run.click(refresh_data, outputs=[data_component])

    with gr.Accordion("Citation", open=False):
        citation_button = gr.Textbox(
            value=CITATION_BUTTON_TEXT,
            label=CITATION_BUTTON_LABEL,
            elem_id="citation-button",
            show_copy_button=True,
        )

# -----------------------------------------------------------------------------
block.launch()