Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The IDEA Authors. All rights reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import os | |
from sklearn import metrics | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader, ConcatDataset | |
import pytorch_lightning as pl | |
from collections import defaultdict | |
from transformers import AutoConfig, AutoModel, get_cosine_schedule_with_warmup | |
from loss import FocalLoss, LabelSmoothingCorrectionCrossEntropy | |
class CustomDataset(Dataset): | |
def __init__(self, file, tokenizer, max_len, mode='no_test'): | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
self.mode = mode | |
self.ex_list = [] | |
with open('./dataset/' + file, "r", encoding='utf-8') as f: | |
for line in f: | |
sample = json.loads(line) | |
query = sample["query"] | |
title = sample["title"] | |
id = int(sample["id"]) | |
if self.mode == 'no_test': | |
relevant = int(sample["label"]) | |
self.ex_list.append((query, title, relevant, id)) | |
else: | |
self.ex_list.append((query, title, id)) | |
def __len__(self): | |
return len(self.ex_list) | |
def __getitem__(self, index): | |
if self.mode == 'no_test': | |
query, title, relevant, id = self.ex_list[index] | |
else: | |
query, title, id = self.ex_list[index] | |
inputs = self.tokenizer.encode_plus( | |
query, title, | |
truncation=True, | |
add_special_tokens=True, | |
max_length=self.max_len, | |
padding='max_length', | |
return_token_type_ids=True | |
) | |
ids = inputs['input_ids'] | |
mask = inputs['attention_mask'] | |
token_type_ids = inputs["token_type_ids"] | |
if self.mode == 'no_test': | |
return { | |
'ids': torch.tensor(ids, dtype=torch.long), | |
'mask': torch.tensor(mask, dtype=torch.long), | |
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), | |
'targets': torch.tensor(relevant, dtype=torch.float), | |
'id': torch.tensor(id, dtype=torch.long) | |
} | |
else: | |
return { | |
'ids': torch.tensor(ids, dtype=torch.long), | |
'mask': torch.tensor(mask, dtype=torch.long), | |
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), | |
'id': torch.tensor(id, dtype=torch.long) | |
} | |
class CustomDataModule(pl.LightningDataModule): | |
def __init__(self, args, tokenizer): | |
super().__init__() | |
self.args = args | |
self.tokenizer = tokenizer | |
self.max_len = self.args.max_seq_length | |
self.train_dataset = None | |
self.val_dataset = None | |
def setup(self, stage): | |
data_path = "./dataset" | |
assert os.path.exists(os.path.join(data_path, 'train.json')) | |
assert os.path.exists(os.path.join(data_path, 'dev.json')) | |
assert os.path.exists(os.path.join(data_path, 'test_public.json')) | |
if stage == 'fit': | |
self.train_dataset = CustomDataset('train.json', self.tokenizer, self.max_len) | |
self.val_dataset = CustomDataset('dev.json', self.tokenizer, self.max_len) | |
self.test_dataset = CustomDataset('test_public.json', self.tokenizer, self.max_len) | |
elif stage == 'test': | |
self.test_dataset = CustomDataset('test_public.json', self.tokenizer, self.max_len) | |
def train_dataloader(self): | |
full_dataset = ConcatDataset([self.train_dataset, self.val_dataset]) | |
train_dataloader = DataLoader( | |
full_dataset, | |
batch_size=self.args.batch_size, | |
num_workers=4, | |
shuffle=True, | |
pin_memory=True, | |
drop_last=True) | |
return train_dataloader | |
def val_dataloader(self): | |
val_dataloader = DataLoader( | |
self.test_dataset, | |
batch_size=self.args.val_batch_size, | |
num_workers=4, | |
shuffle=False, | |
pin_memory=True, | |
drop_last=False) | |
return val_dataloader | |
def test_dataloader(self): | |
test_dataloader = DataLoader( | |
self.test_dataset, | |
batch_size=self.args.val_batch_size, | |
num_workers=4, | |
shuffle=False, | |
pin_memory=True, | |
drop_last=False) | |
return test_dataloader | |
class CustomModel(pl.LightningModule): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.model = self.args.model_name | |
self.cache_dir = self.args.model_path | |
self.scheduler = self.args.scheduler | |
self.step_scheduler_after = "batch" | |
self.optimizer = self.args.optimizer | |
self.pooler = self.args.use_original_pooler | |
self.category = self.args.cate_performance | |
self.loss_func = self.args.loss_function | |
hidden_dropout_prob: float = 0.1 | |
layer_norm_eps: float = 1e-7 | |
config = AutoConfig.from_pretrained(self.model, cache_dir=self.cache_dir) | |
config.update( | |
{ | |
"output_hidden_states": False, | |
"hidden_dropout_prob": hidden_dropout_prob, | |
"layer_norm_eps": layer_norm_eps, | |
} | |
) | |
self.transformer = AutoModel.from_pretrained(self.model, config=config, cache_dir=self.cache_dir) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self.linear = torch.nn.Linear(config.hidden_size, self.args.num_labels, bias=True) # 分三类 | |
def configure_optimizers(self): | |
"""Prepare optimizer and schedule""" | |
model = self.transformer | |
no_decay = ["bias", "LayerNorm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | |
"weight_decay": 0.01, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], | |
"weight_decay": 0.0, | |
}, | |
] | |
optimizer_index = ['Adam', 'AdamW'].index(self.optimizer) | |
optimizer = [ | |
torch.optim.Adam(optimizer_grouped_parameters, lr=self.args.learning_rate), | |
torch.optim.AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate)][optimizer_index] | |
scheduler_index = ['StepLR', 'CosineWarmup', 'CosineAnnealingLR'].index(self.scheduler) | |
scheduler = [ | |
torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.args.warmup_step, | |
gamma=self.args.warmup_proportion), | |
get_cosine_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=int(self.args.warmup_proportion * self.total_steps), | |
num_training_steps=self.total_steps, | |
), | |
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=2e-06)][scheduler_index] | |
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} | |
return [optimizer], [scheduler] | |
def setup(self, stage=None): | |
if stage != "fit": | |
return | |
# calculate total steps | |
train_dataloader = self.trainer.datamodule.train_dataloader() | |
gpus = 0 if self.trainer.gpus is None else self.trainer.gpus | |
tb_size = self.args.batch_size * max(1, gpus) | |
ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs) | |
self.total_steps = (len(train_dataloader.dataset) // tb_size) // ab_size | |
def loss(self, outputs, targets): | |
lossf_index = ['CE', 'Focal', 'LSCE_correction'].index(self.loss_func) | |
loss_fct = [nn.CrossEntropyLoss(), FocalLoss(), LabelSmoothingCorrectionCrossEntropy()][lossf_index] | |
loss = loss_fct(outputs, targets) | |
return loss | |
def category_performance_measure(self, labels_right, labels_pred, num_label=3): | |
text_labels = [i for i in range(num_label)] | |
TP = dict.fromkeys(text_labels, 0) # 预测正确的各个类的数目 | |
TP_FP = dict.fromkeys(text_labels, 0) # 测试数据集中各个类的数目 | |
TP_FN = dict.fromkeys(text_labels, 0) # 预测结果中各个类的数目 | |
label_dict = defaultdict(list) | |
for num in range(num_label): | |
label_dict[num].append(str(num)) | |
# 计算TP等数量 | |
for i in range(0, len(labels_right)): | |
TP_FP[labels_right[i]] += 1 | |
TP_FN[labels_pred[i]] += 1 | |
if labels_right[i] == labels_pred[i]: | |
TP[labels_right[i]] += 1 | |
# 计算准确率P,召回率R,F1值 | |
results = [] | |
for key in TP_FP: | |
P = float(TP[key]) / float(TP_FP[key] + 1e-9) | |
R = float(TP[key]) / float(TP_FN[key] + 1e-9) | |
F1 = P * R * 2 / (P + R) if (P + R) != 0 else 0 | |
# results.append("%s:\t P:%f\t R:%f\t F1:%f" % (key, P, R, F1)) | |
results.append(F1) | |
return results | |
def monitor_metrics(self, outputs, targets): | |
pred = torch.argmax(outputs, dim=1).cpu().numpy().tolist() | |
targets = targets.int().cpu().numpy().tolist() | |
if self.category: | |
category_results = self.category_performance_measure( | |
labels_right=targets, | |
labels_pred=pred, | |
num_label=self.args.num_labels | |
) | |
return {"f1": category_results} | |
else: | |
f1_score = metrics.f1_score(targets, pred, average="macro") | |
return {"f1": f1_score} | |
def forward(self, ids, mask, token_type_ids, labels): | |
transformer_out = self.transformer(input_ids=ids, attention_mask=mask, token_type_ids=token_type_ids) | |
if self.pooler: | |
pooler_output = transformer_out.pooler_output | |
else: | |
sequence_output = transformer_out.last_hidden_state | |
pooler_output = torch.mean(sequence_output, dim=1) | |
logits = self.linear(self.dropout(pooler_output)) | |
labels_hat = torch.argmax(logits, dim=1) | |
correct_count = torch.sum(labels == labels_hat) | |
return logits, correct_count | |
def predict(self, ids, mask, token_type_ids): | |
transformer_out = self.transformer(input_ids=ids, attention_mask=mask, token_type_ids=token_type_ids) | |
pooler_output = transformer_out.pooler_output | |
logits = self.linear(self.dropout(pooler_output)) | |
logits = torch.argmax(logits, dim=1) | |
return logits | |
def training_step(self, batch, batch_idx): | |
ids, mask, token_type_ids, labels = batch['ids'], batch['mask'], batch['token_type_ids'], batch['targets'] | |
logits, correct_count = self.forward(ids, mask, token_type_ids, labels) | |
loss = self.loss(logits, labels.long()) | |
f1 = self.monitor_metrics(logits, labels)["f1"] | |
self.log("train_loss", loss, logger=True, prog_bar=True) | |
self.log('train_acc', correct_count.float() / len(labels), logger=True, prog_bar=True) | |
if self.category: | |
self.log("train_f1_key0", f1[0], logger=True, prog_bar=True) | |
self.log("train_f1_key1", f1[1], logger=True, prog_bar=True) | |
self.log("train_f1_key2", f1[2], logger=True, prog_bar=True) | |
else: | |
self.log("train_f1", f1, logger=True, prog_bar=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
ids, mask, token_type_ids, labels = batch['ids'], batch['mask'], batch['token_type_ids'], batch['targets'] | |
logits, correct_count = self.forward(ids, mask, token_type_ids, labels) | |
loss = self.loss(logits, labels.long()) | |
f1 = self.monitor_metrics(logits, labels)["f1"] | |
self.log("val_loss", loss, logger=True, prog_bar=True) | |
self.log("val_acc", correct_count.float() / len(labels), logger=True, prog_bar=True) | |
if self.category: | |
self.log("val_f1_key0", f1[0], logger=True, prog_bar=True) | |
self.log("val_f1_key1", f1[1], logger=True, prog_bar=True) | |
self.log("val_f1_key2", f1[2], logger=True, prog_bar=True) | |
else: | |
self.log("val_f1", f1, logger=True, prog_bar=True) | |
def test_step(self, batch, batch_idx): | |
ids, mask, token_type_ids, labels = batch['ids'], batch['mask'], batch['token_type_ids'], batch['targets'] | |
logits, correct_count = self.forward(ids, mask, token_type_ids, labels) | |
loss = self.loss(logits, labels.long()) | |
f1 = self.monitor_metrics(logits, labels)["f1"] | |
self.log("test_loss", loss, logger=True, prog_bar=True) | |
self.log("test_acc", correct_count.float() / len(labels), logger=True, prog_bar=True) | |
if self.category: | |
self.log("test_f1_key0", f1[0], logger=True, prog_bar=True) | |
self.log("test_f1_key1", f1[1], logger=True, prog_bar=True) | |
self.log("test_f1_key2", f1[2], logger=True, prog_bar=True) | |
else: | |
self.log("test_f1", f1, logger=True, prog_bar=True) | |
return {"test_loss": loss, "logits": logits, "labels": labels} | |
def predict_step(self, batch, batch_idx, dataloader_idx): | |
ids, mask, token_type_ids, id = batch['ids'], batch['mask'], batch['token_type_ids'], batch['id'] | |
logits = self.predict(ids, mask, token_type_ids) | |
return {'id': id.cpu().numpy().tolist(), 'logits': logits.cpu().numpy().tolist()} | |