File size: 6,339 Bytes
ff0c7de
fbcd930
1d040cb
ff0c7de
963c6da
1d040cb
fbcd930
 
 
 
 
adbb181
 
12f938b
fbcd930
 
 
12f938b
adbb181
fbcd930
12f938b
fbcd930
 
 
 
 
 
 
12f938b
 
fbcd930
ff0c7de
 
adbb181
 
ff0c7de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f938b
 
 
 
 
 
fbcd930
 
 
 
 
12f938b
 
 
 
 
 
fbcd930
 
 
 
 
 
adbb181
fbcd930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff0c7de
 
 
 
 
 
adbb181
ff0c7de
fbcd930
 
ff0c7de
 
 
 
 
 
fbcd930
 
1f586be
 
 
 
 
 
 
174296d
 
fbcd930
174296d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

from st_aggrid import GridOptionsBuilder, AgGrid
import streamlit as st
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories
from .plot import plot_radar_chart_name, plot_radar_chart_rows


def display_app():
    st.markdown("# Open LLM Leaderboard Viz")
    st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
    st.markdown("To select a model, click on the checkbox beside its name.")
    st.markdown("This displays the top 100 models by default, but you can change that using the number input in the sidebar.") 
    st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.")
    st.markdown("If your model doesn't show up, please search it by its name.")

    dataframe = load_dataframe()

    sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7)
    number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
    ascending = True
    
    if sort_selection is None:
        sort_selection = "model_name"
        ascending = True
    elif sort_selection == "model_name":
        ascending = True
    else:
        ascending = False


    name = st.text_input(label = ":mag: Search by name")
    
    #Sidebar configurations
    selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
    st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
    ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.",
                                             placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU")
    
    ordering_metrics = ordering_metrics.replace(" ", "")
    ordering_metrics = ordering_metrics.split(",")

    st.sidebar.markdown("""
                        As a reminder, here are the different metrics:
                        * ARC
                        * GSM8K
                        * TruthfulQA
                        * Winogrande
                        * HellaSwag
                        * MMLU
                        """)
    st.sidebar.markdown("""
                        If there are **typos** in the name of the metrics, or the number of metrics 
                        is **different of six**, there will be no effect on the chart and the 
                        default ordering will be used.
                         """)

    valid_categories = validate_categories(ordering_metrics)
    
    # Search bar
    len_name_input = len(name) 
    if len_name_input > 0:
        dataframe_by_search = search_by_name(name)
        if len(dataframe_by_search) > 0:
            #st.write("number of model name with name", len(dataframe_by_search))
            dataframe = dataframe_by_search
        else:
            dataframe = load_dataframe()

    dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
    dataframe_display = dataframe.copy()
    
    if len_name_input == 0:
        # Show every only top n row
        dataframe_display = show_dataframe_top(number_of_row,dataframe_display)    
        
    
    dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
    dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100
    dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2)

    #Infer basic colDefs from dataframe types
    gb = GridOptionsBuilder.from_dataframe(dataframe_display)
    gb.configure_selection(selection_mode = selection_mode, use_checkbox=True)
    gb.configure_grid_options(domLayout='normal')
    gridOptions = gb.build()

    column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small")

    with column1:
        grid_response = AgGrid(
    dataframe_display, 
    gridOptions=gridOptions,
    height=300, 
    width='40%'
    )
        
    subdata = dataframe.head(1)
    if len(subdata) > 0:
        model_name = subdata["model_name"].values[0]
    else:
        model_name = ""

    with column2:
        if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
            figure = None
            if valid_categories:
                
                figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3], categories = ordering_metrics)
            else:    
                figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3])
            st.plotly_chart(figure, use_container_width=False)
            
        else:
            if len(subdata)>0:
                figure = None
                if valid_categories:
                    figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name)
                else:
                    figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)

                st.plotly_chart(figure, use_container_width=True)

    if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1:
        n_col = len(grid_response['selected_rows']) if len(grid_response['selected_rows']) <=3 else 3
        st.markdown("## Models")
        columns = st.columns(n_col)
        for i in range(n_col):
            with columns[i]:
                st.markdown("**Model name:**   %s" % grid_response['selected_rows'][i]["model_name"])
    elif grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) == 1:
        st.markdown("**Model name:**   %s" % grid_response['selected_rows'][0]["model_name"])
    else:
        st.markdown("**Model name:**   %s" % model_name)