Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from transformers import AutoTokenizer, BartForConditionalGeneration, BartConfig | |
| from pytorch_lightning import ( | |
| LightningModule, | |
| Trainer, | |
| ) | |
| from pytorch_lightning.callbacks import LearningRateMonitor | |
| from dataclasses import dataclass | |
| import os | |
| import argparse | |
| import torch | |
| import math | |
| import time | |
| from torch.utils.data._utils.collate import default_collate | |
| from fengshen.data.data_utils.mask_utils import create_masked_lm_predictions | |
| from fengshen.data.universal_datamodule import UniversalDataModule | |
| from fengshen.utils import UniversalCheckpoint | |
| from fengshen.models.model_utils import ( | |
| get_total_steps, | |
| configure_optimizers, | |
| add_module_args, | |
| ) | |
| import numpy as np | |
| SHOW_DATA = False | |
| class BartCollator: | |
| ''' | |
| 由input处理成samples,也就是最终模型的输入 | |
| 其中主要处理逻辑在__call__里 | |
| 包含text infilling和sentence shuffle任务 | |
| ''' | |
| tokenizer: None # 分词 | |
| max_seq_length: 512 | |
| masked_lm_prob: 0.15 | |
| permute_sentence_ratio: 1.0 | |
| content_key: str = 'text' | |
| def setup(self): | |
| from fengshen.data.data_utils.sentence_split import ChineseSentenceSplitter | |
| self.sentence_split = ChineseSentenceSplitter() | |
| self.np_rng = np.random.RandomState(seed=((int(time.time()) % 2**32))) | |
| inv_vocab = {v: k for k, v in self.tokenizer.vocab.items()} | |
| self.vocab_id_list = list(inv_vocab.keys()) | |
| self.vocab_id_to_token_dict = inv_vocab | |
| import jieba_fast | |
| self.zh_tokenizer = jieba_fast.lcut | |
| seg_tokens = ['。', ';', ';', '!', '!', '?', '?'] | |
| seg_token_ids = [] | |
| for t in seg_tokens: | |
| if t in self.tokenizer.vocab: | |
| seg_token_ids.append(self.tokenizer.vocab[t]) | |
| else: | |
| print('seg_token "{}" not in vocab'.format(t)) | |
| self.seg_token_ids = set(seg_token_ids) | |
| def permute_sentences(self, source, full_stops, p=1.0): | |
| # Tokens that are full stops, where the previous token is not | |
| sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2 | |
| result = source.clone() | |
| num_sentences = sentence_ends.size(0) | |
| num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0) | |
| substitutions = torch.randperm(num_sentences)[:num_to_permute] | |
| ordering = torch.arange(0, num_sentences) | |
| ordering[substitutions] = substitutions[torch.randperm(num_to_permute)] | |
| # Ignore <bos> at start | |
| index = 1 | |
| for i in ordering: | |
| sentence = source[(sentence_ends[i - 1] if i > 0 else 1): sentence_ends[i]] | |
| result[index: index + sentence.size(0)] = sentence | |
| index += sentence.size(0) | |
| return result | |
| def __call__(self, samples): | |
| ''' | |
| samples: 一个sample长这样{"text": "hello world"} | |
| ''' | |
| model_inputs = [] | |
| for s in samples: | |
| sentences = self.sentence_split.tokenize(s[self.content_key]) | |
| tokenized_sentences = [self.tokenizer.convert_tokens_to_ids( | |
| self.tokenizer.tokenize(sent)) for sent in sentences] | |
| if len(tokenized_sentences) == 0: | |
| print('find empty sentence') | |
| continue | |
| tokens = [self.tokenizer.cls_token_id] | |
| for sent in tokenized_sentences: | |
| for t in sent: | |
| tokens.append(t) | |
| if tokens[-1] != self.tokenizer.sep_token_id: | |
| tokens.append(self.tokenizer.sep_token_id) | |
| if len(tokens) > self.max_seq_length: | |
| # 找到最后的一句话,如果有的话,尽量保证最后一句话的完整 | |
| last_pos = self.max_seq_length - 1 | |
| for i in range(self.max_seq_length - 1, 0, -1): | |
| if tokens[i-1] in self.seg_token_ids: | |
| last_pos = i | |
| break | |
| tokens = tokens[:last_pos] | |
| tokens.append(self.tokenizer.sep_token_id) | |
| tokens = torch.LongTensor(tokens) | |
| full_stops = torch.any(torch.stack([torch.eq(tokens, aelem).logical_or_( | |
| torch.eq(tokens, aelem)) for aelem in self.seg_token_ids], dim=0), dim=0) | |
| assert (self.max_seq_length - | |
| tokens.shape[0]) >= 0, (tokens.size(), tokens[-1], self.max_seq_length) | |
| source, target = tokens, tokens.clone() | |
| if self.permute_sentence_ratio > 0.0: | |
| source = self.permute_sentences(source, full_stops, self.permute_sentence_ratio) | |
| if self.masked_lm_prob > 0.0: | |
| mask_prob = self.masked_lm_prob * 2 | |
| max_predictions_per_seq = mask_prob * len(source) | |
| (source, _, _, _, _) = create_masked_lm_predictions( | |
| source.numpy(), self.vocab_id_list, self.vocab_id_to_token_dict, mask_prob, | |
| self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.mask_token_id, | |
| max_predictions_per_seq, self.np_rng, | |
| masking_style='bert', zh_tokenizer=self.zh_tokenizer) | |
| # 合并[MASK] 因为这里用的是Bert的mask函数,Bert是按字mask的, | |
| # 这里把连续的mask合并成一个MASK从而达到span mask的效果 | |
| span_mask_souce = [] | |
| for t in source: | |
| # 如果是连续的多个mask,则跳过 | |
| if len(span_mask_souce) > 0 \ | |
| and t is self.tokenizer.mask_token_id \ | |
| and span_mask_souce[-1] is self.tokenizer.mask_token_id: | |
| continue | |
| span_mask_souce.append(t) | |
| source = torch.LongTensor(span_mask_souce) | |
| assert (source >= 0).all() | |
| # assert (source[1:-1] >= 1).all(), source | |
| assert (source <= self.tokenizer.vocab_size).all() | |
| assert source[0] == self.tokenizer.cls_token_id | |
| assert source[-1] == self.tokenizer.sep_token_id | |
| prev_output_tokens = torch.zeros_like(target) | |
| # match the preprocessing in fairseq | |
| prev_output_tokens[0] = self.tokenizer.sep_token_id | |
| prev_output_tokens[1:] = target[:-1] | |
| source_ = torch.full((self.max_seq_length,), | |
| self.tokenizer.pad_token_id, dtype=torch.long) | |
| source_[:source.shape[0]] = source | |
| target_ = torch.full((self.max_seq_length,), -100, dtype=torch.long) | |
| target_[:target.shape[0]] = target | |
| prev_output_tokens_ = torch.full( | |
| (self.max_seq_length,), self.tokenizer.pad_token_id, dtype=torch.long) | |
| prev_output_tokens_[:prev_output_tokens.shape[0]] = prev_output_tokens | |
| attention_mask = torch.full((self.max_seq_length,), 0, dtype=torch.long) | |
| attention_mask[:source.shape[0]] = 1 | |
| model_inputs.append({ | |
| "input_ids": source_, | |
| "labels": target_, | |
| "decoder_input_ids": prev_output_tokens_, | |
| "attention_mask": attention_mask, | |
| }) | |
| return default_collate(model_inputs) | |
| class RandengBart(LightningModule): | |
| def add_module_specific_args(parent_parser): | |
| parser = parent_parser.add_argument_group('Randeng BART') | |
| parser.add_argument('--masked_lm_prob', type=float, default=0.15) | |
| parser.add_argument('--max_seq_length', type=int, default=512) | |
| parser.add_argument('--sample_content_key', type=str, default='text') | |
| parser.add_argument('--permute_sentence_ratio', type=str, default=1.0) | |
| return parent_parser | |
| def __init__(self, args, tokenizer, **kwargs) -> None: | |
| super().__init__() | |
| self.save_hyperparameters(args) | |
| config = BartConfig.from_pretrained(args.model_path) | |
| self.model = BartForConditionalGeneration(config) | |
| self.tokenizer = tokenizer | |
| def setup(self, stage) -> None: | |
| if stage == 'fit': | |
| self.total_steps = get_total_steps(self.trainer, self.hparams) | |
| def configure_optimizers(self): | |
| return configure_optimizers(self) | |
| def detokenize(self, token_ids): | |
| toks = self.tokenizer.convert_ids_to_tokens(token_ids) | |
| return self.tokenizer.convert_tokens_to_string(toks) | |
| def training_step(self, batch, batch_idx): | |
| if self.trainer.global_rank == 0: | |
| global SHOW_DATA | |
| if not SHOW_DATA: | |
| SHOW_DATA = True | |
| print('source: {}'.format(batch['input_ids'][0])) | |
| print('target: {}'.format(batch['labels'][0])) | |
| print('decoder source: {}'.format(batch['decoder_input_ids'][0])) | |
| print('source: {}'.format(self.detokenize(batch['input_ids'][0]))) | |
| print('decoder source: {}'.format(self.detokenize(batch['decoder_input_ids'][0]))) | |
| label_idx = batch['labels'][0] != -100 | |
| print('target: {}'.format(self.detokenize( | |
| batch['labels'][0][label_idx]))) | |
| output = self.model(**batch) | |
| acc = self.comput_metrix(output.logits, batch['labels']) | |
| self.log('train_loss', output.loss, sync_dist=True) | |
| self.log('train_acc', acc, sync_dist=True) | |
| return output.loss | |
| def comput_metrix(self, logits, labels): | |
| label_idx = labels != -100 | |
| labels = labels[label_idx] | |
| logits = logits[label_idx].view(-1, logits.size(-1)) | |
| y_pred = torch.argmax(logits, dim=-1) | |
| y_pred = y_pred.view(size=(-1,)) | |
| y_true = labels.view(size=(-1,)).float() | |
| corr = torch.eq(y_pred, y_true) | |
| acc = torch.sum(corr.float())/labels.shape[0] | |
| return acc | |
| def validation_step(self, batch, batch_idx): | |
| output = self.model(**batch) | |
| acc = self.comput_metrix(output.logits, batch['labels']) | |
| self.log('val_loss', output.loss, sync_dist=True) | |
| self.log('val_acc', acc, sync_dist=True) | |
| def on_load_checkpoint(self, checkpoint) -> None: | |
| # 兼容低版本lightning,低版本lightning从ckpt起来时steps数会被重置为0 | |
| global_step_offset = checkpoint["global_step"] | |
| if 'global_samples' in checkpoint: | |
| self.consumed_samples = checkpoint['global_samples'] | |
| self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset | |
| if __name__ == '__main__': | |
| args_parser = argparse.ArgumentParser() | |
| args_parser = add_module_args(args_parser) | |
| args_parser = UniversalDataModule.add_data_specific_args(args_parser) | |
| args_parser = Trainer.add_argparse_args(args_parser) | |
| args_parser = RandengBart.add_module_specific_args(args_parser) | |
| args_parser = UniversalCheckpoint.add_argparse_args(args_parser) | |
| args = args_parser.parse_args() | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) | |
| collator = BartCollator( | |
| tokenizer=tokenizer, | |
| max_seq_length=args.max_seq_length, | |
| masked_lm_prob=args.masked_lm_prob, | |
| content_key=args.sample_content_key, | |
| permute_sentence_ratio=args.permute_sentence_ratio, | |
| ) | |
| # 准备一些额外参数 | |
| collator.setup() | |
| data_module = UniversalDataModule(tokenizer=tokenizer, args=args, collate_fn=collator) | |
| module = RandengBart(args, tokenizer=tokenizer) | |
| lr_monitor = LearningRateMonitor(logging_interval='step') | |
| checkpoint_callback = UniversalCheckpoint(args) | |
| # 做兼容,如果目录不存在的话把这个参数去掉,不然会报错 | |
| if args.load_ckpt_path is not None and \ | |
| not os.path.exists(args.load_ckpt_path): | |
| print('--------warning no checkpoint found--------, remove args') | |
| args.load_ckpt_path = None | |
| trainer = Trainer.from_argparse_args(args, | |
| callbacks=[ | |
| lr_monitor, | |
| checkpoint_callback]) | |
| trainer.fit(module, data_module, ckpt_path=args.load_ckpt_path) | |
