|
|
|
|
|
import argparse |
|
import gc |
|
import os |
|
import sys |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
from finetune import SummarizationModule, TranslationModule |
|
from finetune import main as ft_main |
|
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise |
|
from torch import nn |
|
|
|
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration |
|
from transformers.models.bart.modeling_bart import shift_tokens_right |
|
from utils import calculate_bleu, check_output_dir, freeze_params, label_smoothed_nll_loss, use_task_specific_params |
|
|
|
|
|
|
|
sys.path.insert(2, str(Path(__file__).resolve().parents[1])) |
|
from lightning_base import generic_train |
|
|
|
|
|
class SummarizationDistiller(SummarizationModule): |
|
"""Supports T5, Bart, Pegasus and other models that inherit from Bart.""" |
|
|
|
loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"] |
|
|
|
def __init__(self, hparams): |
|
assert Path(hparams.data_dir).exists() |
|
self.output_dir = Path(hparams.output_dir) |
|
self.output_dir.mkdir(exist_ok=True) |
|
|
|
save_dir = self.output_dir.joinpath("student") |
|
|
|
hparams.model_name_or_path = str(save_dir) |
|
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval() |
|
use_task_specific_params(teacher, hparams.task) |
|
if hparams.student is not None: |
|
student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student) |
|
use_task_specific_params(student, hparams.task) |
|
e_layer_ids, d_layer_ids = None, None |
|
else: |
|
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers( |
|
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir |
|
) |
|
|
|
if hparams.length_penalty != -1: |
|
student.config.length_penalty = hparams.length_penalty |
|
hparams.tokenizer_name = hparams.teacher |
|
super().__init__(hparams, model=student, config=student.config) |
|
assert student.config.model_type == teacher.config.model_type, ( |
|
f"teacher, student model types should be the same, got {student.config.model_type} !=" |
|
f" {teacher.config.model_type}" |
|
) |
|
|
|
if student.config.model_type == "t5": |
|
student_encoder_layers = len(student.get_encoder().block) |
|
student_decoder_layers = len(student.get_decoder().block) |
|
teacher_encoder_layers = len(teacher.get_encoder().block) |
|
teacher_decoder_layers = len(teacher.get_decoder().block) |
|
else: |
|
student_encoder_layers = student.config.encoder_layers |
|
student_decoder_layers = student.config.decoder_layers |
|
teacher_encoder_layers = teacher.config.encoder_layers |
|
teacher_decoder_layers = teacher.config.decoder_layers |
|
|
|
self.different_base_models = not (hparams.student is None or hparams.teacher == hparams.student) |
|
self.do_calc_hidden_loss = (not self.different_base_models) and hparams.alpha_hid > 0 |
|
self.different_encoder = self.different_base_models or (student_encoder_layers != teacher_encoder_layers) |
|
|
|
self.teacher = teacher |
|
freeze_params(self.teacher) |
|
|
|
if not self.different_encoder: |
|
try: |
|
del self.teacher.model.encoder |
|
except AttributeError: |
|
del self.teacher.encoder |
|
|
|
if e_layer_ids is None: |
|
e_layer_ids = list(range(student_encoder_layers)) |
|
if d_layer_ids is None: |
|
d_layer_ids = list(range(student_decoder_layers)) |
|
|
|
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids |
|
|
|
if self.do_calc_hidden_loss: |
|
if hparams.supervise_forward: |
|
self.e_matches = get_layers_to_supervise( |
|
n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers |
|
) |
|
self.d_matches = get_layers_to_supervise( |
|
n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers |
|
) |
|
else: |
|
self.e_matches = self.e_layer_ids |
|
self.d_matches = self.d_layer_ids |
|
else: |
|
self.e_matches = None |
|
self.d_matches = None |
|
|
|
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") |
|
self.temperature = 2.0 |
|
self.alpha_mlm = hparams.alpha_mlm |
|
self.alpha_ce = hparams.alpha_ce |
|
self.alpha_hid = hparams.alpha_hid |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def calc_ce_loss(self, mask, s_logits, t_logits): |
|
"""Copy pasted from distillbert (transformers/examples/distillation/)""" |
|
|
|
sel_mask = mask[:, :, None].expand_as(s_logits) |
|
vocab_size = s_logits.size(-1) |
|
s_logits_slct = torch.masked_select(s_logits, sel_mask) |
|
t_logits_slct = torch.masked_select(t_logits, sel_mask) |
|
s_logits_slct = s_logits_slct.view(-1, vocab_size) |
|
t_logits_slct = t_logits_slct.view(-1, vocab_size) |
|
assert t_logits_slct.size() == s_logits_slct.size() |
|
loss_ce = ( |
|
self.ce_loss_fct( |
|
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1), |
|
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1), |
|
) |
|
* (self.temperature) ** 2 |
|
) |
|
return loss_ce |
|
|
|
@staticmethod |
|
def add_model_specific_args(parser, root_dir): |
|
SummarizationModule.add_model_specific_args(parser, root_dir) |
|
add_distill_args(parser) |
|
return parser |
|
|
|
def _step(self, batch: dict) -> tuple: |
|
"""Compute the loss for a batch""" |
|
pad_token_id = self.tokenizer.pad_token_id |
|
input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"] |
|
if isinstance(self.model, T5ForConditionalGeneration): |
|
decoder_input_ids = self.model._shift_right(labels) |
|
else: |
|
decoder_input_ids = shift_tokens_right(labels, pad_token_id) |
|
|
|
|
|
student_outputs = self( |
|
input_ids, |
|
attention_mask=src_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
output_hidden_states=self.do_calc_hidden_loss, |
|
output_attentions=False, |
|
use_cache=False, |
|
) |
|
lm_logits = student_outputs["logits"] |
|
|
|
|
|
assert lm_logits.shape[-1] == self.model.config.vocab_size |
|
if self.hparams.label_smoothing == 0: |
|
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id) |
|
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) |
|
else: |
|
lprobs = nn.functional.log_softmax(lm_logits, dim=-1) |
|
student_lm_loss, _ = label_smoothed_nll_loss( |
|
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id |
|
) |
|
|
|
def zero_tensor(): |
|
return torch.tensor(0.0).type_as(student_lm_loss) |
|
|
|
teacher_enc_outputs = student_outputs[ |
|
"encoder_last_hidden_state" |
|
] |
|
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor() |
|
if self.different_encoder: |
|
all_teacher_encoder_outputs = self.teacher.get_encoder()( |
|
input_ids, |
|
attention_mask=src_mask, |
|
output_hidden_states=self.do_calc_hidden_loss, |
|
) |
|
if self.different_base_models: |
|
teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"] |
|
elif self.do_calc_hidden_loss: |
|
hid_loss_enc = self.calc_hidden_loss( |
|
src_mask, |
|
student_outputs["encoder_hidden_states"], |
|
all_teacher_encoder_outputs["hidden_states"], |
|
self.e_matches, |
|
normalize_hidden=self.hparams.normalize_hidden, |
|
) |
|
|
|
teacher_outputs = self.teacher( |
|
input_ids, |
|
attention_mask=src_mask, |
|
encoder_outputs=(teacher_enc_outputs,), |
|
decoder_input_ids=decoder_input_ids, |
|
output_hidden_states=self.do_calc_hidden_loss, |
|
use_cache=False, |
|
) |
|
dec_mask = decoder_input_ids.ne(pad_token_id) |
|
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"]) |
|
if self.do_calc_hidden_loss: |
|
hid_loss_dec = self.calc_hidden_loss( |
|
dec_mask, |
|
student_outputs["decoder_hidden_states"], |
|
teacher_outputs["decoder_hidden_states"], |
|
self.d_matches, |
|
normalize_hidden=self.hparams.normalize_hidden, |
|
) |
|
|
|
blended_loss = ( |
|
self.alpha_ce * loss_ce |
|
+ self.alpha_mlm * student_lm_loss |
|
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec) |
|
) |
|
return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec |
|
|
|
@staticmethod |
|
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden): |
|
"""MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT.""" |
|
msg = "expected list or tuple for hidden_states, got tensor of shape: " |
|
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}" |
|
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}" |
|
mask = attention_mask.to(hidden_states[0]) |
|
valid_count = mask.sum() * hidden_states[0].size(-1) |
|
student_states = torch.stack([hidden_states[i] for i in range(len(matches))]) |
|
teacher_states = torch.stack([hidden_states_T[j] for j in matches]) |
|
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}" |
|
if normalize_hidden: |
|
student_states = nn.functional.layer_norm(student_states, student_states.shape[1:]) |
|
teacher_states = nn.functional.layer_norm(teacher_states, teacher_states.shape[1:]) |
|
mse = nn.functional.mse_loss(student_states, teacher_states, reduction="none") |
|
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count |
|
return masked_mse |
|
|
|
|
|
def add_distill_args(parser): |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--teacher", type=str) |
|
parser.add_argument("--alpha_ce", default=0.8, type=float) |
|
parser.add_argument("--alpha_mlm", default=0.2, type=float) |
|
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) |
|
parser.add_argument("--student", type=str, required=False) |
|
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False) |
|
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False) |
|
parser.add_argument("--no_teacher", action="store_true", default=False) |
|
parser.add_argument("--length_penalty", type=float, default=-1) |
|
parser.add_argument("--supervise_forward", action="store_true", default=False) |
|
parser.add_argument("--normalize_hidden", action="store_true", default=False) |
|
|
|
|
|
class TranslationDistiller(SummarizationDistiller): |
|
"""Supports T5, mBART, Marian, other models that inherit from Bart.""" |
|
|
|
mode = "translation" |
|
metric_names = ["bleu"] |
|
default_val_metric = "bleu" |
|
|
|
def __init__(self, hparams, **kwargs): |
|
super().__init__(hparams, **kwargs) |
|
assert hparams.src_lang is not None |
|
assert hparams.tgt_lang is not None |
|
self.dataset_kwargs["src_lang"] = hparams.src_lang |
|
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang |
|
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): |
|
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] |
|
|
|
def calc_generative_metrics(self, preds, target) -> dict: |
|
return calculate_bleu(preds, target) |
|
|
|
@staticmethod |
|
def add_model_specific_args(parser, root_dir): |
|
TranslationModule.add_model_specific_args(parser, root_dir) |
|
add_distill_args(parser) |
|
return parser |
|
|
|
|
|
def create_module(args): |
|
if args.no_teacher: |
|
module_cls = TranslationModule if "translation" in args.task else SummarizationModule |
|
else: |
|
module_cls = TranslationDistiller if "translation" in args.task else SummarizationDistiller |
|
args.setup_cls: str = module_cls.__name__ |
|
print(f"using module {args.setup_cls}") |
|
model = module_cls(args) |
|
return model |
|
|
|
|
|
def distill_main(args): |
|
Path(args.output_dir).mkdir(exist_ok=True) |
|
check_output_dir(args, expected_items=3) |
|
|
|
model = create_module(args) |
|
return ft_main(args, model=model) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser = pl.Trainer.add_argparse_args(parser) |
|
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd()) |
|
args = parser.parse_args() |
|
|
|
distill_main(args) |
|
|