import torch.nn as nn import torch from transformers import AutoTokenizer, BertForSequenceClassification, PreTrainedModel, PretrainedConfig, AutoModelForQuestionAnswering, get_scheduler from transformers.modeling_outputs import SequenceClassifierOutput from torch.nn import CrossEntropyLoss from torch.optim import AdamW from LUKE_pipe import generate from datasets import load_dataset from accelerate import Accelerator from tqdm import tqdm MAX_BEAM = 10 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") class ClassifierAdapter(nn.Module): def __init__(self, l1=3): super().__init__() self.linear1 = nn.Linear(l1, 1) self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") self.bert = BertForSequenceClassification.from_pretrained("botcon/right_span_bert") self.relu = nn.ReLU() def forward(self, questions, answers, logits): beam_size = len(answers[0]) samples = len(questions) questions = [question for _ in range(len(answers[0])) for question in questions] answers = [answer for beam in answers for answer in beam] input = self.tokenizer( questions, answers, padding="max_length", return_tensors="pt" ).to(device) bert_logits = self.bert(**input).logits bert_logits = bert_logits.reshape(samples, beam_size, 2) logits = torch.FloatTensor(logits).to(device).unsqueeze(-1) logits = torch.cat((logits, bert_logits), dim=-1) logits = self.relu(logits) out = torch.squeeze(self.linear1(logits), dim=-1) return out class HuggingWrapper(PreTrainedModel): config_class = PretrainedConfig() def __init__(self, config): super().__init__(config) self.model = ClassifierAdapter() def forward(self, **kwargs): labels = kwargs.pop("labels") output = self.model(**kwargs) loss_fn = CrossEntropyLoss(ignore_index=MAX_BEAM) loss = loss_fn(output, labels) return SequenceClassifierOutput(logits=output, loss=loss) accelerator = Accelerator(mixed_precision="fp16") model = HuggingWrapper.from_pretrained("botcon/special_bert").to(device) optimizer = AdamW(model.parameters()) model, optimizer = accelerator.prepare(model, optimizer) batch_size = 2 raw_datasets = load_dataset("squad") raw_train = raw_datasets["train"] num_updates = len(raw_train) // batch_size num_epoch = 2 num_training_steps = num_updates * num_epoch lr_scheduler = get_scheduler( "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps, ) progress_bar = tqdm(range(num_training_steps)) for epoch in range(num_epoch): start = 0 end = batch_size steps = 0 cumu_loss = 0 training_data = raw_train model.train() while start < len(training_data): optimizer.zero_grad() batch_data = raw_train.select(range(start, min(end, len(raw_train)))) with torch.no_grad(): res = generate(batch_data) prediction = [] predicted_logit = [] labels = [] for i in range(len(res)): x = res[i] ground_answer = batch_data["answers"][i]["text"][0] predicted_text = x["prediction_text"] found = False for k in range(len(predicted_text)): if predicted_text[k] == ground_answer: labels.append(k) found = True break if not found: labels.append(MAX_BEAM) prediction.append(predicted_text) predicted_logit.append(x["logits"]) labels = torch.LongTensor(labels).to(device) classifier_out = model(questions=batch_data["question"] , answers=prediction, logits=predicted_logit, labels=labels) loss = classifier_out.loss if not torch.isnan(loss).item(): cumu_loss += loss.item() steps += 1 accelerator.backward(loss) optimizer.step() lr_scheduler.step() progress_bar.update(1) start += batch_size end += batch_size # every 100 steps if steps % 100 == 0: print("Cumu loss: {}".format(cumu_loss / 100)) cumu_loss = 0 model.push_to_hub("some_fake_bert")