Spaces:
Runtime error
Runtime error
File size: 4,365 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import argparse
import importlib
import os
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import merge_with_parent
from hydra.core.config_store import ConfigStore
from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa
# register dataclass
TASK_DATACLASS_REGISTRY = {}
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(cfg: FairseqDataclass, **kwargs):
task = None
task_name = getattr(cfg, "task", None)
if isinstance(task_name, str):
# legacy tasks
task = TASK_REGISTRY[task_name]
if task_name in TASK_DATACLASS_REGISTRY:
dc = TASK_DATACLASS_REGISTRY[task_name]
cfg = dc.from_namespace(cfg)
else:
task_name = getattr(cfg, "_name", None)
if task_name and task_name in TASK_DATACLASS_REGISTRY:
dc = TASK_DATACLASS_REGISTRY[task_name]
cfg = merge_with_parent(dc(), cfg)
task = TASK_REGISTRY[task_name]
assert (
task is not None
), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}"
return task.setup_task(cfg, **kwargs)
def register_task(name, dataclass=None):
"""
New tasks can be added to fairseq with the
:func:`~fairseq.tasks.register_task` function decorator.
For example::
@register_task('classification')
class ClassificationTask(FairseqTask):
(...)
.. note::
All Tasks must implement the :class:`~fairseq.tasks.FairseqTask`
interface.
Args:
name (str): the name of the task
"""
def register_task_cls(cls):
if name in TASK_REGISTRY:
raise ValueError("Cannot register duplicate task ({})".format(name))
if not issubclass(cls, FairseqTask):
raise ValueError(
"Task ({}: {}) must extend FairseqTask".format(name, cls.__name__)
)
if cls.__name__ in TASK_CLASS_NAMES:
raise ValueError(
"Cannot register task with duplicate class name ({})".format(
cls.__name__
)
)
TASK_REGISTRY[name] = cls
TASK_CLASS_NAMES.add(cls.__name__)
if dataclass is not None and not issubclass(dataclass, FairseqDataclass):
raise ValueError(
"Dataclass {} must extend FairseqDataclass".format(dataclass)
)
cls.__dataclass = dataclass
if dataclass is not None:
TASK_DATACLASS_REGISTRY[name] = dataclass
cs = ConfigStore.instance()
node = dataclass()
node._name = name
cs.store(name=name, group="task", node=node, provider="fairseq")
return cls
return register_task_cls
def get_task(name):
return TASK_REGISTRY[name]
def import_tasks(tasks_dir, namespace):
for file in os.listdir(tasks_dir):
path = os.path.join(tasks_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
task_name = file[: file.find(".py")] if file.endswith(".py") else file
importlib.import_module(namespace + "." + task_name)
# expose `task_parser` for sphinx
if task_name in TASK_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_task = parser.add_argument_group("Task name")
# fmt: off
group_task.add_argument('--task', metavar=task_name,
help='Enable this task with: ``--task=' + task_name + '``')
# fmt: on
group_args = parser.add_argument_group(
"Additional command-line arguments"
)
TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + "_parser"] = parser
# automatically import any Python files in the tasks/ directory
tasks_dir = os.path.dirname(__file__)
import_tasks(tasks_dir, "fairseq.tasks")
|