Adding multiple models plotting and sidebar
Browse files- src/display.py +8 -6
- src/plot.py +14 -6
    	
        src/display.py
    CHANGED
    
    | @@ -12,14 +12,14 @@ def display_app(): | |
| 12 | 
             
                st.markdown("# Open LLM Leaderboard Viz")
         | 
| 13 | 
             
                st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
         | 
| 14 | 
             
                st.markdown("To select a model, click on the checkbox beside its name.")
         | 
| 15 | 
            -
                st.markdown("This displays the top 100 models by default, but you can change that using the number input  | 
| 16 | 
            -
                st.markdown("By  | 
| 17 | 
             
                st.markdown("If your model doesn't show up, please search it by its name.")
         | 
| 18 |  | 
| 19 | 
             
                dataframe = load_dataframe()
         | 
| 20 |  | 
| 21 | 
             
                sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7)
         | 
| 22 | 
            -
                number_of_row = st.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
         | 
| 23 | 
             
                ascending = True
         | 
| 24 |  | 
| 25 | 
             
                if sort_selection is None:
         | 
| @@ -32,6 +32,8 @@ def display_app(): | |
| 32 |  | 
| 33 |  | 
| 34 | 
             
                name = st.text_input(label = ":mag: Search by name")
         | 
|  | |
|  | |
| 35 | 
             
                len_name_input = len(name) 
         | 
| 36 | 
             
                if len_name_input > 0:
         | 
| 37 | 
             
                    dataframe_by_search = search_by_name(name)
         | 
| @@ -55,7 +57,7 @@ def display_app(): | |
| 55 |  | 
| 56 | 
             
                #Infer basic colDefs from dataframe types
         | 
| 57 | 
             
                gb = GridOptionsBuilder.from_dataframe(dataframe_display)
         | 
| 58 | 
            -
                gb.configure_selection(selection_mode =  | 
| 59 | 
             
                gb.configure_grid_options(domLayout='normal')
         | 
| 60 | 
             
                gridOptions = gb.build()
         | 
| 61 |  | 
| @@ -77,9 +79,9 @@ def display_app(): | |
| 77 |  | 
| 78 | 
             
                with column2:
         | 
| 79 | 
             
                    if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
         | 
| 80 | 
            -
                        figure = plot_radar_chart_rows(rows=grid_response['selected_rows'])
         | 
| 81 | 
             
                        #figure = plot_radar_chart_name(dataframe= dataframe, model_name=grid_response['selected_rows'][0]["model_name"])
         | 
| 82 | 
            -
                        st.plotly_chart(figure, use_container_width= | 
| 83 | 
             
                    else:
         | 
| 84 | 
             
                        if len(subdata)>0:
         | 
| 85 | 
             
                            figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
         | 
|  | |
| 12 | 
             
                st.markdown("# Open LLM Leaderboard Viz")
         | 
| 13 | 
             
                st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)")
         | 
| 14 | 
             
                st.markdown("To select a model, click on the checkbox beside its name.")
         | 
| 15 | 
            +
                st.markdown("This displays the top 100 models by default, but you can change that using the number input in the sidebar.") 
         | 
| 16 | 
            +
                st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.")
         | 
| 17 | 
             
                st.markdown("If your model doesn't show up, please search it by its name.")
         | 
| 18 |  | 
| 19 | 
             
                dataframe = load_dataframe()
         | 
| 20 |  | 
| 21 | 
             
                sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns), index = 7)
         | 
| 22 | 
            +
                number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100)
         | 
| 23 | 
             
                ascending = True
         | 
| 24 |  | 
| 25 | 
             
                if sort_selection is None:
         | 
|  | |
| 32 |  | 
| 33 |  | 
| 34 | 
             
                name = st.text_input(label = ":mag: Search by name")
         | 
| 35 | 
            +
                selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
         | 
| 36 | 
            +
                st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
         | 
| 37 | 
             
                len_name_input = len(name) 
         | 
| 38 | 
             
                if len_name_input > 0:
         | 
| 39 | 
             
                    dataframe_by_search = search_by_name(name)
         | 
|  | |
| 57 |  | 
| 58 | 
             
                #Infer basic colDefs from dataframe types
         | 
| 59 | 
             
                gb = GridOptionsBuilder.from_dataframe(dataframe_display)
         | 
| 60 | 
            +
                gb.configure_selection(selection_mode = selection_mode, use_checkbox=True)
         | 
| 61 | 
             
                gb.configure_grid_options(domLayout='normal')
         | 
| 62 | 
             
                gridOptions = gb.build()
         | 
| 63 |  | 
|  | |
| 79 |  | 
| 80 | 
             
                with column2:
         | 
| 81 | 
             
                    if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
         | 
| 82 | 
            +
                        figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3])
         | 
| 83 | 
             
                        #figure = plot_radar_chart_name(dataframe= dataframe, model_name=grid_response['selected_rows'][0]["model_name"])
         | 
| 84 | 
            +
                        st.plotly_chart(figure, use_container_width=False)
         | 
| 85 | 
             
                    else:
         | 
| 86 | 
             
                        if len(subdata)>0:
         | 
| 87 | 
             
                            figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
         | 
    	
        src/plot.py
    CHANGED
    
    | @@ -7,6 +7,8 @@ import pandas as pd | |
| 7 | 
             
            fillcolor = "#FFD21E"
         | 
| 8 | 
             
            line_color = "#FF9D00"
         | 
