import os import math import json import random import argparse import numpy as np import time import torch from torch.profiler import profile, record_function, ProfilerActivity import torch.distributed as dist import pytorch_lightning as pl from pytorch_lightning import LightningModule, LightningDataModule from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.strategies.ddp import DDPStrategy from transformers import get_scheduler import transformers from dataset import NERDataset, get_collate_fn from model import build_model from utils import get_class_to_index import evaluate from seqeval.metrics import accuracy_score from seqeval.metrics import classification_report from seqeval.metrics import f1_score from seqeval.scheme import IOB2 def get_args(notebook=False): parser = argparse.ArgumentParser() parser.add_argument('--do_train', action='store_true') parser.add_argument('--do_valid', action='store_true') parser.add_argument('--do_test', action='store_true') parser.add_argument('--fp16', action='store_true') parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpus', type=int, default=1) parser.add_argument('--print_freq', type=int, default=200) parser.add_argument('--debug', action='store_true') parser.add_argument('--no_eval', action='store_true') # Data parser.add_argument('--data_path', type=str, default=None) parser.add_argument('--image_path', type=str, default=None) parser.add_argument('--train_file', type=str, default=None) parser.add_argument('--valid_file', type=str, default=None) parser.add_argument('--test_file', type=str, default=None) parser.add_argument('--vocab_file', type=str, default=None) parser.add_argument('--format', type=str, default='reaction') parser.add_argument('--num_workers', type=int, default=8) parser.add_argument('--input_size', type=int, default=224) # Training parser.add_argument('--epochs', type=int, default=8) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--weight_decay', type=float, default=0.05) parser.add_argument('--max_grad_norm', type=float, default=5.) parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine') parser.add_argument('--warmup_ratio', type=float, default=0) parser.add_argument('--gradient_accumulation_steps', type=int, default=1) parser.add_argument('--load_path', type=str, default=None) parser.add_argument('--load_encoder_only', action='store_true') parser.add_argument('--train_steps_per_epoch', type=int, default=-1) parser.add_argument('--eval_per_epoch', type=int, default=10) parser.add_argument('--save_path', type=str, default='output/') parser.add_argument('--save_mode', type=str, default='best', choices=['best', 'all', 'last']) parser.add_argument('--load_ckpt', type=str, default='best') parser.add_argument('--resume', action='store_true') parser.add_argument('--num_train_example', type=int, default=None) parser.add_argument('--roberta_checkpoint', type=str, default = "roberta-base") parser.add_argument('--corpus', type=str, default = "chemu") parser.add_argument('--cache_dir') parser.add_argument('--eval_truncated', action='store_true') parser.add_argument('--max_seq_length', type = int, default=512) args = parser.parse_args([]) if notebook else parser.parse_args() return args class ChemIENERecognizer(LightningModule): def __init__(self, args): super().__init__() self.args = args self.model = build_model(args) self.validation_step_outputs = [] def training_step(self, batch, batch_idx): sentences, masks, refs,_ = batch ''' print("sentences " + str(sentences)) print("sentence shape " + str(sentences.shape)) print("masks " + str(masks)) print("masks shape " + str(masks.shape)) print("refs " + str(refs)) print("refs shape " + str(refs.shape)) ''' loss, logits = self.model(input_ids=sentences, attention_mask=masks, labels=refs) self.log('train/loss', loss) self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False) return loss def validation_step(self, batch, batch_idx): sentences, masks, refs, untruncated = batch ''' print("sentences " + str(sentences)) print("sentence shape " + str(sentences.shape)) print("masks " + str(masks)) print("masks shape " + str(masks.shape)) print("refs " + str(refs)) print("refs shape " + str(refs.shape)) ''' logits = self.model(input_ids = sentences, attention_mask=masks)[0] ''' print("logits " + str(logits)) print(sentences.shape) print(logits.shape) print(torch.eq(logits.argmax(dim = 2), refs).sum()) ''' self.validation_step_outputs.append((sentences.to("cpu"), logits.argmax(dim = 2).to("cpu"), refs.to('cpu'), untruncated.to("cpu"))) def on_validation_epoch_end(self): if self.trainer.num_devices > 1: gathered_outputs = [None for i in range(self.trainer.num_devices)] dist.all_gather_object(gathered_outputs, self.validation_step_outputs) gathered_outputs = sum(gathered_outputs, []) else: gathered_outputs = self.validation_step_outputs sentences = [list(output[0]) for output in gathered_outputs] class_to_index = get_class_to_index(self.args.corpus) index_to_class = {class_to_index[key]: key for key in class_to_index} predictions = [list(output[1]) for output in gathered_outputs] labels = [list(output[2]) for output in gathered_outputs] untruncateds = [list(output[3]) for output in gathered_outputs] untruncateds = [[index_to_class[int(label.item())] for label in sentence if int(label.item()) != -100] for batched in untruncateds for sentence in batched] output = {"sentences": [[int(word.item()) for (word, label) in zip(sentence_w, sentence_l) if label != -100] for (batched_w, batched_l) in zip(sentences, labels) for (sentence_w, sentence_l) in zip(batched_w, batched_l) ], "predictions": [[index_to_class[int(pred.item())] for (pred, label) in zip(sentence_p, sentence_l) if label!=-100] for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) ], "groundtruth": [[index_to_class[int(label.item())] for label in sentence if label != -100] for batched in labels for sentence in batched]} #true_labels = [str(label.item()) for batched in labels for sentence in batched for label in sentence if label != -100] #true_predictions = [str(pred.item()) for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) for (pred, label) in zip(sentence_p, sentence_l) if label!=-100 ] #print("true_label " + str(len(true_labels)) + " true_predictions "+str(len(true_predictions))) #predictions = utils.merge_predictions(gathered_outputs) name = self.eval_dataset.name scores = [0] #print(predictions) #print(predictions[0].shape) if self.trainer.is_global_zero: if not self.args.no_eval: epoch = self.trainer.current_epoch metric = evaluate.load("seqeval", cache_dir = self.args.cache_dir) predictions = [ preds + ['O'] * (len(full_groundtruth) - len(preds)) for (preds, full_groundtruth) in zip(output['predictions'], untruncateds)] all_metrics = metric.compute(predictions = predictions, references = untruncateds) #accuracy = sum([1 if p == l else 0 for (p, l) in zip(true_predictions, true_labels)])/len(true_labels) #precision = torch.eq(self.eval_dataset.data, predictions.argmax(dim = 1)).sum().float()/self.eval_dataset.data.numel() #self.print("Epoch: "+str(epoch)+" accuracy: "+str(accuracy)) if self.args.eval_truncated: report = classification_report(output['groundtruth'], output['predictions'], mode = 'strict', scheme = IOB2, output_dict = True) else: #report = classification_report(predictions, untruncateds, output_dict = True)#, mode = 'strict', scheme = IOB2, output_dict = True) report = classification_report(predictions, untruncateds, mode = 'strict', scheme = IOB2, output_dict = True) self.print(report) #self.print("______________________________________________") #self.print(report_strict) scores = [report['micro avg']['f1-score']] with open(os.path.join(self.trainer.default_root_dir, f'prediction_{name}.json'), 'w') as f: json.dump(output, f) dist.broadcast_object_list(scores) self.log('val/score', scores[0], prog_bar=True, rank_zero_only=True) self.validation_step_outputs.clear() self.validation_step_outputs.clear() def configure_optimizers(self): num_training_steps = self.trainer.num_training_steps self.print(f'Num training steps: {num_training_steps}') num_warmup_steps = int(num_training_steps * self.args.warmup_ratio) optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps) return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}} class NERDataModule(LightningDataModule): def __init__(self, args): super().__init__() self.args = args self.collate_fn = get_collate_fn() def prepare_data(self): args = self.args if args.do_train: self.train_dataset = NERDataset(args, args.train_file, split='train') if self.args.do_train or self.args.do_valid: self.val_dataset = NERDataset(args, args.valid_file, split='valid') if self.args.do_test: self.test_dataset = NERDataset(args, args.test_file, split='valid') def print_stats(self): if self.args.do_train: print(f'Train dataset: {len(self.train_dataset)}') if self.args.do_train or self.args.do_valid: print(f'Valid dataset: {len(self.val_dataset)}') if self.args.do_test: print(f'Test dataset: {len(self.test_dataset)}') def train_dataloader(self): return torch.utils.data.DataLoader( self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, collate_fn=self.collate_fn) def val_dataloader(self): return torch.utils.data.DataLoader( self.val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, collate_fn=self.collate_fn) def test_dataloader(self): return torch.utils.data.DataLoader( self.test_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers, collate_fn=self.collate_fn) class ModelCheckpoint(pl.callbacks.ModelCheckpoint): def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str: filepath = self.format_checkpoint_name(monitor_candidates) return filepath def main(): transformers.utils.logging.set_verbosity_error() args = get_args() pl.seed_everything(args.seed, workers = True) if args.do_train: model = ChemIENERecognizer(args) else: model = ChemIENERecognizer.load_from_checkpoint(os.path.join(args.save_path, 'checkpoints/best.ckpt'), strict=False, args=args) dm = NERDataModule(args) dm.prepare_data() dm.print_stats() checkpoint = ModelCheckpoint(monitor='val/score', mode='max', save_top_k=1, filename='best', save_last=True) # checkpoint = ModelCheckpoint(monitor=None, save_top_k=0, save_last=True) lr_monitor = LearningRateMonitor(logging_interval='step') logger = pl.loggers.TensorBoardLogger(args.save_path, name='', version='') trainer = pl.Trainer( strategy=DDPStrategy(find_unused_parameters=False), accelerator='gpu', precision = 16, devices=args.gpus, logger=logger, default_root_dir=args.save_path, callbacks=[checkpoint, lr_monitor], max_epochs=args.epochs, gradient_clip_val=args.max_grad_norm, accumulate_grad_batches=args.gradient_accumulation_steps, check_val_every_n_epoch=args.eval_per_epoch, log_every_n_steps=10, deterministic='warn') if args.do_train: trainer.num_training_steps = math.ceil( len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs model.eval_dataset = dm.val_dataset ckpt_path = os.path.join(args.save_path, 'checkpoints/last.ckpt') if args.resume else None trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path) model = ChemIENERecognizer.load_from_checkpoint(checkpoint.best_model_path, args=args) if args.do_valid: model.eval_dataset = dm.val_dataset trainer.validate(model, datamodule=dm) if args.do_test: model.test_dataset = dm.test_dataset trainer.test(model, datamodule=dm) if __name__ == "__main__": main()