File size: 4,504 Bytes
bfaee96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch.nn as nn
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, PreTrainedModel, PretrainedConfig, get_scheduler
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from LUKE_pipe import generate
from datasets import load_dataset
from accelerate import Accelerator
from tqdm import tqdm

MAX_BEAM = 10
tf32 = True
torch.backends.cuda.matmul.allow_tf32 = tf32
torch.backends.cudnn.allow_tf32 = tf32
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

class ClassifierAdapter(nn.Module):
    def __init__(self, l1=3):
        super().__init__()
        self.linear1 = nn.Linear(l1, 1)
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.bert = BertForSequenceClassification.from_pretrained("botcon/right_span_bert")
        self.relu = nn.ReLU()

    def forward(self, questions, answers, logits):
        beam_size = len(answers[0])
        samples = len(questions)
        questions = [question for _ in range(len(answers[0])) for question in questions]
        answers = [answer for beam in answers for answer in beam]
        input = self.tokenizer(
            questions,
            answers,
            padding="max_length",
            return_tensors="pt"
        ).to(device)
        bert_logits = self.bert(**input).logits
        bert_logits = bert_logits.reshape(samples, beam_size, 2)
        logits = torch.FloatTensor(logits).to(device).unsqueeze(-1)
        logits = torch.cat((logits, bert_logits), dim=-1)
        logits = self.relu(logits)
        out = torch.squeeze(self.linear1(logits), dim=-1)
        return out

class HuggingWrapper(PreTrainedModel):
    config_class = PretrainedConfig()
    def __init__(self, config):
        super().__init__(config)
        self.model = ClassifierAdapter()

    def forward(self, **kwargs):
        labels = kwargs.pop("labels")
        output = self.model(**kwargs)
        loss_fn = CrossEntropyLoss(ignore_index=MAX_BEAM)
        loss = loss_fn(output, labels)
        return SequenceClassifierOutput(logits=output, loss=loss)
    
accelerator = Accelerator(mixed_precision="fp16")
model = HuggingWrapper.from_pretrained("botcon/special_bert").to(device)
optimizer = AdamW(model.parameters())
model, optimizer = accelerator.prepare(model, optimizer)
batch_size = 2
raw_datasets = load_dataset("squad")
raw_train = raw_datasets["train"]
num_updates = len(raw_train) // batch_size
num_epoch = 2
num_training_steps = num_updates * num_epoch
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_epoch):
    start = 0
    end = batch_size
    steps = 0
    cumu_loss = 0
    training_data = raw_train
    model.train()
    while start < len(training_data):
        optimizer.zero_grad()
        batch_data = raw_train.select(range(start, min(end, len(raw_train))))
        with torch.no_grad():
            res = generate(batch_data)
            prediction = []
            predicted_logit = []
            labels = []
            for i in range(len(res)):
                x = res[i]
                ground_answer = batch_data["answers"][i]["text"][0]
                predicted_text = x["prediction_text"]
                found = False
                for k in range(len(predicted_text)):
                    if predicted_text[k] == ground_answer:
                        labels.append(k)
                        found = True
                        break
                if not found:
                    labels.append(MAX_BEAM)
                prediction.append(predicted_text)
                predicted_logit.append(x["logits"])
        labels = torch.LongTensor(labels).to(device)
        classifier_out = model(questions=batch_data["question"] , answers=prediction, logits=predicted_logit, labels=labels)
        loss = classifier_out.loss
        if not torch.isnan(loss).item():
            cumu_loss += loss.item()
            steps += 1
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        progress_bar.update(1)
        start += batch_size
        end += batch_size
        # every 100 steps
        if steps % 100 == 0:
            print("Cumu loss: {}".format(cumu_loss / 100))
            cumu_loss = 0

model.push_to_hub("Adapter Bert")