DontPlanToEnd commited on
Commit
fd2eab2
Β·
verified Β·
1 Parent(s): 1701655

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -28
app.py CHANGED
@@ -109,32 +109,36 @@ def load_leaderboard_data(csv_file_path):
109
  return pd.DataFrame(columns=UGI_COLS + WRITING_STYLE_COLS + ANIME_RATING_COLS + ADDITIONAL_COLS)
110
 
111
  # Update the leaderboard table based on the search query and parameter range filters
112
- def update_table(df: pd.DataFrame, query: str, param_ranges: list, columns: list, w10_range: tuple, additional_cols: list) -> pd.DataFrame:
113
  filtered_df = df.copy()
 
 
114
  if param_ranges:
115
- param_mask = pd.Series(False, index=filtered_df.index)
116
  for param_range in param_ranges:
117
  if param_range == '~2':
118
- param_mask |= (filtered_df['Total Params'] < 2.5)
119
  elif param_range == '~4':
120
- param_mask |= ((filtered_df['Total Params'] >= 2.5) & (filtered_df['Total Params'] < 6))
121
  elif param_range == '~8':
122
- param_mask |= ((filtered_df['Total Params'] >= 6) & (filtered_df['Total Params'] < 9.5))
123
  elif param_range == '~13':
124
- param_mask |= ((filtered_df['Total Params'] >= 9.5) & (filtered_df['Total Params'] < 16))
125
  elif param_range == '~20':
126
- param_mask |= ((filtered_df['Total Params'] >= 16) & (filtered_df['Total Params'] < 28))
127
  elif param_range == '~34':
128
- param_mask |= ((filtered_df['Total Params'] >= 28) & (filtered_df['Total Params'] < 40))
129
  elif param_range == '~50':
130
- param_mask |= ((filtered_df['Total Params'] >= 40) & (filtered_df['Total Params'] < 65))
131
  elif param_range == '~70+':
132
- param_mask |= (filtered_df['Total Params'] >= 65)
133
  elif param_range == 'Closed':
134
- param_mask |= filtered_df['Total Params'].isna()
135
- elif param_range == 'Foundation':
136
- param_mask |= (filtered_df['Foundation'] == 1)
137
- filtered_df = filtered_df[param_mask]
 
 
138
 
139
  if query:
140
  filtered_df = filtered_df[filtered_df['Model'].str.contains(query, case=False, na=False)]
@@ -179,15 +183,22 @@ with GraInter:
179
  with gr.Row():
180
  search_bar = gr.Textbox(placeholder=" πŸ” Search for a model...", show_label=False, elem_id="search-bar")
181
  with gr.Row():
182
- with gr.Column(scale=5):
183
  filter_columns_size = gr.CheckboxGroup(
184
  label="Model sizes (in billions of parameters)",
185
- choices=['~2', '~4', '~8', '~13', '~20', '~34', '~50', '~70+', 'Closed', 'Foundation'],
186
  value=[],
187
  interactive=True,
188
  elem_id="filter-columns-size",
189
  )
190
- with gr.Column(scale=2):
 
 
 
 
 
 
 
191
  w10_range = RangeSlider(minimum=0, maximum=10, value=(0, 10), step=0.1, label="W/10 Range")
192
  with gr.Row():
