|
|
|
AFQMC_LABELS = { |
|
'0': '0', |
|
'1': '1', |
|
} |
|
|
|
CSL_LABELS = { |
|
'0': '0', |
|
'1': '1', |
|
'2': '2', |
|
} |
|
|
|
IFLYTEK_LABELS = {} |
|
for i in range(119): |
|
IFLYTEK_LABELS[str(i)] = str(i) |
|
|
|
OCNLI_LABELS = { |
|
'contradiction': '0', |
|
'entailment': '1', |
|
'neutral': '2' |
|
} |
|
|
|
CMNLI_LABELS = { |
|
'contradiction': '0', |
|
'entailment': '1', |
|
'neutral': '2' |
|
} |
|
|
|
TNEWS_LABELS = {} |
|
tnews_list = [] |
|
for i in range(17): |
|
if i == 5 or i == 11: |
|
continue |
|
tnews_list.append(i) |
|
for i in range(len(tnews_list)): |
|
TNEWS_LABELS[str(100 + tnews_list[i])] = str(i) |
|
|
|
WSC_LABELS = { |
|
'true': '0', |
|
'false': '1', |
|
} |
|
|
|
ZC_LABELS = { |
|
'negative': '0', |
|
'positive': '1', |
|
} |
|
|
|
def get_label_dict(task_name, write2file=False): |
|
|
|
if task_name == "AFQMC": |
|
label_dict = AFQMC_LABELS |
|
elif task_name == "CSL": |
|
label_dict = CSL_LABELS |
|
elif task_name == "IFLYTEK": |
|
label_dict = IFLYTEK_LABELS |
|
elif task_name == "OCNLI": |
|
label_dict = OCNLI_LABELS |
|
elif task_name == "TNEWS": |
|
label_dict = TNEWS_LABELS |
|
elif task_name == "WSC": |
|
label_dict = WSC_LABELS |
|
elif task_name == "CMNLI": |
|
label_dict = CMNLI_LABELS |
|
elif task_name == "ZC": |
|
label_dict = ZC_LABELS |
|
else: |
|
print("Not Imp") |
|
import pdb;pdb.set_trace() |
|
|
|
if write2file: |
|
label_dict = {v:k for k,v in label_dict.items()} |
|
|
|
return label_dict |