|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""GLUE finetuning/evaluation.""" |
|
|
|
from megatron import get_args |
|
from megatron import print_rank_0 |
|
from megatron import get_tokenizer |
|
from megatron import mpu |
|
from megatron.model.classification import Classification |
|
from tasks.eval_utils import accuracy_func_provider |
|
from tasks.finetune_utils import finetune |
|
|
|
|
|
def clue_classification(num_classes, Dataset, |
|
name_from_datapath_func): |
|
|
|
def train_valid_datasets_provider(): |
|
"""Build train and validation dataset.""" |
|
args = get_args() |
|
tokenizer = get_tokenizer() |
|
|
|
train_dataset = Dataset('training', args.train_data, |
|
tokenizer, args.seq_length) |
|
valid_dataset = Dataset('validation', args.valid_data, |
|
tokenizer, args.seq_length) |
|
|
|
return train_dataset, valid_dataset |
|
|
|
def model_provider(pre_process=True, post_process=True): |
|
"""Build the model.""" |
|
args = get_args() |
|
|
|
print_rank_0('building classification model for {} ...'.format( |
|
args.task)) |
|
model = Classification(num_classes=num_classes, num_tokentypes=2, |
|
pre_process=pre_process, post_process=post_process) |
|
|
|
return model |
|
|
|
def metrics_func_provider(): |
|
"""Privde metrics callback function.""" |
|
def single_dataset_provider(datapath): |
|
args = get_args() |
|
tokenizer = get_tokenizer() |
|
name = name_from_datapath_func(datapath) |
|
return Dataset(name, [datapath], tokenizer, args.seq_length) |
|
return accuracy_func_provider(single_dataset_provider) |
|
|
|
"""Finetune/evaluate.""" |
|
finetune(train_valid_datasets_provider, model_provider, |
|
end_of_epoch_callback_provider=metrics_func_provider) |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
if args.task == 'AFQMC': |
|
num_classes = 2 |
|
from tasks.clue.afqmc import AFQMCDataset as Dataset |
|
|
|
def name_from_datapath(datapath): |
|
return "afqmc" |
|
|
|
elif args.task == 'CSL': |
|
num_classes = 2 |
|
from tasks.clue.csl import CSLDataset as Dataset |
|
|
|
def name_from_datapath(datapath): |
|
return "csl" |
|
|
|
elif args.task == 'IFLYTEK': |
|
num_classes = 119 |
|
from tasks.clue.iflytek import IFLYTEKDataset as Dataset |
|
|
|
def name_from_datapath(datapath): |
|
return "iflytek" |
|
|
|
elif args.task == 'OCNLI': |
|
num_classes = 3 |
|
from tasks.clue.ocnli import OCNLIDataset as Dataset |
|
|
|
def name_from_datapath(datapath): |
|
return "ocnli" |
|
|
|
elif args.task == 'TNEWS': |
|
num_classes = 15 |
|
from tasks.clue.tnews import TNEWSDataset as Dataset |
|
|
|
def name_from_datapath(datapath): |
|
return "tnews" |
|
|
|
elif args.task == 'WSC': |
|
num_classes = 2 |
|
from tasks.clue.wsc import WSCDataset as Dataset |
|
|
|
def name_from_datapath(datapath): |
|
return "wsc" |
|
|
|
elif args.task == 'CMNLI': |
|
num_classes = 3 |
|
from tasks.clue.cmnli import CMNLIDataset as Dataset |
|
|
|
def name_from_datapath(datapath): |
|
return "cmnli" |
|
|
|
elif args.task == 'ZC': |
|
num_classes = 2 |
|
from tasks.clue.zc import ZCDataset as Dataset |
|
|
|
def name_from_datapath(datapath): |
|
return "zc" |
|
|
|
else: |
|
raise NotImplementedError('GLUE task {} is not implemented.'.format( |
|
args.task)) |
|
|
|
clue_classification(num_classes, Dataset, name_from_datapath) |
|
|