193
  additional_columns = gr.CheckboxGroup(
@@ -325,42 +336,56 @@ with GraInter:
325
  **NA:** When models either reply with one number for every anime, give ratings not between 1 and 10, or don't give every anime in the list a rating.
326
  """)
327
 
328
- def update_all_tables(query, param_ranges, w10_range, additional_cols):
329
- ugi_table = update_table(leaderboard_df, query, param_ranges, UGI_COLS, w10_range, additional_cols)
330
 
331
  ws_df = leaderboard_df.sort_values(by='Reg+MyScore πŸ†', ascending=False)
332
- ws_table = update_table(ws_df, query, param_ranges, WRITING_STYLE_COLS, w10_range, additional_cols)
333
 
334
  arp_df = leaderboard_df.sort_values(by='Score πŸ†', ascending=False)
335
  arp_df_na = arp_df[arp_df[['Dif', 'Cor']].isna().any(axis=1)]
336
  arp_df = arp_df[~arp_df[['Dif', 'Cor']].isna().any(axis=1)]
337
 
338
- arp_table = update_table(arp_df, query, param_ranges, ANIME_RATING_COLS, w10_range, additional_cols)
339
- arp_na_table = update_table(arp_df_na, query, param_ranges, ANIME_RATING_COLS, w10_range, additional_cols).fillna('NA')
340
 
341
  return ugi_table, ws_table, arp_table, arp_na_table
342
 
 
 
 
 
 
 
 
 
343
  search_bar.change(
344
  fn=update_all_tables,
345
- inputs=[search_bar, filter_columns_size, w10_range, additional_columns],
346
  outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
347
  )
348
 
349
  filter_columns_size.change(
350
  fn=update_all_tables,
351
- inputs=[search_bar, filter_columns_size, w10_range, additional_columns],
352
  outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
353
  )
354
-
 
 
 
 
 
 
355
  w10_range.change(
356
  fn=update_all_tables,
357
- inputs=[search_bar, filter_columns_size, w10_range, additional_columns],
358
  outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
359
  )
360
-
361
  additional_columns.change(
362
  fn=update_all_tables,
363
- inputs=[search_bar, filter_columns_size, w10_range, additional_columns],
364
  outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
365
  )
366
 
 
109
  return pd.DataFrame(columns=UGI_COLS + WRITING_STYLE_COLS + ANIME_RATING_COLS + ADDITIONAL_COLS)
110
 
111
  # Update the leaderboard table based on the search query and parameter range filters
112
+ def update_table(df: pd.DataFrame, query: str, param_ranges: list, is_foundation: bool, columns: list, w10_range: tuple, additional_cols: list) -> pd.DataFrame:
113
  filtered_df = df.copy()
114
+
115
+ # Apply model size filter
116
  if param_ranges:
117
+ size_mask = pd.Series(False, index=filtered_df.index)
118
  for param_range in param_ranges:
119
  if param_range == '~2':
120
+ size_mask |= (filtered_df['Total Params'] < 2.5)
121
  elif param_range == '~4':
122
+ size_mask |= ((filtered_df['Total Params'] >= 2.5) & (filtered_df['Total Params'] < 6))
123
  elif param_range == '~8':
124
+ size_mask |= ((filtered_df['Total Params'] >= 6) & (filtered_df['Total Params'] < 9.5))
125
  elif param_range == '~13':
126
+ size_mask |= ((filtered_df['Total Params'] >= 9.5) & (filtered_df['Total Params'] < 16))
127
  elif param_range == '~20':
128
+ size_mask |= ((filtered_df['Total Params'] >= 16) & (filtered_df['Total Params'] < 28))
129
  elif param_range == '~34':
130
+ size_mask |= ((filtered_df['Total Params'] >= 28) & (filtered_df['Total Params'] < 40))
131
  elif param_range == '~50':
132
+ size_mask |= ((filtered_df['Total Params'] >= 40) & (filtered_df['Total Params'] < 65))
133
  elif param_range == '~70+':
134
+ size_mask |= (filtered_df['Total Params'] >= 65)
135
  elif param_range == 'Closed':
136
+ size_mask |= filtered_df['Total Params'].isna()
137
+ filtered_df = filtered_df[size_mask]
138
+
139
+ # Apply foundation model filter
140
+ if is_foundation:
141
+ filtered_df = filtered_df[filtered_df['Foundation'] == 1]
142
 
143
  if query:
144
  filtered_df = filtered_df[filtered_df['Model'].str.contains(query, case=False, na=False)]
 
183
  with gr.Row():
184
  search_bar = gr.Textbox(placeholder=" πŸ” Search for a model...", show_label=False, elem_id="search-bar")
185
  with gr.Row():
186
+ with gr.Column(scale=7):
187
  filter_columns_size = gr.CheckboxGroup(
188
  label="Model sizes (in billions of parameters)",
189
+ choices=['~2', '~4', '~8', '~13', '~20', '~34', '~50', '~70+', 'Closed'],
190
  value=[],
191
  interactive=True,
192
  elem_id="filter-columns-size",
193
  )
194
+ with gr.Column(min_width=200, scale=0):
195
+ model_type = gr.Checkbox(
196
+ label="Foundation Models Only",
197
+ value=False,
198
+ interactive=True,
199
+ elem_id="model-type",
200
+ )
201
+ with gr.Column(scale=3):
202
  w10_range = RangeSlider(minimum=0, maximum=10, value=(0, 10), step=0.1, label="W/10 Range")
203
  with gr.Row():
204
  additional_columns = gr.CheckboxGroup(
 
336
  **NA:** When models either reply with one number for every anime, give ratings not between 1 and 10, or don't give every anime in the list a rating.
337
  """)
