Spaces:
Sleeping
Sleeping
File size: 6,339 Bytes
ff0c7de fbcd930 1d040cb ff0c7de 963c6da 1d040cb fbcd930 adbb181 12f938b fbcd930 12f938b adbb181 fbcd930 12f938b fbcd930 12f938b fbcd930 ff0c7de adbb181 ff0c7de 12f938b fbcd930 12f938b fbcd930 adbb181 fbcd930 ff0c7de adbb181 ff0c7de fbcd930 ff0c7de fbcd930 1f586be 174296d fbcd930 174296d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from st_aggrid import GridOptionsBuilder, AgGrid
import streamlit as st
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories
from .plot import plot_radar_chart_name, plot_radar_chart_rows
def display_app():
st.markdown("# Open LLM Leaderboard Viz")
st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
st.markdown("To select a model, click on the checkbox beside its name.")
st.markdown("This displays the top 100 models by default, but you can change that using the number input in the sidebar.")
st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.")
st.markdown("If your model doesn't show up, please search it by its name.")
dataframe = load_dataframe()
sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7)
number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
ascending = True
if sort_selection is None:
sort_selection = "model_name"
ascending = True
elif sort_selection == "model_name":
ascending = True
else:
ascending = False
name = st.text_input(label = ":mag: Search by name")
#Sidebar configurations
selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.",
placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU")
ordering_metrics = ordering_metrics.replace(" ", "")
ordering_metrics = ordering_metrics.split(",")
st.sidebar.markdown("""
As a reminder, here are the different metrics:
* ARC
* GSM8K
* TruthfulQA
* Winogrande
* HellaSwag
* MMLU
""")
st.sidebar.markdown("""
If there are **typos** in the name of the metrics, or the number of metrics
is **different of six**, there will be no effect on the chart and the
default ordering will be used.
""")
valid_categories = validate_categories(ordering_metrics)
# Search bar
len_name_input = len(name)
if len_name_input > 0:
dataframe_by_search = search_by_name(name)
if len(dataframe_by_search) > 0:
#st.write("number of model name with name", len(dataframe_by_search))
dataframe = dataframe_by_search
else:
dataframe = load_dataframe()
dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending)
dataframe_display = dataframe.copy()
if len_name_input == 0:
# Show every only top n row
dataframe_display = show_dataframe_top(number_of_row,dataframe_display)
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float)
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100
dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2)
#Infer basic colDefs from dataframe types
gb = GridOptionsBuilder.from_dataframe(dataframe_display)
gb.configure_selection(selection_mode = selection_mode, use_checkbox=True)
gb.configure_grid_options(domLayout='normal')
gridOptions = gb.build()
column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small")
with column1:
grid_response = AgGrid(
dataframe_display,
gridOptions=gridOptions,
height=300,
width='40%'
)
subdata = dataframe.head(1)
if len(subdata) > 0:
model_name = subdata["model_name"].values[0]
else:
model_name = ""
with column2:
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
figure = None
if valid_categories:
figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3], categories = ordering_metrics)
else:
figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3])
st.plotly_chart(figure, use_container_width=False)
else:
if len(subdata)>0:
figure = None
if valid_categories:
figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name)
else:
figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
st.plotly_chart(figure, use_container_width=True)
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1:
n_col = len(grid_response['selected_rows']) if len(grid_response['selected_rows']) <=3 else 3
st.markdown("## Models")
columns = st.columns(n_col)
for i in range(n_col):
with columns[i]:
st.markdown("**Model name:** %s" % grid_response['selected_rows'][i]["model_name"])
elif grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) == 1:
st.markdown("**Model name:** %s" % grid_response['selected_rows'][0]["model_name"])
else:
st.markdown("**Model name:** %s" % model_name) |