saattrupdan commited on
Commit
1d11c02
·
1 Parent(s): 64071e4

feat: Optimise colour mapping for visible models only

Browse files
Files changed (1) hide show
  1. app.py +54 -36
app.py CHANGED
@@ -127,6 +127,8 @@ paper](https://aclanthology.org/2023.nodalida-1.20):
127
 
128
  UPDATE_FREQUENCY_MINUTES = 5
129
  MIN_COLOUR_DISTANCE_BETWEEN_MODELS = 200
 
 
130
 
131
 
132
  class Task(BaseModel):
@@ -170,12 +172,14 @@ INFORMATION_EXTRACTION = Task(name="information extraction", metric="micro_f1_no
170
  ALL_TASKS = [obj for obj in globals().values() if isinstance(obj, Task)]
171
 
172
  DANISH = Language(code="da", name="Danish")
173
- NORWEGIAN = Language(code="no", name="Norwegian")
174
- SWEDISH = Language(code="sv", name="Swedish")
175
- ICELANDIC = Language(code="is", name="Icelandic")
176
- GERMAN = Language(code="de", name="German")
177
  DUTCH = Language(code="nl", name="Dutch")
178
  ENGLISH = Language(code="en", name="English")
 
 
 
 
 
 
179
  ALL_LANGUAGES = {
180
  obj.name: obj for obj in globals().values() if isinstance(obj, Language)
181
  }
@@ -187,6 +191,9 @@ DATASETS = [
187
  Dataset(name="sb10k", language=GERMAN, task=TEXT_CLASSIFICATION),
188
  Dataset(name="dutch-social", language=DUTCH, task=TEXT_CLASSIFICATION),
189
  Dataset(name="sst5", language=ENGLISH, task=TEXT_CLASSIFICATION),
 
 
 
190
  Dataset(name="suc3", language=SWEDISH, task=INFORMATION_EXTRACTION),
191
  Dataset(name="dansk", language=DANISH, task=INFORMATION_EXTRACTION),
192
  Dataset(name="norne-nb", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
@@ -195,6 +202,9 @@ DATASETS = [
195
  Dataset(name="germeval", language=GERMAN, task=INFORMATION_EXTRACTION),
196
  Dataset(name="conll-nl", language=DUTCH, task=INFORMATION_EXTRACTION),
197
  Dataset(name="conll-en", language=ENGLISH, task=INFORMATION_EXTRACTION),
 
 
 
198
  Dataset(name="scala-sv", language=SWEDISH, task=GRAMMAR),
199
  Dataset(name="scala-da", language=DANISH, task=GRAMMAR),
200
  Dataset(name="scala-nb", language=NORWEGIAN, task=GRAMMAR),
@@ -203,6 +213,9 @@ DATASETS = [
203
  Dataset(name="scala-de", language=GERMAN, task=GRAMMAR),
204
  Dataset(name="scala-nl", language=DUTCH, task=GRAMMAR),
205
  Dataset(name="scala-en", language=ENGLISH, task=GRAMMAR),
 
 
 
206
  Dataset(name="scandiqa-da", language=DANISH, task=READING_COMPREHENSION),
207
  Dataset(name="norquad", language=NORWEGIAN, task=READING_COMPREHENSION),
208
  Dataset(name="scandiqa-sv", language=SWEDISH, task=READING_COMPREHENSION),
@@ -210,6 +223,9 @@ DATASETS = [
210
  Dataset(name="germanquad", language=GERMAN, task=READING_COMPREHENSION),
211
  Dataset(name="squad", language=ENGLISH, task=READING_COMPREHENSION),
212
  Dataset(name="squad-nl", language=DUTCH, task=READING_COMPREHENSION),
 
 
 
213
  Dataset(name="nordjylland-news", language=DANISH, task=SUMMARISATION),
214
  Dataset(name="mlsum", language=GERMAN, task=SUMMARISATION),
215
  Dataset(name="rrn", language=ICELANDIC, task=SUMMARISATION),
@@ -217,6 +233,8 @@ DATASETS = [
217
  Dataset(name="wiki-lingua-nl", language=DUTCH, task=SUMMARISATION),
218
  Dataset(name="swedn", language=SWEDISH, task=SUMMARISATION),
219
  Dataset(name="cnn-dailymail", language=ENGLISH, task=SUMMARISATION),
 
 
220
  Dataset(name="danish-citizen-tests", language=DANISH, task=KNOWLEDGE),
221
  Dataset(name="danske-talemaader", language=DANISH, task=KNOWLEDGE),
222
  Dataset(name="mmlu-no", language=NORWEGIAN, task=KNOWLEDGE),
@@ -225,6 +243,8 @@ DATASETS = [
225
  Dataset(name="mmlu-de", language=GERMAN, task=KNOWLEDGE),
226
  Dataset(name="mmlu-nl", language=DUTCH, task=KNOWLEDGE),
227
  Dataset(name="mmlu", language=ENGLISH, task=KNOWLEDGE),
 
 
228
  Dataset(name="hellaswag-da", language=DANISH, task=REASONING),
229
  Dataset(name="hellaswag-no", language=NORWEGIAN, task=REASONING),
230
  Dataset(name="hellaswag-sv", language=SWEDISH, task=REASONING),
@@ -232,6 +252,7 @@ DATASETS = [
232
  Dataset(name="hellaswag-de", language=GERMAN, task=REASONING),
233
  Dataset(name="hellaswag-nl", language=DUTCH, task=REASONING),
234
  Dataset(name="hellaswag", language=ENGLISH, task=REASONING),
 
235
  ]
236
 
237
 
@@ -254,7 +275,8 @@ def main() -> None:
254
  global colour_mapping
255
  global seed
256
  seed = 4242
257
- update_colour_mapping(results_dfs=results_dfs)
 
258
 
259
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
260
  gr.Markdown(INTRO_MARKDOWN)
@@ -266,7 +288,7 @@ def main() -> None:
266
  choices=all_languages,
267
  multiselect=True,
268
  label="Languages",
269
- value=["Danish"],
270
  interactive=True,
271
  scale=2,
272
  )
@@ -274,7 +296,7 @@ def main() -> None:
274
  choices=danish_models,
275
  multiselect=True,
276
  label="Models",
277
- value=["gpt-4-0613", "mistralai/Mistral-7B-v0.1"],
278
  interactive=True,
279
  scale=2,
280
  )
@@ -310,11 +332,6 @@ def main() -> None:
310
  interactive=True,
311
  scale=1,
312
  )
313
- update_colours_button = gr.Button(
314
- value="Update colours",
315
- interactive=True,
316
- scale=1,
317
- )
318
  with gr.Row():
319
  plot = gr.Plot(
320
  value=produce_radial_plot(
@@ -339,7 +356,7 @@ def main() -> None:
339
  fn=partial(update_model_ids_dropdown, results_dfs=results_dfs),
340
  inputs=[language_names_dropdown, model_ids_dropdown],
341
  outputs=model_ids_dropdown,
342
- )
343
 
344
  # Update plot when anything changes
345
  update_plot_kwargs = dict(
@@ -357,16 +374,23 @@ def main() -> None:
357
  ],
358
  outputs=plot,
359
  )
360
- language_names_dropdown.change(**update_plot_kwargs)
361
- model_ids_dropdown.change(**update_plot_kwargs)
362
- use_rank_score_checkbox.change(**update_plot_kwargs)
363
- show_scale_checkbox.change(**update_plot_kwargs)
364
- plot_width_slider.change(**update_plot_kwargs)
365
- plot_height_slider.change(**update_plot_kwargs)
366
-
367
- # Update colours when the button is clicked
368
- update_colours_button.click(
369
- fn=partial(update_colour_mapping, results_dfs=results_dfs),
 
 
 
 
 
 
 
370
  ).then(**update_plot_kwargs)
371
 
372
  demo.launch()
@@ -703,29 +727,23 @@ def fetch_results() -> dict[Language, pd.DataFrame]:
703
  return results_dfs
704
 
705
 
706
- def update_colour_mapping(results_dfs: dict[Language, pd.DataFrame]) -> None:
707
  """Get a mapping from model ids to RGB triplets.
708
 
709
  Args:
710
- results_dfs:
711
- The results dataframes for each language.
712
  """
713
  global colour_mapping
714
  global seed
715
  seed += 1
716
 
717
- gr.Info(f"Updating colour mapping...")
718
-
719
- # Get distinct RGB values for all models
720
- all_models = list(
721
- {model_id for df in results_dfs.values() for model_id in df.index}
722
- )
723
- colour_mapping = dict()
724
-
725
  for i in it.count():
726
  min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
727
- retries_left = 10 * len(all_models)
728
- for model_id in all_models:
 
 
729
  random.seed(hash(model_id) + i + seed)
730
  r, g, b = 0, 0, 0
731
  too_bright, similar_to_other_model = True, True
 
127
 
128
  UPDATE_FREQUENCY_MINUTES = 5
129
  MIN_COLOUR_DISTANCE_BETWEEN_MODELS = 200
130
+ DEFAULT_LANGUAGES = ["Danish"]
131
+ DEFAULT_MODELS = ["gpt-4-0613", "mistralai/Mistral-7B-v0.1"]
132
 
133
 
134
  class Task(BaseModel):
 
172
  ALL_TASKS = [obj for obj in globals().values() if isinstance(obj, Task)]
173
 
174
  DANISH = Language(code="da", name="Danish")
 
 
 
 
175
  DUTCH = Language(code="nl", name="Dutch")
176
  ENGLISH = Language(code="en", name="English")
177
+ FAROESE = Language(code="fo", name="Faroese")
178
+ FRENCH = Language(code="fr", name="French")
179
+ GERMAN = Language(code="de", name="German")
180
+ ICELANDIC = Language(code="is", name="Icelandic")
181
+ NORWEGIAN = Language(code="no", name="Norwegian")
182
+ SWEDISH = Language(code="sv", name="Swedish")
183
  ALL_LANGUAGES = {
184
  obj.name: obj for obj in globals().values() if isinstance(obj, Language)
185
  }
 
191
  Dataset(name="sb10k", language=GERMAN, task=TEXT_CLASSIFICATION),
192
  Dataset(name="dutch-social", language=DUTCH, task=TEXT_CLASSIFICATION),
193
  Dataset(name="sst5", language=ENGLISH, task=TEXT_CLASSIFICATION),
194
+ Dataset(name="fosent", language=FAROESE, task=TEXT_CLASSIFICATION),
195
+ Dataset(name="allocine", language=FRENCH, task=TEXT_CLASSIFICATION),
196
+
197
  Dataset(name="suc3", language=SWEDISH, task=INFORMATION_EXTRACTION),
198
  Dataset(name="dansk", language=DANISH, task=INFORMATION_EXTRACTION),
199
  Dataset(name="norne-nb", language=NORWEGIAN, task=INFORMATION_EXTRACTION),
 
202
  Dataset(name="germeval", language=GERMAN, task=INFORMATION_EXTRACTION),
203
  Dataset(name="conll-nl", language=DUTCH, task=INFORMATION_EXTRACTION),
204
  Dataset(name="conll-en", language=ENGLISH, task=INFORMATION_EXTRACTION),
205
+ Dataset(name="fone", language=FAROESE, task=INFORMATION_EXTRACTION),
206
+ Dataset(name="eltec", language=FRENCH, task=INFORMATION_EXTRACTION),
207
+
208
  Dataset(name="scala-sv", language=SWEDISH, task=GRAMMAR),
209
  Dataset(name="scala-da", language=DANISH, task=GRAMMAR),
210
  Dataset(name="scala-nb", language=NORWEGIAN, task=GRAMMAR),
 
213
  Dataset(name="scala-de", language=GERMAN, task=GRAMMAR),
214
  Dataset(name="scala-nl", language=DUTCH, task=GRAMMAR),
215
  Dataset(name="scala-en", language=ENGLISH, task=GRAMMAR),
216
+ Dataset(name="scala-fo", language=FAROESE, task=GRAMMAR),
217
+ Dataset(name="scala-fr", language=FRENCH, task=GRAMMAR),
218
+
219
  Dataset(name="scandiqa-da", language=DANISH, task=READING_COMPREHENSION),
220
  Dataset(name="norquad", language=NORWEGIAN, task=READING_COMPREHENSION),
221
  Dataset(name="scandiqa-sv", language=SWEDISH, task=READING_COMPREHENSION),
 
223
  Dataset(name="germanquad", language=GERMAN, task=READING_COMPREHENSION),
224
  Dataset(name="squad", language=ENGLISH, task=READING_COMPREHENSION),
225
  Dataset(name="squad-nl", language=DUTCH, task=READING_COMPREHENSION),
226
+ Dataset(name="foqa", language=FAROESE, task=READING_COMPREHENSION),
227
+ Dataset(name="fquad", language=FRENCH, task=READING_COMPREHENSION),
228
+
229
  Dataset(name="nordjylland-news", language=DANISH, task=SUMMARISATION),
230
  Dataset(name="mlsum", language=GERMAN, task=SUMMARISATION),
231
  Dataset(name="rrn", language=ICELANDIC, task=SUMMARISATION),
 
233
  Dataset(name="wiki-lingua-nl", language=DUTCH, task=SUMMARISATION),
234
  Dataset(name="swedn", language=SWEDISH, task=SUMMARISATION),
235
  Dataset(name="cnn-dailymail", language=ENGLISH, task=SUMMARISATION),
236
+ Dataset(name="orange-sum", language=FRENCH, task=SUMMARISATION),
237
+
238
  Dataset(name="danish-citizen-tests", language=DANISH, task=KNOWLEDGE),
239
  Dataset(name="danske-talemaader", language=DANISH, task=KNOWLEDGE),
240
  Dataset(name="mmlu-no", language=NORWEGIAN, task=KNOWLEDGE),
 
243
  Dataset(name="mmlu-de", language=GERMAN, task=KNOWLEDGE),
244
  Dataset(name="mmlu-nl", language=DUTCH, task=KNOWLEDGE),
245
  Dataset(name="mmlu", language=ENGLISH, task=KNOWLEDGE),
246
+ Dataset(name="mmlu-fr", language=FRENCH, task=KNOWLEDGE),
247
+
248
  Dataset(name="hellaswag-da", language=DANISH, task=REASONING),
249
  Dataset(name="hellaswag-no", language=NORWEGIAN, task=REASONING),
250
  Dataset(name="hellaswag-sv", language=SWEDISH, task=REASONING),
 
252
  Dataset(name="hellaswag-de", language=GERMAN, task=REASONING),
253
  Dataset(name="hellaswag-nl", language=DUTCH, task=REASONING),
254
  Dataset(name="hellaswag", language=ENGLISH, task=REASONING),
255
+ Dataset(name="hellaswag-fr", language=FRENCH, task=REASONING),
256
  ]
257
 
258
 
 
275
  global colour_mapping
276
  global seed
277
  seed = 4242
278
+ colour_mapping = dict()
279
+ update_colour_mapping(model_ids=DEFAULT_MODELS)
280
 
281
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
282
  gr.Markdown(INTRO_MARKDOWN)
 
288
  choices=all_languages,
289
  multiselect=True,
290
  label="Languages",
291
+ value=DEFAULT_LANGUAGES,
292
  interactive=True,
293
  scale=2,
294
  )
 
296
  choices=danish_models,
297
  multiselect=True,
298
  label="Models",
299
+ value=DEFAULT_MODELS,
300
  interactive=True,
301
  scale=2,
302
  )
 
332
  interactive=True,
333
  scale=1,
334
  )
 
 
 
 
 
335
  with gr.Row():
336
  plot = gr.Plot(
337
  value=produce_radial_plot(
 
356
  fn=partial(update_model_ids_dropdown, results_dfs=results_dfs),
357
  inputs=[language_names_dropdown, model_ids_dropdown],
358
  outputs=model_ids_dropdown,
359
+ ).then(fn=update_colour_mapping, inputs=model_ids_dropdown)
360
 
361
  # Update plot when anything changes
362
  update_plot_kwargs = dict(
 
374
  ],
375
  outputs=plot,
376
  )
377
+ language_names_dropdown.change(
378
+ fn=update_colour_mapping, inputs=model_ids_dropdown
379
+ ).then(**update_plot_kwargs)
380
+ model_ids_dropdown.change(
381
+ fn=update_colour_mapping, inputs=model_ids_dropdown
382
+ ).then(**update_plot_kwargs)
383
+ use_rank_score_checkbox.change(
384
+ fn=update_colour_mapping, inputs=model_ids_dropdown
385
+ ).then(**update_plot_kwargs)
386
+ show_scale_checkbox.change(
387
+ fn=update_colour_mapping, inputs=model_ids_dropdown
388
+ ).then(**update_plot_kwargs)
389
+ plot_width_slider.change(
390
+ fn=update_colour_mapping, inputs=model_ids_dropdown
391
+ ).then(**update_plot_kwargs)
392
+ plot_height_slider.change(
393
+ fn=update_colour_mapping, inputs=model_ids_dropdown
394
  ).then(**update_plot_kwargs)
395
 
396
  demo.launch()
 
727
  return results_dfs
728
 
729
 
730
+ def update_colour_mapping(model_ids: list[str]) -> None:
731
  """Get a mapping from model ids to RGB triplets.
732
 
733
  Args:
734
+ model_ids:
735
+ The model ids to update the colour
736
  """
737
  global colour_mapping
738
  global seed
739
  seed += 1
740
 
 
 
 
 
 
 
 
 
741
  for i in it.count():
742
  min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
743
+ retries_left = 10 * len(model_ids)
744
+ for model_id in model_ids:
745
+ if model_id in colour_mapping:
746
+ continue
747
  random.seed(hash(model_id) + i + seed)
748
  r, g, b = 0, 0, 0
749
  too_bright, similar_to_other_model = True, True