import torch import os import argparse import json import pytorch_lightning as pl from fengshen.models.model_utils import add_module_args from fengshen.data.task_dataloader.task_datasets import AbstractCollator from fengshen.data.universal_datamodule import UniversalDataModule from fengshen.utils.universal_checkpoint import UniversalCheckpoint from fengshen.utils.utils import chinese_char_tokenize from torchmetrics.text.rouge import ROUGEScore from pytorch_lightning import Trainer, loggers from pytorch_lightning.callbacks import LearningRateMonitor from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import sys sys.path.append('../../../') # os.environ["CUDA_VISIBLE_DEVICES"] = '3,4' class FinetuneSummary(pl.LightningModule): @staticmethod def add_model_specific_args(parent_args): parser = parent_args.add_argument_group('BaseModel') parser.add_argument('--rouge_keys', default='rougeL,rouge1,rouge2', type=str) return parent_args def __init__(self, args, tokenizer=None): super().__init__() self.save_hyperparameters(args) self.model = AutoModelForSeq2SeqLM.from_pretrained( args.pretrained_model_path) self.tokenizer = tokenizer assert self.tokenizer, "tokenizer is None!" self.rouge_keys = tuple(args.rouge_keys.split(',')) self.rouge_metric = ROUGEScore(rouge_keys=self.rouge_keys, normalizer=lambda x: x) def setup(self, stage) -> None: if stage == 'fit': train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() # Calculate total steps tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus) ab_size = self.trainer.accumulate_grad_batches * \ float(self.trainer.max_epochs) self.total_steps = ( len(train_loader.dataset) // tb_size) // ab_size print('total_steps is :', self.total_steps) def training_step(self, batch, batch_idx): output = self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels']) self.log('train_loss', output.loss, sync_dist=True) return output.loss def on_validation_start(self) -> None: # rm file at validation start prefix, ext = os.path.splitext(self.hparams.output_save_path) file_path_rank = '{}_{}{}'.format( prefix, self.trainer._accelerator_connector.cluster_environment.global_rank(), ext) if os.path.exists(file_path_rank): print('rm {}'.format(file_path_rank)) os.remove(file_path_rank) def validation_step(self, batch, batch_idx): output = self.model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels']) generated_ids = self.model.generate( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], max_length=self.hparams.max_dec_length ) preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) labels = torch.where(batch['labels'] != -100, batch['labels'], self.tokenizer.pad_token_id) labels = self.tokenizer.batch_decode( labels, skip_special_tokens=True, clean_up_tokenization_spaces=True) # save preds for every rank prefix, ext = os.path.splitext(self.hparams.output_save_path) file_path_rank = '{}_{}{}'.format( prefix, self.trainer._accelerator_connector.cluster_environment.global_rank(), ext) self.save_prediction_to_file(preds=preds, texts=batch['text'], summarys=batch['summary'], file_path=file_path_rank) # you need to split chinese char with space for rouge metric new_preds = [chinese_char_tokenize(p) for p in preds] new_labels = [chinese_char_tokenize(label) for label in labels] # update metric self.rouge_metric.update(preds=new_preds, target=new_labels) self.log('val_loss', output.loss, sync_dist=True) def validation_epoch_end(self, outputs): # compute metric for all process rouge_dict = self.rouge_metric.compute() # reset the metric after once validation self.rouge_metric.reset() for k, v in rouge_dict.items(): self.log('val_{}'.format(k), v, sync_dist=True) if self.trainer._accelerator_connector.cluster_environment.global_rank() == 0: print('rouge:\n', rouge_dict) def on_save_checkpoint(self, checkpoint) -> None: if self.trainer._accelerator_connector.cluster_environment.global_rank() == 0: self.model.save_pretrained(os.path.join( self.trainer.checkpoint_callback.dirpath, 'hf_pretrained_epoch{}_step{}'.format(checkpoint['epoch'], checkpoint['global_step']))) def save_prediction_to_file(self, preds, texts, summarys, file_path): with open(file_path, 'a', encoding='utf-8') as f: for idx, pred in enumerate(preds): text = texts[idx] summary = summarys[idx] tmp_result = dict() tmp_result['pred'] = pred tmp_result['label'] = summary tmp_result['text'] = text json_data = json.dumps(tmp_result, ensure_ascii=False) f.write(json_data + '\n') def predict_step(self, batch, batch_idx): # print(batch) texts = batch['text'] # output summary and metrics generated_ids = self.model.generate( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], max_length=self.hparams.max_dec_length ) preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) labels = self.tokenizer.batch_decode( batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=True) print(batch_idx, len(preds), len(labels)) self.save_prediction_to_file(preds, texts, labels) def configure_optimizers(self): from fengshen.models.model_utils import configure_optimizers return configure_optimizers(self) def main(): total_parser = argparse.ArgumentParser("Summary Task") total_parser.add_argument('--do_eval_only', action='store_true', default=False) total_parser.add_argument('--pretrained_model_path', default='google/mt5-small', type=str) total_parser.add_argument('--output_save_path', default='./predict.json', type=str) total_parser.add_argument('--self_tokenizer', action='store_true', default=False) total_parser.add_argument('--max_enc_length', default=1024, type=int) total_parser.add_argument('--max_dec_length', default=256, type=int) total_parser.add_argument('--prompt', default='summarize:', type=str) # * Args for data preprocessing # from fengshen.data.task_dataloader.task_datasets import LCSTSDataModel total_parser = UniversalDataModule.add_data_specific_args(total_parser) # * Args for training total_parser = add_module_args(total_parser) total_parser = Trainer.add_argparse_args(total_parser) total_parser = UniversalCheckpoint.add_argparse_args(total_parser) total_parser = FinetuneSummary.add_model_specific_args(total_parser) # * Args for base model args = total_parser.parse_args() if args.self_tokenizer: from fengshen.examples.pegasus.tokenizers_pegasus import PegasusTokenizer tokenizer = PegasusTokenizer.from_pretrained(args.pretrained_model_path) else: tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, use_fast=False) collator = AbstractCollator(tokenizer, args.max_enc_length, args.max_dec_length, args.prompt) data_model = UniversalDataModule(tokenizer=tokenizer, args=args, collate_fn=collator) model = FinetuneSummary(args, tokenizer) if not args.do_eval_only: lr_monitor = LearningRateMonitor(logging_interval='step') logger = loggers.TensorBoardLogger(save_dir=os.path.join( args.default_root_dir, 'log/')) checkpoint_callback = UniversalCheckpoint(args) trainer = Trainer.from_argparse_args(args, logger=logger, callbacks=[lr_monitor, checkpoint_callback] ) trainer.fit(model, data_model) else: trainer = Trainer.from_argparse_args(args) # trainer.predict(model, data_model) trainer.validate(model, data_model) if __name__ == '__main__': main()