Spaces:
Runtime error
Runtime error
import torch | |
import os | |
import argparse | |
import json | |
import pytorch_lightning as pl | |
from fengshen.models.model_utils import add_module_args | |
from fengshen.data.task_dataloader.task_datasets import AbstractCollator | |
from fengshen.data.universal_datamodule import UniversalDataModule | |
from fengshen.utils.universal_checkpoint import UniversalCheckpoint | |
from fengshen.utils.utils import chinese_char_tokenize | |
from torchmetrics.text.rouge import ROUGEScore | |
from pytorch_lightning import Trainer, loggers | |
from pytorch_lightning.callbacks import LearningRateMonitor | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import sys | |
sys.path.append('../../../') | |
# os.environ["CUDA_VISIBLE_DEVICES"] = '3,4' | |
class FinetuneSummary(pl.LightningModule): | |
def add_model_specific_args(parent_args): | |
parser = parent_args.add_argument_group('BaseModel') | |
parser.add_argument('--rouge_keys', default='rougeL,rouge1,rouge2', type=str) | |
return parent_args | |
def __init__(self, args, tokenizer=None): | |
super().__init__() | |
self.save_hyperparameters(args) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
args.pretrained_model_path) | |
self.tokenizer = tokenizer | |
assert self.tokenizer, "tokenizer is None!" | |
self.rouge_keys = tuple(args.rouge_keys.split(',')) | |
self.rouge_metric = ROUGEScore(rouge_keys=self.rouge_keys, normalizer=lambda x: x) | |
def setup(self, stage) -> None: | |
if stage == 'fit': | |
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() | |
# Calculate total steps | |
tb_size = self.hparams.train_batchsize * max(1, self.trainer.gpus) | |
ab_size = self.trainer.accumulate_grad_batches * \ | |
float(self.trainer.max_epochs) | |
self.total_steps = ( | |
len(train_loader.dataset) // tb_size) // ab_size | |
print('total_steps is :', self.total_steps) | |
def training_step(self, batch, batch_idx): | |
output = self.model(input_ids=batch['input_ids'], | |
attention_mask=batch['attention_mask'], labels=batch['labels']) | |
self.log('train_loss', output.loss, sync_dist=True) | |
return output.loss | |
def on_validation_start(self) -> None: | |
# rm file at validation start | |
prefix, ext = os.path.splitext(self.hparams.output_save_path) | |
file_path_rank = '{}_{}{}'.format( | |
prefix, self.trainer._accelerator_connector.cluster_environment.global_rank(), ext) | |
if os.path.exists(file_path_rank): | |
print('rm {}'.format(file_path_rank)) | |
os.remove(file_path_rank) | |
def validation_step(self, batch, batch_idx): | |
output = self.model(input_ids=batch['input_ids'], | |
attention_mask=batch['attention_mask'], labels=batch['labels']) | |
generated_ids = self.model.generate( | |
input_ids=batch['input_ids'], | |
attention_mask=batch['attention_mask'], | |
max_length=self.hparams.max_dec_length | |
) | |
preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
labels = torch.where(batch['labels'] != -100, batch['labels'], | |
self.tokenizer.pad_token_id) | |
labels = self.tokenizer.batch_decode( | |
labels, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
# save preds for every rank | |
prefix, ext = os.path.splitext(self.hparams.output_save_path) | |
file_path_rank = '{}_{}{}'.format( | |
prefix, self.trainer._accelerator_connector.cluster_environment.global_rank(), ext) | |
self.save_prediction_to_file(preds=preds, texts=batch['text'], | |
summarys=batch['summary'], file_path=file_path_rank) | |
# you need to split chinese char with space for rouge metric | |
new_preds = [chinese_char_tokenize(p) for p in preds] | |
new_labels = [chinese_char_tokenize(label) for label in labels] | |
# update metric | |
self.rouge_metric.update(preds=new_preds, target=new_labels) | |
self.log('val_loss', output.loss, sync_dist=True) | |
def validation_epoch_end(self, outputs): | |
# compute metric for all process | |
rouge_dict = self.rouge_metric.compute() | |
# reset the metric after once validation | |
self.rouge_metric.reset() | |
for k, v in rouge_dict.items(): | |
self.log('val_{}'.format(k), v, sync_dist=True) | |
if self.trainer._accelerator_connector.cluster_environment.global_rank() == 0: | |
print('rouge:\n', rouge_dict) | |
def on_save_checkpoint(self, checkpoint) -> None: | |
if self.trainer._accelerator_connector.cluster_environment.global_rank() == 0: | |
self.model.save_pretrained(os.path.join( | |
self.trainer.checkpoint_callback.dirpath, | |
'hf_pretrained_epoch{}_step{}'.format(checkpoint['epoch'], checkpoint['global_step']))) | |
def save_prediction_to_file(self, preds, texts, summarys, file_path): | |
with open(file_path, 'a', encoding='utf-8') as f: | |
for idx, pred in enumerate(preds): | |
text = texts[idx] | |
summary = summarys[idx] | |
tmp_result = dict() | |
tmp_result['pred'] = pred | |
tmp_result['label'] = summary | |
tmp_result['text'] = text | |
json_data = json.dumps(tmp_result, ensure_ascii=False) | |
f.write(json_data + '\n') | |
def predict_step(self, batch, batch_idx): | |
# print(batch) | |
texts = batch['text'] | |
# output summary and metrics | |
generated_ids = self.model.generate( | |
input_ids=batch['input_ids'], | |
attention_mask=batch['attention_mask'], | |
max_length=self.hparams.max_dec_length | |
) | |
preds = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
labels = self.tokenizer.batch_decode( | |
batch['labels'], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
print(batch_idx, len(preds), len(labels)) | |
self.save_prediction_to_file(preds, texts, labels) | |
def configure_optimizers(self): | |
from fengshen.models.model_utils import configure_optimizers | |
return configure_optimizers(self) | |
def main(): | |
total_parser = argparse.ArgumentParser("Summary Task") | |
total_parser.add_argument('--do_eval_only', | |
action='store_true', | |
default=False) | |
total_parser.add_argument('--pretrained_model_path', | |
default='google/mt5-small', | |
type=str) | |
total_parser.add_argument('--output_save_path', | |
default='./predict.json', | |
type=str) | |
total_parser.add_argument('--self_tokenizer', | |
action='store_true', | |
default=False) | |
total_parser.add_argument('--max_enc_length', default=1024, type=int) | |
total_parser.add_argument('--max_dec_length', default=256, type=int) | |
total_parser.add_argument('--prompt', default='summarize:', type=str) | |
# * Args for data preprocessing | |
# from fengshen.data.task_dataloader.task_datasets import LCSTSDataModel | |
total_parser = UniversalDataModule.add_data_specific_args(total_parser) | |
# * Args for training | |
total_parser = add_module_args(total_parser) | |
total_parser = Trainer.add_argparse_args(total_parser) | |
total_parser = UniversalCheckpoint.add_argparse_args(total_parser) | |
total_parser = FinetuneSummary.add_model_specific_args(total_parser) | |
# * Args for base model | |
args = total_parser.parse_args() | |
if args.self_tokenizer: | |
from fengshen.examples.pegasus.tokenizers_pegasus import PegasusTokenizer | |
tokenizer = PegasusTokenizer.from_pretrained(args.pretrained_model_path) | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, use_fast=False) | |
collator = AbstractCollator(tokenizer, args.max_enc_length, | |
args.max_dec_length, args.prompt) | |
data_model = UniversalDataModule(tokenizer=tokenizer, args=args, collate_fn=collator) | |
model = FinetuneSummary(args, tokenizer) | |
if not args.do_eval_only: | |
lr_monitor = LearningRateMonitor(logging_interval='step') | |
logger = loggers.TensorBoardLogger(save_dir=os.path.join( | |
args.default_root_dir, 'log/')) | |
checkpoint_callback = UniversalCheckpoint(args) | |
trainer = Trainer.from_argparse_args(args, | |
logger=logger, | |
callbacks=[lr_monitor, | |
checkpoint_callback] | |
) | |
trainer.fit(model, data_model) | |
else: | |
trainer = Trainer.from_argparse_args(args) | |
# trainer.predict(model, data_model) | |
trainer.validate(model, data_model) | |
if __name__ == '__main__': | |
main() | |