Spaces:
Sleeping
Sleeping
""" | |
Ted Multi TranslationDataset Class | |
------------------------------------ | |
""" | |
import collections | |
import datasets | |
import numpy as np | |
from textattack.datasets import HuggingFaceDataset | |
class TedMultiTranslationDataset(HuggingFaceDataset): | |
"""Loads examples from the Ted Talk translation dataset using the | |
`datasets` package. | |
dataset source: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/ | |
""" | |
def __init__(self, source_lang="en", target_lang="de", split="test"): | |
self._dataset = datasets.load_dataset("ted_multi")[split] | |
self.examples = self._dataset["translations"] | |
language_options = set(self.examples[0]["language"]) | |
if source_lang not in language_options: | |
raise ValueError( | |
f"Source language {source_lang} invalid. Choices: {sorted(language_options)}" | |
) | |
if target_lang not in language_options: | |
raise ValueError( | |
f"Target language {target_lang} invalid. Choices: {sorted(language_options)}" | |
) | |
self.source_lang = source_lang | |
self.target_lang = target_lang | |
def _format_raw_example(self, raw_example): | |
translations = np.array(raw_example["translation"]) | |
languages = np.array(raw_example["language"]) | |
source = translations[languages == self.source_lang][0] | |
target = translations[languages == self.target_lang][0] | |
source_dict = collections.OrderedDict([("Source", source)]) | |
return (source_dict, target) | |