File size: 10,644 Bytes
4c5c136
 
47170a5
a683732
 
4c5c136
 
47170a5
7c06aef
031925d
4c5c136
 
 
47170a5
 
 
 
 
 
031925d
 
 
 
 
 
 
 
 
47170a5
031925d
 
 
 
 
 
 
 
 
 
 
 
 
 
47170a5
 
 
 
 
031925d
 
 
 
 
47170a5
 
 
 
 
031925d
47170a5
 
 
 
 
031925d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47170a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
031925d
 
47170a5
031925d
 
 
 
 
47170a5
031925d
 
47170a5
031925d
47170a5
031925d
a683732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c5c136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a683732
 
4c5c136
 
7c06aef
 
a683732
75010c2
47170a5
a683732
4c5c136
7c06aef
75010c2
47170a5
 
4c5c136
7c06aef
4c5c136
 
 
7c06aef
47170a5
7c06aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a683732
4c5c136
 
 
 
 
 
 
7c06aef
4c5c136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import asyncio
import os
import random
from collections import Counter, defaultdict

from datasets import Dataset, load_dataset
from datasets_.util import _get_dataset_config_names, _load_dataset
from langcodes import Language, standardize_tag
from models import get_google_supported_languages, translate_google
from rich import print
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio


def print_counts(slug, subjects_dev, subjects_test):
    print(
        f"{slug:<25} {len(list(set(subjects_test))):>3} test categories, {len(subjects_test):>6} samples, {len(list(set(subjects_dev))):>3} dev categories, {len(subjects_dev):>6} dev samples"
    )


def print_datasets_analysis():
    print("Category counts and sample counts per dataset:")
    slug1 = "masakhane/afrimmlu"
    ds1 = _load_dataset(slug1, "eng")
    print_counts(slug1, ds1["dev"]["subject"], ds1["test"]["subject"])
    langs1 = _get_dataset_config_names(slug1)
    langs1 = [standardize_tag(a, macro=True) for a in langs1]

    slug2 = "openai/MMMLU"  # does not have dev set! – but: these languages are all also present in Global-MMLU
    ds2 = _load_dataset(slug2, "FR_FR")
    print_counts(slug2, [], ds2["test"]["Subject"])
    langs2 = _get_dataset_config_names(slug2)
    langs2 = [a.split("_")[0].lower() for a in langs2]
    langs2.remove("default")

    slug3 = "CohereForAI/Global-MMLU"
    ds3 = _load_dataset(slug3, "en")
    print_counts(slug3, ds3["dev"]["subject"], ds3["test"]["subject"])
    langs3 = _get_dataset_config_names(slug3)
    langs3 = [standardize_tag(a, macro=True) for a in langs3]

    slug4 = "lighteval/okapi_mmlu"
    ds4 = _load_dataset(slug4, "ar", trust_remote_code=True)
    print_counts(
        slug4,
        [a.split("/")[0] for a in ds4["dev"]["id"]],
        [a.split("/")[0] for a in ds4["test"]["id"]],
    )
    langs4 = _get_dataset_config_names(slug4)

    slug5 = "Eurolingua/mmlux"
    subsets = _get_dataset_config_names(slug5)
    subjects = set(a.rsplit("_", 1)[0] for a in subsets)
    rows_test = [
        _load_dataset(slug5, subset)["test"]["id"]
        for subset in subsets
        if "_DA" in subset
    ]
    rows_test = [a.split("/")[0] for l in rows_test for a in l]
    rows_dev = [
        _load_dataset(slug5, subset)["dev"]["id"]
        for subset in subsets
        if "_DA" in subset
    ]
    rows_dev = [a.split("/")[0] for l in rows_dev for a in l]
    print_counts(slug5, rows_dev, rows_test)
    langs5 = list(set(a.rsplit("_", 1)[1].split("-")[0].lower() for a in subsets))

    langs = langs1 + langs2 + langs3 + langs4 + langs5
    lang_datasets = defaultdict(list)
    for slug, langs_list in [
        (slug1, langs1),
        (slug2, langs2),
        (slug3, langs3),
        (slug4, langs4),
        (slug5, langs5),
    ]:
        for lang in langs_list:
            lname = Language.get(lang).display_name()
            lang_datasets[lname].append(slug)
    print("Datasets per language:")
    print(sorted(lang_datasets.items()))
    print(len(set(langs)))

    print("Datasets per language for languages that are not in Global-MMLU:")
    print(
        sorted(
            (lang, datasets)
            for lang, datasets in lang_datasets.items()
            if slug3 not in datasets
        )
    )
    print(
        Counter(
            dataset
            for ds_list in lang_datasets.values()
            for dataset in ds_list
            if slug3 not in ds_list
        )
    )
    print(list(set(ds1["test"]["subject"])))


# based on this analysis:
# - we drop the OpenAI dataset, since it does not have a dev set, and since every language that it has is also present in Global-MMLU
# - we stick to the 5 categories of the AfriMMLU dataset, since this is the most restricted dataset, and these 5 categories are present in all datasets, so this is good for comparability

# AfriMMLU is human-translated, but has only 5 task categories
# Global-MMLU is mixed-translated, specifically those 15 languages are that are also present in Global-MMLU-Lite, which are mostly from MMMLU; otherwise translated using Google Translate
# Okapi-MMLU is translated using ChatGPT (version unclear)
# MMLUX is translated using DeepL
# Therefore, the priority is: AfriMMLU, Global-MMLU, MMLUX, Okapi-MMLU

# print_datasets_analysis()


def parse_choices(row):
    if not isinstance(row["choices"], list):
        row["choices"] = eval(row["choices"])
    return row