338
 
339
+ def update_all_tables(query, param_ranges, is_foundation, w10_range, additional_cols):
340
+ ugi_table = update_table(leaderboard_df, query, param_ranges, is_foundation, UGI_COLS, w10_range, additional_cols)
341
 
342
  ws_df = leaderboard_df.sort_values(by='Reg+MyScore πŸ†', ascending=False)
343
+ ws_table = update_table(ws_df, query, param_ranges, is_foundation, WRITING_STYLE_COLS, w10_range, additional_cols)
344
 
345
  arp_df = leaderboard_df.sort_values(by='Score πŸ†', ascending=False)
346
  arp_df_na = arp_df[arp_df[['Dif', 'Cor']].isna().any(axis=1)]
347
  arp_df = arp_df[~arp_df[['Dif', 'Cor']].isna().any(axis=1)]
348
 
349
+ arp_table = update_table(arp_df, query, param_ranges, is_foundation, ANIME_RATING_COLS, w10_range, additional_cols)
350
+ arp_na_table = update_table(arp_df_na, query, param_ranges, is_foundation, ANIME_RATING_COLS, w10_range, additional_cols).fillna('NA')
351
 
352
  return ugi_table, ws_table, arp_table, arp_na_table
353
 
354
+ # Update the event handlers
355
+ for component in [search_bar, filter_columns_size, model_type, w10_range, additional_columns]:
356
+ component.change(
357
+ fn=update_all_tables,
358
+ inputs=[search_bar, filter_columns_size, model_type, w10_range, additional_columns],
359
+ outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
360
+ )
361
+
362
  search_bar.change(
363
  fn=update_all_tables,
364
+ inputs=[search_bar, filter_columns_size, model_type, w10_range, additional_columns],
365
  outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
366
  )
367
 
368
  filter_columns_size.change(
369
  fn=update_all_tables,
370
+ inputs=[search_bar, filter_columns_size, model_type, w10_range, additional_columns],
371
  outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
372
  )
373
+
374
+ model_type.change(
375
+ fn=update_all_tables,
376
+ inputs=[search_bar, filter_columns_size, model_type, w10_range, additional_columns],
377
+ outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
378
+ )
379
+
380
  w10_range.change(
381
  fn=update_all_tables,
382
+ inputs=[search_bar, filter_columns_size, model_type, w10_range, additional_columns],
383
  outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
384
  )
385
+
386
  additional_columns.change(
387
  fn=update_all_tables,
388
+ inputs=[search_bar, filter_columns_size, model_type, w10_range, additional_columns],
389
  outputs=[leaderboard_table_ugi, leaderboard_table_ws, leaderboard_table_arp, leaderboard_table_arp_na]
390
  )
391