File size: 3,471 Bytes
0d2e03d
fc4805a
61fd7c8
fc4805a
 
 
61fd7c8
fc4805a
 
 
 
 
 
 
 
 
 
 
 
 
 
61fd7c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc4805a
0d2e03d
 
fc4805a
a86b221
fc4805a
61fd7c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc4805a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt

# Function to load data from a given CSV file
def load_data(version):
    file_path = f'versions/{version}.csv'  # Replace with your file paths
    return pd.read_csv(file_path)

# Function for searching in the leaderboard
def search_leaderboard(df, query):
    if query == "":
        return df
    else:
        return df[df['Method'].str.contains(query)]

# Function to change the version of the leaderboard
def change_version(version):
    new_df = load_data(version)
    return new_df

# Function to create plots
def create_plots(df, selected_methods):
    if not selected_methods:
        return plt.figure()  # Return an empty plot if no method is selected

    filtered_df = df[df['Method'].isin(selected_methods)]
    fig, ax = plt.subplots()
    for method in selected_methods:
        method_df = filtered_df[filtered_df['Method'] == method]
        ax.plot(method_df['PPL'], label=method)  # Example: Plotting PPL, replace with your metrics

    ax.set_xlabel('Index')  # Modify as per your data
    ax.set_ylabel('PPL')  # Modify as per your data
    ax.legend()
    return fig

# Initialize Gradio app
demo = gr.Blocks()

with demo:
    gr.Markdown("## πŸ₯‡ TOFU Leaderboard")

    

    with gr.Tabs():
        with gr.TabItem("Leaderboard"):
            with gr.Row():
                version_dropdown = gr.Dropdown(
                    choices=["llama", "phi", "stable-lm"],
                    label="πŸ”„ Select Base Model",
                    value="llama",
                )

            with gr.Row():
                search_bar = gr.Textbox(
                    placeholder="Search for methods...",
                    show_label=False,
                )

            leaderboard_table = gr.components.Dataframe(
                value=load_data("llama"),
                interactive=True,
                visible=True,
            )

            version_dropdown.change(
                change_version,
                inputs=version_dropdown,
                outputs=leaderboard_table
            )

            search_bar.change(
                search_leaderboard,
                inputs=[leaderboard_table, search_bar],
                outputs=leaderboard_table
            )

        with gr.TabItem("Plots"):
            version_dropdown_plots = gr.Dropdown(
                    choices=["llama", "phi", "stable-lm"],
                    label="πŸ”„ Select Base Model",
                    value="llama",
                )

            with gr.Row():
                methods_checkbox = gr.CheckboxGroup(
                    label="Select Methods",
                    choices=[],  # To be populated dynamically
                )

            plot_output = gr.Plot()

            # Dynamically update the choices for the methods checkbox
            def update_method_choices(version):
                df = load_data(version)
                methods = df['Method'].unique()
                methods_checkbox.update(choices=methods)
                return df

            version_dropdown_plots.change(
                update_method_choices,
                inputs=version_dropdown_plots,
                outputs=[methods_checkbox, plot_output]
            )

            methods_checkbox.change(
                create_plots,
                inputs=[methods_checkbox, leaderboard_table],
                outputs=plot_output
            )

# Launch the app
demo.launch()