File size: 15,376 Bytes
1cc2077
 
 
 
 
25f445b
b90013e
dc9c8a6
4826928
df66c51
 
 
1cc2077
b696eae
 
 
 
1cc2077
3e6bf0e
 
a2e6203
1cc2077
edb9d91
6ba85f3
edb9d91
3624a97
1cc2077
 
 
 
 
2781be6
d10decd
 
 
 
72f465f
1cc2077
edb9d91
cdf41df
11d1b83
cdf41df
 
11d1b83
cdf41df
 
edb9d91
b363799
1cc2077
b972165
7f7ea9c
edb9d91
 
 
 
 
 
 
 
 
 
 
 
56d7438
7f7ea9c
72f465f
 
edb9d91
e1bfbc1
edb9d91
7f7ea9c
72f465f
1cc2077
edb9d91
90fcb15
edb9d91
b696eae
37c0c8d
 
 
 
 
 
8cc60a4
b20cd7e
edb9d91
6ba85f3
edb9d91
b696eae
 
 
 
 
edb9d91
 
b696eae
6ba85f3
edb9d91
6ba85f3
edb9d91
 
 
 
6ba85f3
 
edb9d91
6ba85f3
edb9d91
6ba85f3
126b728
edb9d91
6ba85f3
edb9d91
6ba85f3
edb9d91
6ba85f3
edb9d91
6ba85f3
edb9d91
 
 
 
6ba85f3
edb9d91
 
 
 
 
 
 
 
 
 
 
 
6ba85f3
 
 
 
126b728
 
 
6ba85f3
 
 
126b728
6ba85f3
126b728
 
 
6ba85f3
 
126b728
 
 
6ba85f3
 
126b728
 
 
 
 
 
 
6ba85f3
 
 
edb9d91
6ba85f3
 
1cc2077
 
b90013e
edb9d91
1cc2077
edb9d91
 
 
524ef7e
6ba85f3
 
 
 
1963dd6
6ba85f3
 
126b728
5ca345c
1963dd6
5ca345c
 
b696eae
6ba85f3
508ed01
6ba85f3
c0e572f
 
edb9d91
6ba85f3
 
 
c0e572f
edb9d91
51bfc88
edb9d91
6ba85f3
edb9d91
 
51bfc88
b696eae
edb9d91
 
 
 
 
c0e572f
edb9d91
51bfc88
edb9d91
6ba85f3
edb9d91
 
51bfc88
 
 
edb9d91
6ba85f3
51bfc88
 
6962b8e
edb9d91
761c866
 
 
 
 
6ba85f3
126b728
 
761c866
6962b8e
6ba85f3
51bfc88
edb9d91
 
 
51bfc88
edb9d91
 
 
 
c0e572f
51bfc88
edb9d91
 
 
51bfc88
 
edb9d91
6ba85f3
edb9d91
6ba85f3
edb9d91
6ba85f3
 
edb9d91
 
 
 
6ba85f3
edb9d91
 
 
6ba85f3
 
 
 
 
edb9d91
 
6ba85f3
edb9d91
6ba85f3
edb9d91
51bfc88
6962b8e
 
edb9d91
6ba85f3
edb9d91
51bfc88
edb9d91
 
 
 
 
 
 
 
51bfc88
 
edb9d91
 
 
 
 
 
 
 
 
 
 
51bfc88
edb9d91
 
 
 
 
1cc2077
 
fe897f2
 
edb9d91
 
 
fe897f2
edb9d91
 
 
 
 
1cc2077
 
 
 
 
 
edb9d91
 
6ba85f3
 
 
 
 
 
1b91391
edb9d91
 
d280876
b696eae
d280876
b696eae
d280876
 
 
 
 
 
 
 
 
 
 
 
 
1cc2077
6ba85f3
1cc2077
 