| 9 |  | 
|  | |
|  | |
| 10 | 
             
            # opacity of the plot
         | 
| 11 | 
             
            opacity = 0.75
         | 
| 12 |  | 
| @@ -109,7 +111,7 @@ def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: | |
| 109 |  | 
| 110 |  | 
| 111 | 
             
            #@st.cache_data
         | 
| 112 | 
            -
            def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories,  | 
| 113 | 
             
                """
         | 
| 114 | 
             
                plot the results of the model selected by the checkbox
         | 
| 115 |  | 
| @@ -124,22 +126,27 @@ def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list | |
| 124 | 
             
                dataset = pd.DataFrame(rows, columns=columns)
         | 
| 125 | 
             
                data = dataset[categories].to_numpy()
         | 
| 126 | 
             
                data  = data.astype(float)
         | 
|  | |
|  | |
|  | |
|  | |
| 127 |  | 
| 128 | 
             
                # add data to close the area of the radar chart
         | 
| 129 | 
             
                data = np.append(data, data[:,0].reshape((-1,1)), axis=1)
         | 
| 130 | 
             
                categories_theta = categories.copy()
         | 
| 131 | 
             
                categories_theta.append(categories[0])
         | 
| 132 |  | 
| 133 | 
            -
                 | 
| 134 | 
             
                for i in range(len(dataset)):
         | 
| 135 | 
            -
             | 
|  | |
| 136 | 
             
                  fig.add_trace(go.Scatterpolar(
         | 
| 137 | 
             
                        r=data[i,:],
         | 
| 138 | 
             
                        theta=categories_theta,
         | 
| 139 | 
             
                        fill='toself',
         | 
| 140 | 
            -
                        fillcolor =  | 
| 141 | 
             
                        opacity = opacity,
         | 
| 142 | 
            -
                        line=dict(color =  | 
| 143 | 
             
                        name= dataset.loc[i,"model_name"]
         | 
| 144 | 
             
                  ))
         | 
| 145 | 
             
                  fig.update_layout(
         | 
| @@ -148,7 +155,8 @@ def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list | |
| 148 | 
             
                        visible=True,
         | 
| 149 | 
             
                        range=[0, 100.]
         | 
| 150 | 
             
                      )),
         | 
| 151 | 
            -
                    showlegend= | 
| 152 | 
             
                  )
         | 
|  | |
| 153 |  | 
| 154 | 
             
                return fig
         | 
|  | |
| 7 | 
             
            fillcolor = "#FFD21E"
         | 
| 8 | 
             
            line_color = "#FF9D00"
         | 
| 9 |  | 
| 10 | 
            +
            fill_color_list = [fillcolor, "#F05998", "#40BAF0"]
         | 
| 11 | 
            +
            line_color_list = [line_color, "#5E233C", "#194A5E"]
         | 
| 12 | 
             
            # opacity of the plot
         | 
| 13 | 
             
            opacity = 0.75
         | 
| 14 |  | 
|  | |
| 111 |  | 
| 112 |  | 
| 113 | 
             
            #@st.cache_data
         | 
| 114 | 
            +
            def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories, fillcolor_list: str = fill_color_list, line_color_list:str = line_color_list):
         | 
| 115 | 
             
                """
         | 
| 116 | 
             
                plot the results of the model selected by the checkbox
         | 
| 117 |  | 
|  | |
| 126 | 
             
                dataset = pd.DataFrame(rows, columns=columns)
         | 
| 127 | 
             
                data = dataset[categories].to_numpy()
         | 
| 128 | 
             
                data  = data.astype(float)
         | 
| 129 | 
            +
                showLegend = False
         | 
| 130 | 
            +
                if len(rows) > 1:
         | 
| 131 | 
            +
                    showLegend = True
         | 
| 132 | 
            +
             | 
| 133 |  | 
| 134 | 
             
                # add data to close the area of the radar chart
         | 
| 135 | 
             
                data = np.append(data, data[:,0].reshape((-1,1)), axis=1)
         | 
| 136 | 
             
                categories_theta = categories.copy()
         | 
| 137 | 
             
                categories_theta.append(categories[0])
         | 
| 138 |  | 
| 139 | 
            +
                opacity = 0.75
         | 
| 140 | 
             
                for i in range(len(dataset)):
         | 
| 141 | 
            +
                  colors = fillcolor_list[i]
         | 
| 142 | 
            +
                  
         | 
| 143 | 
             
                  fig.add_trace(go.Scatterpolar(
         | 
| 144 | 
             
                        r=data[i,:],
         | 
| 145 | 
             
                        theta=categories_theta,
         | 
| 146 | 
             
                        fill='toself',
         | 
| 147 | 
            +
                        fillcolor = colors,
         | 
| 148 | 
             
                        opacity = opacity,
         | 
| 149 | 
            +
                        line=dict(color = line_color_list[i]),
         | 
| 150 | 
             
                        name= dataset.loc[i,"model_name"]
         | 
| 151 | 
             
                  ))
         | 
| 152 | 
             
                  fig.update_layout(
         | 
|  | |
| 155 | 
             
                        visible=True,
         | 
| 156 | 
             
                        range=[0, 100.]
         | 
| 157 | 
             
                      )),
         | 
| 158 | 
            +
                    showlegend=showLegend
         | 
| 159 | 
             
                  )
         | 
| 160 | 
            +
                  opacity -= .2
         | 
| 161 |  | 
| 162 | 
             
                return fig
         | 