def add_choices(row):
    row["choices"] = [
        row["option_a"],
        row["option_b"],
        row["option_c"],
        row["option_d"],
    ]
    return row


tags_afrimmlu = {
    standardize_tag(a, macro=True): a
    for a in _get_dataset_config_names("masakhane/afrimmlu")
}
tags_global_mmlu = {
    standardize_tag(a, macro=True): a
    for a in _get_dataset_config_names("CohereForAI/Global-MMLU")
}
tags_okapi = _get_dataset_config_names("lighteval/okapi_mmlu")
tags_mmlux = set(
    a.rsplit("_", 1)[1].split("-")[0].lower()
    for a in _get_dataset_config_names("Eurolingua/mmlux", trust_remote_code=True)
)
tags_mmlu_autotranslated = _get_dataset_config_names("fair-forward/mmlu-autotranslated")

categories = sorted(
        list(set(_load_dataset("masakhane/afrimmlu", "eng")["dev"]["subject"]))
    )


async def load_mmlu(language_bcp_47, nr):
    print(f"Loading MMLU data for {language_bcp_47}...")
    category = categories[nr % len(categories)]
    if language_bcp_47 in tags_afrimmlu.keys():
        ds = _load_dataset("masakhane/afrimmlu", tags_afrimmlu[language_bcp_47])
        ds = ds.map(parse_choices)
        task = ds["test"].filter(lambda x: x["subject"] == category)[nr]
        return "masakhane/afrimmlu", task, "human"
    elif language_bcp_47 in tags_global_mmlu.keys():
        ds = _load_dataset("CohereForAI/Global-MMLU", tags_global_mmlu[language_bcp_47])
        ds = ds.map(add_choices)
        task = ds["test"].filter(lambda x: x["subject"] == category)[nr]
        return "CohereForAI/Global-MMLU", task, "human"
    elif language_bcp_47 in tags_mmlu_autotranslated:
        ds = _load_dataset("fair-forward/mmlu-autotranslated", language_bcp_47)
        task = ds["test"].filter(lambda x: x["subject"] == category)[nr]
        return "fair-forward/mmlu-autotranslated", task, "machine"
    else:
        # Try on-the-fly translation for missing languages
        return await load_mmlu_translated(language_bcp_47, nr)


async def load_mmlu_translated(language_bcp_47, nr):
    """
    Load MMLU data with on-the-fly Google translation for languages 
    without native MMLU translations.
    """
    # Check if Google Translate supports this language
    supported_languages = get_google_supported_languages()
    if language_bcp_47 not in supported_languages:
        return None, None, None
    
    print(f"πŸ”„ Translating MMLU data to {language_bcp_47} on-the-fly...")
    
    try:
        # Load English MMLU data
        category = categories[nr % len(categories)]
        ds = _load_dataset("masakhane/afrimmlu", "eng")
        ds = ds.map(parse_choices)
        task = ds["test"].filter(lambda x: x["subject"] == category)[nr]
        
        # Translate question and choices
        question_translated = await translate_google(task["question"], "en", language_bcp_47)
        choices_translated = []
        for choice in task["choices"]:
            choice_translated = await translate_google(choice, "en", language_bcp_47)
            choices_translated.append(choice_translated)
        
        # Create translated task
        translated_task = {
            "question": question_translated,
            "choices": choices_translated,
            "answer": task["answer"],  # Keep original answer index
            "subject": task["subject"]
        }
        
        return f"mmlu-translated-{language_bcp_47}", translated_task, "machine"
        
    except Exception as e:
        print(f"❌ Translation failed for {language_bcp_47}: {e}")
        return None, None, None


def translate_mmlu(languages):
    human_translated = [*tags_afrimmlu.keys(), *tags_global_mmlu.keys()]
    untranslated = [
        lang
        for lang in languages["bcp_47"].values[:100]
        if lang not in human_translated and lang in get_google_supported_languages()
    ]
    n_samples = 10

    slug = "fair-forward/mmlu-autotranslated"
    for lang in tqdm(untranslated):
        # check if already exists on hub
        try:
            ds_lang = load_dataset(slug, lang)
        except (ValueError, Exception):
            print(f"Translating {lang}...")
            for split in ["dev", "test"]:
                ds = _load_dataset("masakhane/afrimmlu", "eng", split=split)
                samples = []
                for category in categories:
                    if split == "dev":
                        samples.extend(ds.filter(lambda x: x["subject"] == category))
                    else:
                        for i in range(n_samples):
                            task = ds.filter(lambda x: x["subject"] == category)[i]
                            samples.append(task)
                questions_tr = [
                    translate_google(s["question"], "en", lang) for s in samples
                ]
                questions_tr = asyncio.run(tqdm_asyncio.gather(*questions_tr))
                choices_texts_concatenated = []
                for s in samples:
                    for choice in eval(s["choices"]):
                        choices_texts_concatenated.append(choice)
                choices_tr = [
                    translate_google(c, "en", lang) for c in choices_texts_concatenated
                ]
                choices_tr = asyncio.run(tqdm_asyncio.gather(*choices_tr))
                # group into chunks of 4
                choices_tr = [
                    choices_tr[i : i + 4] for i in range(0, len(choices_tr), 4)
                ]

                ds_lang = Dataset.from_dict(
                    {
                        "subject": [s["subject"] for s in samples],
                        "question": questions_tr,
                        "choices": choices_tr,
                        "answer": [s["answer"] for s in samples],
                    }
                )
                ds_lang.push_to_hub(
                    slug,
                    split=split,
                    config_name=lang,
                    token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
                )
                ds_lang.to_json(
                    f"data/translations/mmlu/{lang}_{split}.json",
                    lines=False,
                    force_ascii=False,
                    indent=2,
                )