b90013e
1cc2077
 
 
 
 
 
 
 
 
6ba85f3
126b728
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
import gradio as gr
import pandas as pd
import re
import os
import json
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine as p9
import sys
sys.path.append('./src')
sys.path.append('.')

from huggingface_hub import HfApi
repo_id = "HUBioDataLab/PROBE"
api = HfApi()

from src.about import *
from src.saving_utils import *
from src.vis_utils import *
from src.bin.PROBE import run_probe

# ------------------------------------------------------------------
# Helper functions --------------------------------------------------
# ------------------------------------------------------------------

def add_new_eval(
    human_file,
    skempi_file,
    model_name_textbox: str,
    revision_name_textbox: str,
    benchmark_types,
    similarity_tasks,
    function_prediction_aspect,
    function_prediction_dataset,
    family_prediction_dataset,
    save,
):
    """Validate inputs, run evaluation and (optionally) save results."""
    if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
        gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
        return -1
    if 'affinity' in benchmark_types and skempi_file is None:
        gr.Warning("SKEMPI representations are required for affinity benchmark!")
        return -1

    gr.Info("Your submission is being processed…")

    representation_name = model_name_textbox if revision_name_textbox == '' else revision_name_textbox

    try:
        results = run_probe(
            benchmark_types,
            representation_name,
            human_file,
            skempi_file,
            similarity_tasks,
            function_prediction_aspect,
            function_prediction_dataset,
            family_prediction_dataset,
        )
    except Exception:
        gr.Warning("Your submission has not been processed. Please check your representation files!")
        return -1

    if save:
        save_results(representation_name, benchmark_types, results)
        gr.Info("Your submission has been processed and results are saved!")
    else:
        gr.Info("Your submission has been processed!")

    return 0


def refresh_data():
    """Re‑start the space and pull fresh leaderboard CSVs from the HF Hub."""
    api.restart_space(repo_id=repo_id)
    benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
    for benchmark_type in benchmark_types:
        path = f"/tmp/{benchmark_type}_results.csv"
        if os.path.exists(path):
            os.remove(path)
    benchmark_types.remove("leaderboard")
    download_from_hub(benchmark_types)


# ------- Leaderboard helpers -----------------------------------------------

def update_metrics(selected_benchmarks):
    updated_metrics = set()
    for benchmark in selected_benchmarks:
        updated_metrics.update(benchmark_metric_mapping.get(benchmark, []))
    return list(updated_metrics)


def update_leaderboard(selected_methods, selected_metrics):
    return get_baseline_df(selected_methods, selected_metrics)

# ------- Visualisation helpers ---------------------------------------------

def get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric):
    if benchmark_type == "similarity":
        return (
            f"Scatter plot compares models on **{x_metric}** (x‑axis) and **{y_metric}** (y‑axis). "
            "Upper‑right points indicate jointly strong performance."
        )
    if benchmark_type == "function":
        return (
            f"Heat‑map shows model scores for **{aspect.upper()}** terms with **{single_metric}**. "
            "Darker squares → better predictions."
        )
    if benchmark_type == "family":
        return (
            f"Box‑plots summarise cross‑fold MCC on **{dataset}**; higher medians are better."
        )
    if benchmark_type == "affinity":
        return (
            f"Box‑plots display distribution of **{single_metric}** scores for affinity prediction; higher values are better."
        )
    return ""


def generate_plot_and_explanation(benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric):
    plot_path = benchmark_plot(
        benchmark_type,
        methods_selected,
        x_metric,
        y_metric,
        aspect,
        dataset,
        single_metric,
    )
    explanation = get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric)
    return plot_path, explanation

# ---------------------------------------------------------------------------
# Custom CSS for frozen first column and clearer table styles
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
/* freeze first column */
#leaderboard-table table tr th:first-child,
#leaderboard-table table tr td:first-child {
  position: sticky;
  left: 0;
  background: white;
  z-index: 2;
}

/* striped rows for readability */
#leaderboard-table table tr:nth-child(odd) {
  background: #fafafa;
}

/* centre numeric cells */
#leaderboard-table td:not(:first-child) {
  text-align: center;
}

/* scrollable and taller table */
#leaderboard-table .dataframe-wrap {
  max-height: 1200px;
  overflow-y: auto;
  overflow-x: auto;
}
"""

# ---------------------------------------------------------------------------
# UI definition
# ---------------------------------------------------------------------------
block = gr.Blocks(css=CUSTOM_CSS)

with block:
    gr.Markdown(LEADERBOARD_INTRODUCTION)

    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        # ------------------------------------------------------------------
        # 1️⃣  Leaderboard tab
        # ------------------------------------------------------------------
        with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
            # small workflow figure at top
            gr.Image(
                value="./src/data/PROBE_workflow_figure.jpg",
                show_label=False,
                height=1000,
                container=False,
            )

            gr.Markdown(
                "## For detailed explanations of the metrics and benchmarks, please refer to the 📝 About tab.",
                elem_classes="leaderboard-note",
            )

            leaderboard = get_baseline_df(None, None)
            method_names = leaderboard['Method'].unique().tolist()
            metric_names = leaderboard.columns.tolist(); metric_names.remove('Method')

            benchmark_metric_mapping = {
                "similarity": [m for m in metric_names if m.startswith('sim_')],
                "function":  [m for m in metric_names if m.startswith('func')],
                "family":    [m for m in metric_names if m.startswith('fam_')],
                "affinity":  [m for m in metric_names if m.startswith('aff_')],
            }

            leaderboard_method_selector = gr.CheckboxGroup(
                choices=method_names,
                label="Select Methods",
                value=method_names,
                interactive=True,
            )

            benchmark_type_selector_lb = gr.CheckboxGroup(
                choices=list(benchmark_metric_mapping.keys()),
                label="Select Benchmark Types",
                value=None,
                interactive=True,
            )

            leaderboard_metric_selector = gr.CheckboxGroup(
                choices=metric_names,
                label="Select Metrics",
                value=None,
                interactive=True,
            )

            baseline_value = get_baseline_df(method_names, metric_names)
            baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
            baseline_header  = ["Method"] + metric_names
            baseline_datatype = ['markdown'] + ['number'] * len(metric_names)

            with gr.Row(show_progress=True, variant='panel'):
                data_component = gr.Dataframe(
                    value=baseline_value,
                    headers=baseline_header,
                    type="pandas",
                    datatype=baseline_datatype,
                    interactive=False,
                    elem_id="leaderboard-table",
                    pinned_columns=1,
                    max_height=1000,
                )

            # callbacks
            leaderboard_method_selector.change(
                get_baseline_df,
                inputs=[leaderboard_method_selector, leaderboard_metric_selector],
                outputs=data_component,
            )
            benchmark_type_selector_lb.change(
                lambda selected: update_metrics(selected),
                inputs=[benchmark_type_selector_lb],
                outputs=leaderboard_metric_selector,
            )
            leaderboard_metric_selector.change(
                get_baseline_df,
                inputs=[leaderboard_method_selector, leaderboard_metric_selector],
                outputs=data_component,
            )

        # ------------------------------------------------------------------
        # 2️⃣  Visualisation tab
        # ------------------------------------------------------------------
        with gr.TabItem("📊 Visualizations", elem_id="probe-benchmark-tab-visualization", id=2):
            gr.Markdown(
                """## **Interactive Visualizations**  
                Choose a benchmark type; context‑specific options will appear. Click **Plot** and an explanation will follow the figure.""",
                elem_classes="markdown-text",
            )
            vis_benchmark_type_selector = gr.Dropdown(
                choices=list(benchmark_specific_metrics.keys()),
                label="Benchmark Type",
                value=None,
            )
            with gr.Row():
                vis_x_metric_selector = gr.Dropdown(choices=[], label="X‑axis Metric", visible=False)
                vis_y_metric_selector = gr.Dropdown(choices=[], label="Y‑axis Metric", visible=False)
                vis_aspect_type_selector = gr.Dropdown(choices=[], label="Aspect", visible=False)
                vis_dataset_selector = gr.Dropdown(choices=[], label="Dataset", visible=False)
                vis_single_metric_selector = gr.Dropdown(choices=[], label="Metric", visible=False)
            vis_method_selector = gr.CheckboxGroup(
                choices=method_names,
                label="Methods",
                value=method_names,
                interactive=True,
            )
            plot_button = gr.Button("Plot")
            with gr.Row(show_progress=True, variant='panel'):
                plot_output = gr.Image(label="Plot")
            plot_explanation = gr.Markdown(visible=False)
            # callbacks
            vis_benchmark_type_selector.change(
                update_metric_choices,
                inputs=[vis_benchmark_type_selector],
                outputs=[
                    vis_x_metric_selector,
                    vis_y_metric_selector,
                    vis_aspect_type_selector,
                    vis_dataset_selector,
                    vis_single_metric_selector,
                ],
            )
            plot_button.click(
                generate_plot_and_explanation,
                inputs=[
                    vis_benchmark_type_selector,
                    vis_method_selector,
                    vis_x_metric_selector,
                    vis_y_metric_selector,
                    vis_aspect_type_selector,
                    vis_dataset_selector,
                    vis_single_metric_selector,
                ],
                outputs=[plot_output, plot_explanation],
            )

        # ------------------------------------------------------------------
        # 3️⃣  About tab
        # ------------------------------------------------------------------
        with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=3):
            with gr.Row():
                gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
            with gr.Row():
                gr.Image(
                    value="./src/data/PROBE_workflow_figure.jpg",
                    label="PROBE Workflow Figure",
                    elem_classes="about-image",
                )

        # ------------------------------------------------------------------
        # 4️⃣  Submit tab
        # ------------------------------------------------------------------
        with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4):
            with gr.Row():
                gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
            with gr.Row():
                gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")
            with gr.Row():
                with gr.Column():
                    model_name_textbox = gr.Textbox(label="Method name")
                    revision_name_textbox = gr.Textbox(label="Revision Method Name")
                    benchmark_types = gr.CheckboxGroup(choices=TASK_INFO, label="Benchmark Types", interactive=True)
                    similarity_tasks = gr.CheckboxGroup(choices=similarity_tasks_options, label="Similarity Tasks", interactive=True)
                    function_prediction_aspect = gr.Radio(choices=function_prediction_aspect_options, label="Function Prediction Aspects", interactive=True)
                    family_prediction_dataset = gr.CheckboxGroup(choices=family_prediction_dataset_options, label="Family Prediction Datasets", interactive=True)
                    function_dataset = gr.Textbox(label="Function Prediction Datasets", visible=False, value="All_Data_Sets")
                    save_checkbox = gr.Checkbox(label="Save results for leaderboard and visualization", value=True)
            with gr.Row():
                human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath')
                skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath')
            submit_button = gr.Button("Submit Eval")
            submission_result = gr.Markdown()
            submit_button.click(
                add_new_eval,
                inputs=[
                    human_file,
                    skempi_file,
                    model_name_textbox,
                    revision_name_textbox,
                    benchmark_types,
                    similarity_tasks,
                    function_prediction_aspect,
                    function_dataset,
                    family_prediction_dataset,
                    save_checkbox,
                ],
            )

    # global refresh + citation ---------------------------------------------
    with gr.Row():
        data_run = gr.Button("Refresh")
        data_run.click(refresh_data, outputs=[data_component])

    with gr.Accordion("Citation", open=False):
        citation_button = gr.Textbox(
            value=CITATION_BUTTON_TEXT,
            label=CITATION_BUTTON_LABEL,
            elem_id="citation-button",
            show_copy_button=True,
        )

# ---------------------------------------------------------------------------
block.launch()